diff --git a/client.go b/client.go index 485e4cff3771aacb85768dd23372b5f5310feaa2..5102e9d09a917ad31db11b022a58fd924820db21 100644 --- a/client.go +++ b/client.go @@ -23,7 +23,6 @@ import ( "fmt" "net/url" "reflect" - "strconv" "sync/atomic" "time" @@ -223,9 +222,9 @@ func initClient(conn ServerCodec, r Router) *Client { return c } -func (c *Client) nextID() json.RawMessage { +func (c *Client) nextID() *ID { id := atomic.AddUint32(&c.idCounter, 1) - return strconv.AppendUint(nil, uint64(id), 10) + return NewNumberIDPtr(int32(id)) } // SupportedModules calls the rpc_modules method, retrieving the list of @@ -286,7 +285,7 @@ func (c *Client) CallContext(ctx context.Context, result any, method string, arg if err != nil { return err } - op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} + op := &requestOp{ids: []json.RawMessage{msg.ID.RawMessage()}, resp: make(chan *jsonrpcMessage, 1)} if c.isHTTP { err = c.sendHTTP(ctx, op, msg) @@ -347,8 +346,8 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { return err } msgs[i] = msg - op.ids[i] = msg.ID - byID[string(msg.ID)] = i + op.ids[i] = msg.ID.RawMessage() + byID[string(msg.ID.RawMessage())] = i } var err error @@ -368,7 +367,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { // Find the element corresponding to this response. // The element is guaranteed to be present because dispatch // only sends valid IDs to our channel. - elem := &b[byID[string(resp.ID)]] + elem := &b[byID[string(resp.ID.RawMessage())]] if resp.Error != nil { elem.Error = resp.Error continue @@ -399,7 +398,7 @@ func (c *Client) Notify(ctx context.Context, method string, args ...any) error { } func (c *Client) newMessage(method string, paramsIn ...any) (*jsonrpcMessage, error) { - msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method} + msg := &jsonrpcMessage{ID: c.nextID(), Method: method} if paramsIn != nil { // prevent sending "params":null var err error if msg.Params, err = json.Marshal(paramsIn); err != nil { diff --git a/handler.go b/handler.go index 1c4265f6483bf465df77d8c60550dbaad138a131..31aa5724a2c00f724ee7c52846e13266a7c84834 100644 --- a/handler.go +++ b/handler.go @@ -32,21 +32,20 @@ import ( // // The entry points for incoming messages are: // -// h.handleMsg(message) -// h.handleBatch(message) +// h.handleMsg(message) +// h.handleBatch(message) // // Outgoing calls use the requestOp struct. Register the request before sending it // on the connection: // -// op := &requestOp{ids: ...} -// h.addRequestOp(op) +// op := &requestOp{ids: ...} +// h.addRequestOp(op) // // Now send the request, then wait for the reply to be delivered through handleMsg: // -// if err := op.wait(...); err != nil { -// h.removeRequestOp(op) // timeout, etc. -// } -// +// if err := op.wait(...); err != nil { +// h.removeRequestOp(op) // timeout, etc. +// } type handler struct { reg Router respWait map[string]*requestOp // active client requests @@ -191,7 +190,7 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { return true case msg.isResponse(): h.handleResponse(msg) - h.log.Trace().Str("reqid", string(msg.ID)).Dur("duration", start.Since(start)).Msg("Handled RPC response") + h.log.Trace().Str("reqid", string(msg.ID.RawMessage())).Dur("duration", start.Since(start)).Msg("Handled RPC response") return true default: return false @@ -200,17 +199,17 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { // handleResponse processes method call responses. func (h *handler) handleResponse(msg *jsonrpcMessage) { - op := h.respWait[string(msg.ID)] + op := h.respWait[string(msg.ID.RawMessage())] if op == nil { - h.log.Debug().Str("reqid", string(msg.ID)).Msg("Unsolicited RPC response") + h.log.Debug().Str("reqid", string(msg.ID.RawMessage())).Msg("Unsolicited RPC response") return } - delete(h.respWait, string(msg.ID)) + delete(h.respWait, string(msg.ID.RawMessage())) op.resp <- msg } // handleCallMsg executes a call message and returns the answer. -// TODO: export prometheus metrics maybe? +// TODO: export prometheus metrics maybe? also fix logging func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { // start := NewTimer() switch { diff --git a/http.go b/http.go index e17a97d137baab55da1728d6007f96f72b548811..cbfa3f5277fc88740b9db6b14d497c1aa9174156 100644 --- a/http.go +++ b/http.go @@ -242,10 +242,9 @@ func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { buf := new(bytes.Buffer) buf.Grow(64) json.NewEncoder(buf).Encode(jsonrpcMessage{ - Version: "2.0", - ID: []byte(id), - Method: method_up, - Params: param, + ID: NewStringIDPtr(id), + Method: method_up, + Params: param, }) conn.Reader = buf } else { diff --git a/json.go b/json.go index c42a580ba784bfca530707db40f43d40cd200274..a22414bd697e4ddf927bf2f36a27bf7511257271 100644 --- a/json.go +++ b/json.go @@ -24,7 +24,6 @@ import ( "fmt" "io" "reflect" - "strconv" "strings" "sync" "time" @@ -44,16 +43,11 @@ const ( var null = json.RawMessage("null") -type subscriptionResult struct { - ID string `json:"subscription"` - Result json.RawMessage `json:"result,omitempty"` -} - // A value of this type can a JSON-RPC request, notification, successful response or // error response. Which one it is depends on the fields. type jsonrpcMessage struct { - Version string `json:"jsonrpc,omitempty"` - ID json.RawMessage `json:"id,omitempty"` + Version version `json:"jsonrpc,omitempty"` + ID *ID `json:"id,omitempty"` Method string `json:"method,omitempty"` Params json.RawMessage `json:"params,omitempty"` Error *jsonError `json:"error,omitempty"` @@ -62,8 +56,7 @@ type jsonrpcMessage struct { func MakeCall(id int, method string, params []any) *JsonRpcMessage { return &JsonRpcMessage{ - Version: vsn, - ID: []byte(strconv.Itoa(id)), + ID: NewNumberIDPtr(int32(id)), } } @@ -82,7 +75,7 @@ func (msg *jsonrpcMessage) isResponse() bool { } func (msg *jsonrpcMessage) hasValidID() bool { - return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '[' + return msg.ID != nil && !msg.ID.null } func (msg *jsonrpcMessage) isSubscribe() bool { @@ -105,7 +98,9 @@ func (msg *jsonrpcMessage) String() string { func (msg *jsonrpcMessage) errorResponse(err error) *jsonrpcMessage { resp := errorMessage(err) - resp.ID = msg.ID + if resp.ID != nil { + resp.ID = msg.ID + } return resp } @@ -113,16 +108,14 @@ func (msg *jsonrpcMessage) response(result any) *jsonrpcMessage { // do a funny marshaling enc, err := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(result) if err != nil { - // TODO: wrap with 'internal server error' return msg.errorResponse(err) } - return &jsonrpcMessage{Version: vsn, ID: msg.ID, Result: enc} + return &jsonrpcMessage{ID: msg.ID, Result: enc} } func errorMessage(err error) *jsonrpcMessage { msg := &jsonrpcMessage{ - ID: null, - Version: vsn, + ID: NewNullIDPtr(), Error: &jsonError{ Code: defaultErrorCode, Message: err.Error(), diff --git a/protocol.go b/protocol.go index 5053dc41ee77cccf38a191be80f9b43081262ca7..9fa8c7e1410039515a67e1737e55b21138e86b22 100644 --- a/protocol.go +++ b/protocol.go @@ -10,14 +10,18 @@ import ( type HandlerFunc func(w ResponseWriter, r *Request) -func (fn HandlerFunc) ServeRPC(w ResponseWriter, r *Request) { - (fn)(w, r) -} - type Handler interface { ServeRPC(w ResponseWriter, r *Request) } +type ResponseWriter interface { + Send(v any, err error) error +} + +func (fn HandlerFunc) ServeRPC(w ResponseWriter, r *Request) { + (fn)(w, r) +} + type Request struct { ctx context.Context msg jsonrpcMessage @@ -29,10 +33,9 @@ func NewRequest(ctx context.Context, id string, method string, params any) *Requ r := &Request{ctx: ctx} pms, _ := json.Marshal(params) r.msg = jsonrpcMessage{ - Version: "2.0", - ID: []byte(id), - Method: method, - Params: pms, + ID: NewStringIDPtr(id), + Method: method, + Params: pms, } return r } @@ -99,10 +102,6 @@ func (r *Request) Msg() jsonrpcMessage { return r.msg } -type ResponseWriter interface { - Send(v any, err error) error -} - type ResponseWriterIo struct { r *Request w io.Writer diff --git a/wire.go b/wire.go new file mode 100644 index 0000000000000000000000000000000000000000..8358d5331e4c2bd52237355a25cdc56c8ef0ca2b --- /dev/null +++ b/wire.go @@ -0,0 +1,175 @@ +package jrpc + +import ( + "encoding/json" + "fmt" +) + +// Version represents a JSON-RPC version. +const Version = "2.0" + +// version is a special 0 sized struct that encodes as the jsonrpc version tag. +// +// It will fail during decode if it is not the correct version tag in the stream. +type version struct{} + +// compile time check whether the version implements a json.Marshaler and json.Unmarshaler interfaces. +var ( + _ json.Marshaler = (*version)(nil) + _ json.Unmarshaler = (*version)(nil) +) + +// MarshalJSON implements json.Marshaler. +func (version) MarshalJSON() ([]byte, error) { + return json.Marshal(Version) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (version) UnmarshalJSON(data []byte) error { + version := "" + if err := json.Unmarshal(data, &version); err != nil { + return fmt.Errorf("failed to Unmarshal: %w", err) + } + if version != Version { + return fmt.Errorf("invalid RPC version %v", version) + } + return nil +} + +// ID is a Request identifier. +// +// Only one of either the Name or Number members will be set, using the +// number form if the Name is the empty string. +// alternatively, ID can be null +type ID struct { + name string + number int32 + + null bool +} + +// compile time check whether the ID implements a fmt.Formatter, json.Marshaler and json.Unmarshaler interfaces. +var ( + _ fmt.Formatter = (*ID)(nil) + _ json.Marshaler = (*ID)(nil) + _ json.Unmarshaler = (*ID)(nil) +) + +// NewNumberID returns a new number request ID. +func NewNumberID(v int32) ID { return *NewNumberIDPtr(v) } + +// NewStringID returns a new string request ID. +func NewStringID(v string) ID { return *NewStringIDPtr(v) } + +// NewStringID returns a new string request ID. +func NewNullID() ID { return *NewNullIDPtr() } + +func NewNumberIDPtr(v int32) *ID { return &ID{number: v} } +func NewStringIDPtr(v string) *ID { return &ID{name: v} } +func NewNullIDPtr() *ID { return &ID{null: true} } + +// Format writes the ID to the formatter. +// +// If the rune is q the representation is non ambiguous, +// string forms are quoted, number forms are preceded by a #. +func (id ID) Format(f fmt.State, r rune) { + numF, strF := `%d`, `%s` + if r == 'q' { + numF, strF = `#%d`, `%q` + } + + switch { + case id.name != "": + fmt.Fprintf(f, strF, id.name) + default: + fmt.Fprintf(f, numF, id.number) + } + id.null = false +} + +// get the raw message +func (id *ID) RawMessage() json.RawMessage { + if id.null { + return null + } + if id.name != "" { + ans, err := json.Marshal(id.name) + if err == nil { + return ans + } + } + ans, err := json.Marshal(id.number) + if err == nil { + return ans + } + return nil +} + +// MarshalJSON implements json.Marshaler. +func (id *ID) MarshalJSON() ([]byte, error) { + if id == nil { + return null, nil + } + if id.null { + return null, nil + } + if id.name != "" { + return json.Marshal(id.name) + } + return json.Marshal(id.number) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (id *ID) UnmarshalJSON(data []byte) error { + *id = ID{} + if err := json.Unmarshal(data, &id.number); err == nil { + return nil + } + if err := json.Unmarshal(data, &id.name); err == nil { + return nil + } + id.null = true + return nil +} + +// wireRequest is sent to a server to represent a Call or Notify operaton. +type wireRequest struct { + // VersionTag is always encoded as the string "2.0" + VersionTag version `json:"jsonrpc"` + // Method is a string containing the method name to invoke. + Method string `json:"method"` + // Params is either a struct or an array with the parameters of the method. + Params *json.RawMessage `json:"params,omitempty"` + // The id of this request, used to tie the Response back to the request. + // Will be either a string or a number. If not set, the Request is a notify, + // and no response is possible. + ID *ID `json:"id,omitempty"` +} + +// wireResponse is a reply to a Request. +// +// It will always have the ID field set to tie it back to a request, and will +// have either the Result or Error fields set depending on whether it is a +// success or failure wireResponse. +type wireResponse struct { + // VersionTag is always encoded as the string "2.0" + VersionTag version `json:"jsonrpc"` + // Result is the response value, and is required on success. + Result *json.RawMessage `json:"result,omitempty"` + // Error is a structured error response if the call fails. + Error *Error `json:"error,omitempty"` + // ID must be set and is the identifier of the Request this is a response to. + ID *ID `json:"id,omitempty"` +} + +// combined has all the fields of both Request and Response. +// +// We can decode this and then work out which it is. +type combined struct { + VersionTag version `json:"jsonrpc"` + ID *ID `json:"id,omitempty"` + Method string `json:"method"` + Params *json.RawMessage `json:"params,omitempty"` + Result *json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` +}