Newer
Older
package jrpc
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"reflect"
stdjson "encoding/json"
"github.com/goccy/go-json"
const (
defaultWriteTimeout = 10 * time.Second // used if context has no deadline
)
var null = json.RawMessage("null")
// A value of this type can a JSON-RPC request, notification, successful response or
// error response. Which one it is depends on the fields.
type jsonrpcMessage struct {
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
func MakeCall(id int, method string, params []any) *JsonRpcMessage {
return &JsonRpcMessage{
}
}
type JsonRpcMessage = jsonrpcMessage
func (msg *jsonrpcMessage) isNotification() bool {
return msg.hasValidID() && len(msg.Method) == 0 && msg.Params == nil && (msg.Result != nil || msg.Error != nil)
func (msg *jsonrpcMessage) isSubscribe() bool {
return strings.HasSuffix(msg.Method, subscribeMethodSuffix)
}
func (msg *jsonrpcMessage) isUnsubscribe() bool {
return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix)
}
func (msg *jsonrpcMessage) namespace() string {
elem := strings.SplitN(msg.Method, serviceMethodSeparator, 2)
return elem[0]
}
b, _ := json.Marshal(msg)
return string(b)
}
func (msg *jsonrpcMessage) errorResponse(err error) *jsonrpcMessage {
resp := errorMessage(err)
return resp
}
func (msg *jsonrpcMessage) response(result any) *jsonrpcMessage {
// do a funny marshaling
enc, err := jzon.Marshal(result)
type jsonError struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
type JsonError = jsonError
func (err *jsonError) Error() string {
if err.Message == "" {
}
return err.Message
}
func (err *jsonError) ErrorCode() int {
return err.Code
}
func (err *jsonError) ErrorData() any {
return err.Data
}
// error message produces json rpc message with error message
func errorMessage(err error) *jsonrpcMessage {
msg := &jsonrpcMessage{
ID: NewNullIDPtr(),
Error: &jsonError{
Code: defaultErrorCode,
Message: err.Error(),
}}
ec, ok := err.(Error)
if ok {
msg.Error.Code = ec.ErrorCode()
}
de, ok := err.(DataError)
if ok {
msg.Error.Data = de.ErrorData()
}
return msg
}
// Conn is a subset of the methods of net.Conn which are sufficient for ServerCodec.
type Conn interface {
io.ReadWriteCloser
SetWriteDeadline(time.Time) error
}
// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has
// support for parsing arguments and serializing (result) objects.
type jsonCodec struct {
remote string
closer sync.Once // close closed channel once
closeFunc func() error
closeCh chan any // closed on Close
decode func(v any) error // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode func(v any) error // encoder to allow multiple transports
}
// NewFuncCodec creates a codec which uses the given functions to read and write. If conn
// implements ConnRemoteAddr, log messages will use it to include the remote address of
// the connection.
encode, decode func(v any) error,
closeFunc func() error,
) ServerCodec {
closeFunc: closeFunc,
closeCh: make(chan any),
encode: encode,
decode: decode,
conn: conn,
codec.remote = ra.RemoteAddr()
}
return codec
}
// NewCodec creates a codec on the given connection. If conn implements ConnRemoteAddr, log
// messages will use it to include the remote address of the connection.
func NewCodec(conn Conn) ServerCodec {
encr := func(v any) error {
enc := jzon.BorrowStream(conn)
defer jzon.ReturnStream(enc)
enc.WriteVal(v)
enc.WriteRaw("\n")
enc.Flush()
if enc.Error != nil {
return enc.Error
}
return nil
}
// TODO:
// for some reason other json decoders are incompatible with our test suite
// pretty sure its how we handle EOFs and stuff
dec := stdjson.NewDecoder(conn)
dec.UseNumber()
return NewFuncCodec(conn, encr, dec.Decode, func() error {
// This returns "ipc" because all other built-in transports have a separate codec type.
return PeerInfo{Transport: "ipc", RemoteAddr: c.remote}
}
func (c *jsonCodec) ReadBatch() (messages []*jsonrpcMessage, batch bool, err error) {
// Decode the next JSON object in the input stream.
// This verifies basic syntax, etc.
var rawmsg json.RawMessage
if err := c.decode(&rawmsg); err != nil {
return nil, false, err
}
messages, batch = parseMessage(rawmsg)
for i, msg := range messages {
if msg == nil {
// Message is JSON 'null'. Replace with zero value so it
// will be treated like any other invalid message.
messages[i] = new(jsonrpcMessage)
}
}
return messages, batch, nil
}
c.encMu.Lock()
defer c.encMu.Unlock()
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultWriteTimeout)
}
c.conn.SetWriteDeadline(deadline)
return c.encode(v)
}
return c.closeCh
}
// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error
// checks in this function because the raw message has already been syntax-checked when it
// is called. Any non-JSON-RPC messages in the input return the zero value of
// jsonrpcMessage.
func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) {
if !isBatch(raw) {
msgs := []*jsonrpcMessage{{}}
// TODO:
// for some reason other json decoders are incompatible with our test suite
// pretty sure its how we handle EOFs and stuff
dec := stdjson.NewDecoder(bytes.NewReader(raw))
dec.Token() // skip '['
var msgs []*jsonrpcMessage
for dec.More() {
msgs = append(msgs, new(jsonrpcMessage))
dec.Decode(&msgs[len(msgs)-1])
}
return msgs, true
}
// isBatch returns true when the first non-whitespace characters is '['
func isBatch(raw json.RawMessage) bool {
for _, c := range raw {
// skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt)
continue
}
return c == '['
}
return false
}
// parsePositionalArguments tries to parse the given args to an array of values with the
// given types. It returns the parsed values or an error when the args could not be
// parsed. Missing optional arguments are returned as reflect.Zero values.
func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, error) {
var args []reflect.Value
switch {
var err error
if args, err = parseArgumentArray(rawArgs, types); err != nil {
default:
return nil, errors.New("non-array args")
}
// Set any missing args to nil.
for i := len(args); i < len(types); i++ {
if types[i].Kind() != reflect.Ptr {
return nil, fmt.Errorf("missing value for required argument %d", i)
}
args = append(args, reflect.Zero(types[i]))
}
return args, nil
}
func parseArgumentArray(p json.RawMessage, types []reflect.Type) ([]reflect.Value, error) {
dec := jzon.BorrowIterator(p)
defer jzon.ReturnIterator(dec)
if i >= len(types) {
return args, fmt.Errorf("too many arguments, want at most %d", len(types))
}
argval := reflect.New(types[i])
dec.ReadVal(argval.Interface())
if err := dec.Error; err != nil {
return args, fmt.Errorf("invalid argument %d: %v", i, err)
}
if argval.IsNil() && types[i].Kind() != reflect.Ptr {
return args, fmt.Errorf("missing value for required argument %d", i)
}
args = append(args, argval.Elem())
}