diff --git a/contrib/codecs/http/client.go b/contrib/codecs/http/client.go index 57097298bb29e2f81f330e030e6bcf79d9fe4265..8a36277dadce0a13642062a2c4ced28a56e0fe7e 100644 --- a/contrib/codecs/http/client.go +++ b/contrib/codecs/http/client.go @@ -74,7 +74,10 @@ func (c *Client) SetHeader(key string, value string) { } func (c *Client) Do(ctx context.Context, result any, method string, params any) error { - req := codec.NewRequestInt(ctx, int(c.id.Add(1)), method, params) + req, err := codec.NewRequest(ctx, codec.NewId(c.id.Add(1)), method, params) + if err != nil { + return err + } resp, err := c.post(req) if err != nil { return err @@ -129,7 +132,10 @@ func (c *Client) post(req *codec.Request) (*http.Response, error) { } func (c *Client) Notify(ctx context.Context, method string, params any) error { - req := codec.NewNotification(ctx, method, params) + req, err := codec.NewNotification(ctx, method, params) + if err != nil { + return err + } resp, err := c.post(req) if err != nil { return err @@ -142,13 +148,18 @@ func (c *Client) BatchCall(ctx context.Context, b ...*codec.BatchElem) error { reqs := make([]*codec.Request, len(b)) ids := make(map[int]int, len(b)) for idx, v := range b { + var rid *codec.ID if v.IsNotification { - reqs = append(reqs, codec.NewRequest(ctx, "", v.Method, v.Params)) } else { id := int(c.id.Add(1)) ids[idx] = id - reqs = append(reqs, codec.NewRequestInt(ctx, id, v.Method, v.Params)) + rid = codec.NewNumberIDPtr(int64(id)) + } + req, err := codec.NewRequest(ctx, rid, v.Method, v.Params) + if err != nil { + return err } + reqs = append(reqs, req) } dat, err := json.Marshal(reqs) if err != nil { diff --git a/contrib/codecs/http/codec.go b/contrib/codecs/http/codec.go index fee62d2565437c32f69696158c04c331d5d11d1d..0b4d8a61ef4e0b15bc1eca33b05f54288a9df13e 100644 --- a/contrib/codecs/http/codec.go +++ b/contrib/codecs/http/codec.go @@ -88,7 +88,8 @@ func (r *Codec) doReadGet() (msgs json.RawMessage, err error) { if id == "" { id = "1" } - req := codec.NewRequest(r.ctx, id, method_up, json.RawMessage(param)) + + req := codec.NewRawRequest(r.ctx, codec.NewId(id), method_up, json.RawMessage(param)) return req.MarshalJSON() } @@ -105,7 +106,7 @@ func (r *Codec) doReadRPC() (msgs json.RawMessage, err error) { if err != nil { return nil, err } - req := codec.NewRequest(r.ctx, id, method_up, json.RawMessage(data)) + req := codec.NewRawRequest(r.ctx, codec.NewId(id), method_up, json.RawMessage(data)) return req.MarshalJSON() } diff --git a/contrib/codecs/inproc/client.go b/contrib/codecs/inproc/client.go index fac266ad0fe94c0243ce3c96cb1e516f5326c313..f6a77ea8e7177b8f41ccd543acff92fa069f4011 100644 --- a/contrib/codecs/inproc/client.go +++ b/contrib/codecs/inproc/client.go @@ -51,14 +51,16 @@ func (c *Client) listen() error { id := v.ID.Number() if id == 0 { if c.handler != nil { - c.handler.ServeRPC(nil, codec.NewRequestFromRaw(c.c.ctx, &codec.RequestMarshaling{ - Method: v.Method, - Params: v.Params, - Peer: codec.PeerInfo{ - Transport: "ipc", - RemoteAddr: "", - }, - })) + req := codec.NewRawRequest(c.c.ctx, + nil, + v.Method, + v.Params, + ) + req.Peer = codec.PeerInfo{ + Transport: "ipc", + RemoteAddr: "", + } + c.handler.ServeRPC(nil, req) } continue } @@ -73,19 +75,19 @@ func (c *Client) listen() error { } func (c *Client) Do(ctx context.Context, result any, method string, params any) error { - if ctx == nil { - ctx = context.Background() - } id := c.p.NextId() - req := codec.NewRequestInt(ctx, id, method, params) + req, err := codec.NewRequest(ctx, codec.NewId(id), method, params) + if err != nil { + return err + } fwd, err := json.Marshal(req) if err != nil { return err } select { case c.c.msgs <- fwd: - case <-ctx.Done(): - return ctx.Err() + case <-req.Context().Done(): + return req.Context().Err() } ans, err := c.p.Ask(req.Context(), id) if err != nil { @@ -101,16 +103,16 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any) } func (c *Client) BatchCall(ctx context.Context, b ...*codec.BatchElem) error { - if ctx == nil { - ctx = context.Background() - } buf := new(bytes.Buffer) enc := json.NewEncoder(buf) reqs := make([]*codec.Request, 0, len(b)) ids := make([]int, 0, len(b)) for _, v := range b { id := c.p.NextId() - req := codec.NewRequestInt(ctx, id, v.Method, v.Params) + req, err := codec.NewRequest(ctx, codec.NewId(id), v.Method, v.Params) + if err != nil { + return err + } ids = append(ids, id) reqs = append(reqs, req) } @@ -156,7 +158,10 @@ func (c *Client) Notify(ctx context.Context, method string, params any) error { if ctx == nil { ctx = context.Background() } - req := codec.NewRequest(ctx, "", method, params) + req, err := codec.NewRequest(ctx, nil, method, params) + if err != nil { + return err + } fwd, err := json.Marshal(req) if err != nil { return err diff --git a/contrib/codecs/rdwr/client.go b/contrib/codecs/rdwr/client.go index 2188a4e39521aa470c441b7007e67b96dd0e78a2..c3d01abd214a528beb4d59269102f37edb323815 100644 --- a/contrib/codecs/rdwr/client.go +++ b/contrib/codecs/rdwr/client.go @@ -79,11 +79,13 @@ func (c *Client) listen() error { // writer should only be allowed to send notifications // reader should contain the message above // the context is the client context - handler.ServeRPC(nil, codec.NewRequestFromRaw(c.ctx, &codec.RequestMarshaling{ - Method: v.Method, - Params: v.Result, - Peer: c.handlerPeer, - })) + req := codec.NewRawRequest(c.ctx, + nil, + v.Method, + v.Params, + ) + req.Peer = c.handlerPeer + handler.ServeRPC(nil, req) continue } var err error @@ -97,16 +99,16 @@ func (c *Client) listen() error { } func (c *Client) Do(ctx context.Context, result any, method string, params any) error { - if ctx == nil { - ctx = context.Background() - } id := c.p.NextId() - req := codec.NewRequestInt(ctx, id, method, params) + req, err := codec.NewRequest(ctx, codec.NewId(id), method, params) + if err != nil { + return err + } fwd, err := json.Marshal(req) if err != nil { return err } - err = c.writeContext(ctx, fwd) + err = c.writeContext(req.Context(), fwd) if err != nil { return err } @@ -133,7 +135,10 @@ func (c *Client) BatchCall(ctx context.Context, b ...*codec.BatchElem) error { ids := make([]int, 0, len(b)) for _, v := range b { id := c.p.NextId() - req := codec.NewRequestInt(ctx, id, v.Method, v.Params) + req, err := codec.NewRequest(ctx, codec.NewId(id), v.Method, v.Params) + if err != nil { + return err + } ids = append(ids, id) reqs = append(reqs, req) } @@ -174,7 +179,10 @@ func (c *Client) Notify(ctx context.Context, method string, params any) error { if ctx == nil { ctx = context.Background() } - req := codec.NewRequest(ctx, "", method, params) + req, err := codec.NewRequest(ctx, nil, method, params) + if err != nil { + return err + } fwd, err := json.Marshal(req) if err != nil { return err diff --git a/contrib/openrpc/out/example/main.go b/contrib/openrpc/out/example/main.go index b143a321495ec306810cc01594f00824c9d1023c..dca31e123e0ed5a68d7ffee8b92517c3c647c651 100644 --- a/contrib/openrpc/out/example/main.go +++ b/contrib/openrpc/out/example/main.go @@ -8,7 +8,7 @@ import ( "gfx.cafe/open/jrpc" "gfx.cafe/open/jrpc/contrib/codecs" "gfx.cafe/open/jrpc/contrib/jmux" - "gfx.cafe/open/jrpc/openrpc/out" + "gfx.cafe/open/jrpc/contrib/openrpc/out" ) func main() { diff --git a/pkg/codec/errors.go b/pkg/codec/errors.go index 5efd2313a3b49bef61151426170da3dcf5ac596d..6f13598f44ad94dfe016d8c7117b80d8ae28a98b 100644 --- a/pkg/codec/errors.go +++ b/pkg/codec/errors.go @@ -2,6 +2,7 @@ package codec import ( "encoding/json" + "errors" "fmt" "github.com/go-faster/jx" @@ -24,6 +25,10 @@ const ( ErrorCodeJrpc = -42000 ) +var ( + ErrIllegalExtraField = errors.New("invalid extra field") +) + // Error wraps RPC errors, which contain an error code in addition to the message. type Error interface { Error() string // returns the message diff --git a/pkg/codec/json.go b/pkg/codec/json.go index 646fece7b06dffa044ebcd85ca46ff79a44239ae..706219ada95a78e5b71ba3f7dfb293b6fd395311 100644 --- a/pkg/codec/json.go +++ b/pkg/codec/json.go @@ -1,12 +1,12 @@ package codec import ( + "bytes" "encoding/json" + "fmt" "strconv" "github.com/go-faster/jx" - - gojson "github.com/goccy/go-json" ) var Null = json.RawMessage("null") @@ -15,22 +15,109 @@ func NewNull() json.RawMessage { return json.RawMessage("null") } +// RequestField is an idea borrowed from sourcegraphs implementation. +type RequestField struct { + Name string + Value json.RawMessage +} + // A value of this type can a JSON-RPC request, notification, successful response or // error response. Which one it is depends on the fields. type Message struct { - Version Version `json:"jsonrpc,omitempty"` - ID *ID `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Params json.RawMessage `json:"params,omitempty"` - Result json.RawMessage `json:"result,omitempty"` + ID *ID `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error error `json:"error,omitempty"` - Error *JsonError `json:"error,omitempty"` + ExtraFields []RequestField `json:"-"` } -func (m *Message) MarshalJSON() ([]byte, error) { - var enc jx.Encoder +func (m *Message) UnmarshalJSON(xs []byte) error { + var dec jx.Decoder + dec.ResetBytes(xs) + err := dec.Obj(func(d *jx.Decoder, key string) error { + switch key { + default: + val, err := d.Raw() + if err != nil { + return err + } + xs := make(json.RawMessage, len(val)) + copy(xs, val) + m.ExtraFields = append(m.ExtraFields, RequestField{ + Name: key, + Value: xs, + }) + case "jsonrpc": + value, err := d.Str() + if err != nil { + return err + } + if value != VersionString { + return NewInvalidRequestError("Invalid Version") + } + case "id": + raw, err := d.Raw() + if err != nil { + return err + } + id := &ID{} + err = id.UnmarshalJSON(raw) + m.ID = id + if err != nil { + return err + } + case "method": + value, err := d.Str() + if err != nil { + return err + } + m.Method = value + case "params": + val, err := d.Raw() + if err != nil { + return err + } + if len(m.Params) >= len(val) { + m.Params = m.Params[len(val):] + } else { + m.Params = make(json.RawMessage, len(val)) + } + copy(m.Params, val) + case "result": + val, err := d.Raw() + if err != nil { + return err + } + if len(m.Result) >= len(val) { + m.Result = m.Result[len(val):] + } else { + m.Result = make(json.RawMessage, len(val)) + } + copy(m.Result, val) + case "error": + val, err := d.Raw() + if err != nil { + return err + } + m.Error = &JsonError{} + err = json.Unmarshal(val, m.Error) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return err + } + return nil +} + +func MarshalMessage(m *Message, enc *jx.Encoder) error { // use encoder - enc.Obj(func(e *jx.Encoder) { + fail := enc.Obj(func(e *jx.Encoder) { e.Field("jsonrpc", func(e *jx.Encoder) { e.Str("2.0") }) @@ -44,11 +131,16 @@ func (m *Message) MarshalJSON() ([]byte, error) { e.Str(m.Method) }) } + for _, v := range m.ExtraFields { + e.Field(v.Name, func(e *jx.Encoder) { + e.Raw(v.Value) + }) + } if m.Error != nil { e.Field("error", func(e *jx.Encoder) { - xs, _ := json.Marshal(m.Error) - e.Raw(xs) + EncodeError(e, m.Error) }) + return } if len(m.Params) != 0 { e.Field("params", func(e *jx.Encoder) { @@ -60,15 +152,27 @@ func (m *Message) MarshalJSON() ([]byte, error) { e.Raw(m.Result) }) } + }) + if fail { + return fmt.Errorf("jx encoding error") + } // output - return enc.Bytes(), nil + return nil } -func MakeCall(id int, method string, params []any) *Message { - return &Message{ - ID: NewNumberIDPtr(int64(id)), +func (m *Message) MarshalJSON() ([]byte, error) { + buf := &bytes.Buffer{} + enc := jx.NewStreamingEncoder(buf, 4096) + err := MarshalMessage(m, enc) + if err != nil { + return nil, err + } + err = enc.Close() + if err != nil { + return nil, err } + return buf.Bytes(), nil } func (msg *Message) isNotification() bool { @@ -88,28 +192,17 @@ func (msg *Message) hasValidID() bool { } func (msg *Message) String() string { - b, _ := json.Marshal(msg) + b, _ := msg.MarshalJSON() return string(b) } -func (msg *Message) ErrorResponse(err error) *Message { - resp := ErrorMessage(err) +func (msg *Message) ErrorResponse(id *ID, err error) *Message { + resp := ErrorMessage(id, err) if resp.ID != nil { resp.ID = msg.ID } return resp } -func (msg *Message) response(result any) *Message { - // do a funny marshaling - enc, err := gojson.Marshal(result) - if err != nil { - return msg.ErrorResponse(err) - } - if len(enc) == 0 { - enc = []byte("null") - } - return &Message{ID: msg.ID, Result: enc} -} // encapsulate json rpc error into struct type JsonError struct { @@ -134,23 +227,25 @@ func (err *JsonError) ErrorData() any { } // error message produces json rpc message with error message -func ErrorMessage(err error) *Message { +func ErrorMessage(id *ID, err error) *Message { if err == nil { return nil } + je := &JsonError{ + Code: ErrorCodeDefault, + Message: err.Error(), + } msg := &Message{ - ID: NewNullIDPtr(), - Error: &JsonError{ - Code: ErrorCodeDefault, - Message: err.Error(), - }} - ec, ok := err.(Error) + ID: id, + Error: je, + } + ec, ok := err.(*JsonError) if ok { - msg.Error.Code = ec.ErrorCode() + je.Code = ec.ErrorCode() } de, ok := err.(DataError) if ok { - msg.Error.Data = de.ErrorData() + je.Data = de.ErrorData() } return msg } @@ -175,7 +270,7 @@ func IsBatchMessage(raw json.RawMessage) bool { func ParseMessage(raw json.RawMessage) ([]*Message, bool) { if !IsBatchMessage(raw) { msgs := []*Message{{}} - gojson.Unmarshal(raw, &msgs[0]) + msgs[0].UnmarshalJSON(raw) return msgs, false } // TODO: @@ -198,3 +293,19 @@ func ParseMessage(raw json.RawMessage) ([]*Message, bool) { }) return msgs, true } + +func (m *Message) SetExtraField(name string, v any) error { + switch name { + case "id", "jsonrpc", "method", "params", "result", "error": + return fmt.Errorf("%w: %q", ErrIllegalExtraField, name) + } + val, err := json.Marshal(v) + if err != nil { + return err + } + m.ExtraFields = append(m.ExtraFields, RequestField{ + Name: name, + Value: val, + }) + return nil +} diff --git a/pkg/codec/reqresp.go b/pkg/codec/reqresp.go index c9a697094d98bb287daa9c14d8ec7941912a1554..9f4fe6b561c9d5c517f5765497235d9cce8e3373 100644 --- a/pkg/codec/reqresp.go +++ b/pkg/codec/reqresp.go @@ -10,9 +10,10 @@ import ( // http.ResponseWriter interface, but for jrpc type ResponseWriter interface { Send(v any, err error) error - Option(k string, v any) Header() http.Header + SetExtraField(k string, v any) error + Notify(method string, v any) error } @@ -32,90 +33,51 @@ type BatchElem struct { Error error } -type Response struct { - Version Version `json:"jsonrpc,omitempty"` - ID *ID `json:"id,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *JsonError `json:"error,omitempty"` -} - -func (r *Response) Msg() *Message { - out := &Message{} - if r.ID != nil { - out.ID = r.ID - } - if r.Error != nil { - out.Error = r.Error - } else { - out.Result = r.Result - } - return out -} - type Request struct { - RequestMarshaling - ctx context.Context -} + ctx context.Context + Peer PeerInfo `json:"-"` -func NewRequestFromRaw(ctx context.Context, req *RequestMarshaling) *Request { - return &Request{ctx: ctx, RequestMarshaling: *req} + Message } func (r *Request) UnmarshalJSON(xs []byte) error { - return json.Unmarshal(xs, &r.RequestMarshaling) + return json.Unmarshal(xs, &r.Message) } func (r *Request) MarshalJSON() ([]byte, error) { - return json.Marshal(r.RequestMarshaling) + return json.Marshal(r.Message) } -type RequestMarshaling struct { - Version Version `json:"jsonrpc"` - ID *ID `json:"id,omitempty"` - Method string `json:"method"` - Params json.RawMessage `json:"params"` - Peer PeerInfo `json:"-"` -} - -func NewRequestInt(ctx context.Context, id int, method string, params any) *Request { +func NewRequestFromMessage(ctx context.Context, message *Message) (r *Request) { if ctx == nil { ctx = context.Background() } - r := &Request{ctx: ctx} - pms, _ := json.Marshal(params) - r.ID = NewNumberIDPtr(int64(id)) - r.Method = method - r.Params = pms + r = &Request{ctx: ctx, Message: *message} return r } -func NewRequest(ctx context.Context, id string, method string, params any) *Request { +func NewRawRequest(ctx context.Context, id *ID, method string, params json.RawMessage) (r *Request) { if ctx == nil { ctx = context.Background() } - r := &Request{ctx: ctx} - pms, _ := json.Marshal(params) - r.ID = NewStringIDPtr(id) + r = &Request{ctx: ctx} + r.ID = id r.Method = method - r.Params = pms + r.Params = params return r } -func NewNotification(ctx context.Context, method string, params any) *Request { - if ctx == nil { - ctx = context.Background() +// NewRequest makes a new request +func NewRequest(ctx context.Context, id *ID, method string, params any) (r *Request, err error) { + raw, err := json.Marshal(params) + if err != nil { + return nil, err } - r := &Request{ctx: ctx} - pms, _ := json.Marshal(params) - r.ID = nil - r.Method = method - r.Params = pms - return r + return NewRawRequest(ctx, id, method, raw), nil } -func (r *Request) makeError(err error) *Message { - m := r.Msg() - return m.ErrorResponse(err) +func NewNotification(ctx context.Context, method string, params any) (*Request, error) { + return NewRequest(ctx, nil, method, params) } func (r *Request) isNotification() bool { @@ -132,7 +94,10 @@ func (r *Request) hasValidID() bool { func (r *Request) ParamArray(a ...any) error { var params []json.RawMessage - json.Unmarshal(r.Params, ¶ms) + err := json.Unmarshal(r.Params, ¶ms) + if err != nil { + return err + } for idx, v := range params { if len(v) > idx { err := json.Unmarshal(v, &a[idx]) @@ -173,6 +138,8 @@ func (r *Request) WithContext(ctx context.Context) *Request { r2.ID = r.ID r2.Method = r.Method r2.Params = r.Params + r2.Error = r.Error + r2.ExtraFields = r.ExtraFields r2.Peer = r.Peer return r2 } diff --git a/pkg/codec/wire.go b/pkg/codec/wire.go index 194ebe556c493f3c0ed3a3c96a49a7eb3b854dd1..6d96b9e7bc0084d73cd9b0d494f5ceea31147c16 100644 --- a/pkg/codec/wire.go +++ b/pkg/codec/wire.go @@ -3,6 +3,7 @@ package codec import ( "bytes" "fmt" + "reflect" "strconv" json "github.com/goccy/go-json" @@ -11,34 +12,6 @@ import ( // Version represents a JSON-RPC version. const VersionString = "2.0" -// version is a special 0 sized struct that encodes as the jsonrpc version tag. -// -// It will fail during decode if it is not the correct version tag in the stream. -type Version struct{} - -// compile time check whether the version implements a json.Marshaler and json.Unmarshaler interfaces. -var ( - _ json.Marshaler = (*Version)(nil) - _ json.Unmarshaler = (*Version)(nil) -) - -// MarshalJSON implements json.Marshaler. -func (Version) MarshalJSON() ([]byte, error) { - return []byte(`"` + VersionString + `"`), nil -} - -// UnmarshalJSON implements json.Unmarshaler. -func (Version) UnmarshalJSON(data []byte) error { - version := "" - if err := json.Unmarshal(data, &version); err != nil { - return fmt.Errorf("failed to Unmarshal: %w", err) - } - if version != VersionString { - return fmt.Errorf("invalid RPC version %v", version) - } - return nil -} - // ID is a Request identifier. // // alternatively, ID can be null @@ -59,6 +32,41 @@ var ( _ json.Unmarshaler = (*ID)(nil) ) +func NewId(v any) *ID { + switch cast := v.(type) { + case uint8: + return NewNumberIDPtr(int64(cast)) + case int8: + return NewNumberIDPtr(int64(cast)) + case uint16: + return NewNumberIDPtr(int64(cast)) + case int16: + return NewNumberIDPtr(int64(cast)) + case uint32: + return NewNumberIDPtr(int64(cast)) + case int32: + return NewNumberIDPtr(int64(cast)) + case uint64: + return NewNumberIDPtr(int64(cast)) + case int64: + return NewNumberIDPtr(int64(cast)) + case int: + return NewNumberIDPtr(int64(cast)) + case uint: + return NewNumberIDPtr(int64(cast)) + case string: + return NewStringIDPtr(cast) + case []byte: + r := ID(cast) + return &r + case json.RawMessage: + r := ID(cast) + return &r + default: + panic(fmt.Sprintf("invalid id: %s %+v", reflect.TypeOf(v), v)) + } +} + // NewNumberID returns a new number request ID. func NewNumberID(v int64) ID { return *NewNumberIDPtr(v) } diff --git a/pkg/codec/wire_test.go b/pkg/codec/wire_test.go index 7a4ee3272c31d19635dc8f757f0bd98ab4ed57e9..900a90ad4c9c4b641a6d9703d193e010f272472a 100644 --- a/pkg/codec/wire_test.go +++ b/pkg/codec/wire_test.go @@ -7,23 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestVersion(t *testing.T) { - var v Version - - t.Run("encoding", func(t *testing.T) { - ans, err := json.Marshal(v) - assert.NoError(t, err) - assert.Equal(t, []byte(`"2.0"`), ans) - }) - - t.Run("decoding", func(t *testing.T) { - err := json.Unmarshal([]byte(`"2.0"`), &v) - assert.NoError(t, err) - err = json.Unmarshal([]byte("not"), &v) - assert.Error(t, err) - }) -} - func TestIDMarshal(t *testing.T) { var v ID diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go new file mode 100644 index 0000000000000000000000000000000000000000..4a3ae382f9863ef60b000901905bea34d6a1ed5a --- /dev/null +++ b/pkg/server/responsewriter.go @@ -0,0 +1,48 @@ +package server + +import ( + "net/http" + + "gfx.cafe/open/jrpc/pkg/codec" +) + +var _ codec.ResponseWriter = (*callRespWriter)(nil) + +type callRespWriter struct { + msg *codec.Message + + pkt *codec.Message + + dat any + skip bool + header http.Header + + notifications chan *notifyEnv +} + +func (c *callRespWriter) Send(v any, err error) error { + if err != nil { + c.pkt.Error = err + return nil + } + c.dat = v + return nil +} + +func (c *callRespWriter) SetExtraField(k string, v any) error { + c.pkt.SetExtraField(k, v) + return nil +} + +func (c *callRespWriter) Header() http.Header { + return c.header +} + +func (c *callRespWriter) Notify(method string, v any) error { + c.notifications <- ¬ifyEnv{ + method: method, + dat: v, + extra: c.pkt.ExtraFields, + } + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index bbebdefdbd7902f52feeccb4e6587a7073d35ad2..ac440ea89c433c54d03c52166d1781756784a7d3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,8 +3,6 @@ package server import ( "bytes" "context" - "io" - "net/http" "sync" "sync/atomic" @@ -57,16 +55,20 @@ func (s *Server) codecLoop(ctx context.Context, remote codec.ReaderWriter, respo s.printError(remote, err) return err } - msg, batch := codec.ParseMessage(msgs) + incoming, batch := codec.ParseMessage(msgs) env := &callEnv{ batch: batch, } + // check for empty batch - if batch && len(msg) == 0 { + if batch && len(incoming) == 0 { // if it is empty batch, send the empty batch warning responder.toSend <- &callEnv{ responses: []*callRespWriter{{ - err: codec.NewInvalidRequestError("empty batch"), + pkt: &codec.Message{ + ID: codec.NewNullIDPtr(), + Error: codec.NewInvalidRequestError("empty batch"), + }, }}, batch: false, } @@ -74,29 +76,34 @@ func (s *Server) codecLoop(ctx context.Context, remote codec.ReaderWriter, respo } // populate the envelope - for _, v := range msg { + for _, v := range incoming { rw := &callRespWriter{ + pkt: &codec.Message{ + ID: codec.NewNullIDPtr(), + }, + msg: &codec.Message{ + ID: codec.NewNullIDPtr(), + }, notifications: responder.toNotify, header: remote.PeerInfo().HTTP.Headers, } - env.responses = append(env.responses, rw) - if v == nil { - continue - } - rw.msg = v - if v.ID != nil { - rw.id = *v.ID + if v != nil { + rw.msg = v + if v.ID != nil { + rw.pkt.ID = v.ID + } } + env.responses = append(env.responses, rw) } // create a waitgroup wg := sync.WaitGroup{} wg.Add(len(env.responses)) - for _, vv := range env.responses { - v := vv + for _, vRef := range env.responses { + v := vRef // early respond to nil requests if v.msg == nil || len(v.msg.Method) == 0 { - v.err = codec.NewInvalidRequestError("invalid request") + v.pkt.Error = codec.NewInvalidRequestError("invalid request") wg.Done() continue } @@ -108,15 +115,12 @@ func (s *Server) codecLoop(ctx context.Context, remote codec.ReaderWriter, respo } go func() { defer wg.Done() - s.services.ServeRPC(v, codec.NewRequestFromRaw( + r := codec.NewRequestFromMessage( ctx, - &codec.RequestMarshaling{ - ID: v.msg.ID, - Version: v.msg.Version, - Method: v.msg.Method, - Params: v.msg.Params, - Peer: remote.PeerInfo(), - })) + v.msg, + ) + r.Peer = remote.PeerInfo() + s.services.ServeRPC(v, r) }() } wg.Wait() @@ -215,39 +219,45 @@ func (c *callResponder) run(ctx context.Context) error { } } } + +type notifyEnv struct { + method string + dat any + extra []codec.RequestField +} + func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error { + enc := jx.NewStreamingEncoder(c.remote, 4096) + msg := &codec.Message{} + var err error + // allocate a temp buffer for this packet buf := bufpool.GetStd() defer bufpool.PutStd(buf) - enc := jx.GetEncoder() - enc.Reset() - defer jx.PutEncoder(enc) - buf.Reset() - enc.ObjStart() - enc.FieldStart("jsonrpc") - enc.Str("2.0") - enc.FieldStart("method") - enc.Str(env.method) - err := env.dat(buf) + err = json.NewEncoder(buf).Encode(env.dat) if err != nil { - enc.FieldStart("error") - err := codec.EncodeError(enc, err) - if err != nil { - return err - } + msg.Error = err } else { - enc.FieldStart("result") - enc.Raw(buf.Bytes()) + msg.Result = buf.Bytes() } - enc.ObjEnd() - _, err = enc.WriteTo(c.remote) + msg.ExtraFields = env.extra + // add the method + msg.Method = env.method + err = codec.MarshalMessage(msg, enc) if err != nil { return err } - return nil + return enc.Close() + +} + +type callEnv struct { + responses []*callRespWriter + batch bool } -func (c *callResponder) send(ctx context.Context, env *callEnv) error { +func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { // notification gets nothing + // if all msgs in batch are notification, we trigger an allSkip and write nothing if env.batch { allSkip := true for _, v := range env.responses { @@ -259,105 +269,43 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { return nil } } - enc := jx.GetEncoder() - enc.Reset() - // enc.ResetWriter(c.remote) - defer jx.PutEncoder(enc) + // create the streaming encoder + enc := jx.NewStreamingEncoder(c.remote, 4096) if env.batch { enc.ArrStart() } for _, v := range env.responses { - id := codec.Null - if v.id != nil { - id = v.id.RawMessage() - } + msg := v.pkt + // if we are a batch AND we are supposed to skip, then continue + // this means that for a non-batch notification, we do not skip! if env.batch && v.skip { continue } - enc.Obj(func(e *jx.Encoder) { - e.FieldStart("jsonrpc") - e.Str("2.0") - e.FieldStart("id") - e.Raw(id) - err := v.err - if err == nil { - if v.dat != nil { - buf := new(bytes.Buffer) - err = v.dat(buf) - if err == nil { - e.Field("result", func(e *jx.Encoder) { - e.Raw(bytes.TrimSpace(buf.Bytes())) - }) - } - } else { - err = codec.NewInvalidRequestError("invalid request") - } - } + // if there is no error, we try to marshal the result + if msg.Error == nil { + buf := bufpool.GetStd() + defer bufpool.PutStd(buf) + je := json.NewEncoder(buf) + err = je.EncodeWithOption(v.dat) if err != nil { - e.Field("error", func(e *jx.Encoder) { - codec.EncodeError(e, err) - }) + msg.Error = err + } else { + msg.Result = buf.Bytes() + msg.Result = bytes.TrimSuffix(msg.Result, []byte{'\n'}) } - }) + } + // then marshal the whole message into the stream + err := codec.MarshalMessage(msg, enc) + if err != nil { + return err + } } if env.batch { enc.ArrEnd() } - _, err := enc.WriteTo(c.remote) + err = enc.Close() if err != nil { return err } return nil } - -type notifyEnv struct { - method string - dat func(io.Writer) error -} - -type callEnv struct { - responses []*callRespWriter - batch bool -} - -var _ codec.ResponseWriter = (*callRespWriter)(nil) - -type callRespWriter struct { - id codec.ID - msg *codec.Message - dat func(io.Writer) error - err error - skip bool - header http.Header - - notifications chan *notifyEnv -} - -func (c *callRespWriter) Send(v any, err error) error { - if err != nil { - c.err = err - return nil - } - c.dat = func(w io.Writer) error { - return json.NewEncoder(w).Encode(v) - } - return nil -} - -func (c *callRespWriter) Option(k string, v any) { - // no options for now -} - -func (c *callRespWriter) Header() http.Header { - return c.header -} - -func (c *callRespWriter) Notify(method string, v any) error { - c.notifications <- ¬ifyEnv{ - method: method, - dat: func(w io.Writer) error { - return json.NewEncoder(w).Encode(v) - }, - } - return nil -} diff --git a/readme.md b/readme.md index cd8a2fc798e376395cdaa4d370a09dd6e421a5d8..6385cfd6990cbc0b6c112ea486e5a3784c98b22b 100644 --- a/readme.md +++ b/readme.md @@ -1,19 +1,40 @@ ## jrpc +```go get gfx.cafe/open/jrpc``` + this is a bottom up implementation of jsonrpc2, primarily made for hosting eth-like jsonrpc requests. -we extend the eth-rpc reflect based handler with go-http style request/response. +we extend the eth-rpc reflect based handler with go-http style request/response. + +we also make things like subscriptions additional extensions, so they are no longer baked into the rpc package. + +most users should mostly access the `jrpc` packages, along with a variety of things in `contrib` + +see examples in `examples` folder for usage + +it is currently being used in the oku.trade api in proxy, client, and server applications. -we also make things like subscriptions additional extensions, so they are no longer baked into the rpc package. +## features -most users should only ever need to access the "jrpc" and "pkg/codec" packages + - full jsonrpc2 protocol + - batch requests + notifications + - http.Request/http.ResponseWriter style semantics + - simple but powerful middleware framework + - subscription framework used by go-ethereum/rpc is implemented as middleware. + - http (with rest-like access via RPC verb), websocket, io.Reader/io.Writer (tcp, any net.Conn, etc), inproc codecs. + - using faster json packages (goccy/go-json and jx) + - extensions, which allow setting arbitrary fields on the parent object, like in sourcegraph jsonrpc2 + - jmux, which allows for http-like routing, implemented like `go-chi/v5`, except for jsonrpc2 paths + - argreflect, which allows mounting methods on structs to the rpc engine, like go-ethereum/rpc + - openrpc schema parser and code generator -it is currently being used in the oku.trade api +## maybe outdated but somewhat useful contribution info +basic structure ``` -exports.go - export things in subpackages to jrpc namespace, cleaning up the public use package. +exports.go - export things in subpackages to jrpc namespace, cleaning up the public use package. pkg/ - packages for implementing jrpc clientutil/ - common utilities for client implementations to use idreply.go - generalizes making a request with an incrementing id, then waiting on it @@ -54,9 +75,9 @@ contrib/ - packages that add to jrpc jmux/ - a chi based router which satisfies the jrpc.Handler interface handlers/ - special jrpc handlers argreflect/ - go-ethereum style struct reflection + middleware/ - pre implemented middleware extension/ - extensions to the protocol - middleware/ - pre implemented middleware - subscription/ - WIP: subscription engine for go-ethereum style subs + subscription/ - WIP: subscription engine for go-ethereum style subs ```