From 983606bdc2ef1059fe1b254f83f2252a575a2609 Mon Sep 17 00:00:00 2001 From: a <a@tuxpa.in> Date: Sat, 10 Jun 2023 09:10:43 -0500 Subject: [PATCH] wip --- contrib/codecs/rdwr/client.go | 3 + contrib/codecs/rdwr/codec.go | 4 +- contrib/handlers/argreflect/json.go | 15 +- .../handlers/argreflect/reflect_handler.go | 4 + contrib/jmux/mux.go | 2 +- pkg/codec/errors.go | 4 +- pkg/codec/json.go | 64 ++++++- pkg/codec/wire.go | 18 +- pkg/jrpctest/services.go | 5 +- pkg/server/json_test.go | 8 + pkg/server/server.go | 161 +++++++++++------- pkg/server/server_test.go | 44 +++++ 12 files changed, 247 insertions(+), 85 deletions(-) create mode 100644 pkg/server/json_test.go create mode 100644 pkg/server/server_test.go diff --git a/contrib/codecs/rdwr/client.go b/contrib/codecs/rdwr/client.go index 0827a7a..6872e65 100644 --- a/contrib/codecs/rdwr/client.go +++ b/contrib/codecs/rdwr/client.go @@ -45,6 +45,9 @@ func (c *Client) listen() error { msgs, _ := codec.ParseMessage(msg) for i := range msgs { v := msgs[i] + if v == nil { + continue + } id := v.ID.Number() // messages without ids are notifications if id == 0 { diff --git a/contrib/codecs/rdwr/codec.go b/contrib/codecs/rdwr/codec.go index a292da3..8a71503 100644 --- a/contrib/codecs/rdwr/codec.go +++ b/contrib/codecs/rdwr/codec.go @@ -37,8 +37,9 @@ func NewCodec(rd io.Reader, wr io.Writer, onError func(error)) *Codec { } func (c *Codec) listen() error { - var msg json.RawMessage for { + var msg json.RawMessage + // reading a message err := json.NewDecoder(c.rd).Decode(&msg) if err != nil { c.cn() @@ -80,6 +81,7 @@ func (c *Codec) Write(p []byte) (n int, err error) { } func (c *Codec) Flush() (err error) { + c.wr.WriteByte('\n') return c.wr.Flush() } diff --git a/contrib/handlers/argreflect/json.go b/contrib/handlers/argreflect/json.go index a46fe42..8d1bdc3 100644 --- a/contrib/handlers/argreflect/json.go +++ b/contrib/handlers/argreflect/json.go @@ -2,10 +2,11 @@ package argreflect import ( "encoding/json" - "errors" "fmt" - "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" "reflect" + + "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" + "gfx.cafe/open/jrpc/pkg/codec" ) var jzon = wsjson.JZON @@ -26,12 +27,12 @@ func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([] case string(rawArgs) == "null": return nil, nil default: - return nil, errors.New("non-array args") + return nil, codec.NewInvalidParamsError("non-array args") } // Set any missing args to nil. for i := len(args); i < len(types); i++ { if types[i].Kind() != reflect.Ptr { - return nil, fmt.Errorf("missing value for required argument %d", i) + return nil, codec.NewInvalidParamsError(fmt.Sprintf("missing value for required argument %d", i)) } args = append(args, reflect.Zero(types[i])) } @@ -44,15 +45,15 @@ func parseArgumentArray(p json.RawMessage, types []reflect.Type) ([]reflect.Valu args := make([]reflect.Value, 0, len(types)) for i := 0; dec.ReadArray(); i++ { if i >= len(types) { - return args, fmt.Errorf("too many arguments, want at most %d", len(types)) + return args, codec.NewInvalidParamsError(fmt.Sprintf("too many arguments, want at most %d", len(types))) } argval := reflect.New(types[i]) dec.ReadVal(argval.Interface()) if err := dec.Error; err != nil { - return args, fmt.Errorf("invalid argument %d: %v", i, err) + return args, codec.NewInvalidParamsError(fmt.Sprintf("invalid argument %d: %v", i, err)) } if argval.IsNil() && types[i].Kind() != reflect.Ptr { - return args, fmt.Errorf("missing value for required argument %d", i) + return nil, codec.NewInvalidParamsError(fmt.Sprintf("missing value for required argument %d", i)) } args = append(args, argval.Elem()) } diff --git a/contrib/handlers/argreflect/reflect_handler.go b/contrib/handlers/argreflect/reflect_handler.go index a474a44..cf91be5 100644 --- a/contrib/handlers/argreflect/reflect_handler.go +++ b/contrib/handlers/argreflect/reflect_handler.go @@ -86,6 +86,10 @@ func (e *callback) ServeRPC(w codec.ResponseWriter, r *codec.Request) { w.Send(nil, err) return } + if len(results) == 0 { + w.Send(codec.Null, nil) + return + } w.Send(results[0].Interface(), nil) } diff --git a/contrib/jmux/mux.go b/contrib/jmux/mux.go index fb3ffb9..8fed2bc 100644 --- a/contrib/jmux/mux.go +++ b/contrib/jmux/mux.go @@ -423,5 +423,5 @@ func methodNotAllowedHandler(w codec.ResponseWriter, r *codec.Request) { } func NotFound(w codec.ResponseWriter, r *codec.Request) { - w.Send(nil, errors.New("not found: does not exist")) + w.Send(nil, codec.NewMethodNotFoundError(r.Method)) } diff --git a/pkg/codec/errors.go b/pkg/codec/errors.go index 4194623..f3736e6 100644 --- a/pkg/codec/errors.go +++ b/pkg/codec/errors.go @@ -154,8 +154,8 @@ func (e *ErrorInvalidMessage) ErrorCode() int { return -32700 } func (e *ErrorInvalidMessage) Error() string { return e.message } -func NewInvalidParamsError(message string) *ErrorInvalidMessage { - return &ErrorInvalidMessage{ +func NewInvalidParamsError(message string) *ErrorInvalidParams { + return &ErrorInvalidParams{ message: message, } } diff --git a/pkg/codec/json.go b/pkg/codec/json.go index f2f202f..8bcea93 100644 --- a/pkg/codec/json.go +++ b/pkg/codec/json.go @@ -1,17 +1,21 @@ package codec import ( - "bytes" "encoding/json" "strconv" "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" + "github.com/go-faster/jx" ) var jzon = wsjson.JZON var Null = json.RawMessage("null") +func NewNull() json.RawMessage { + return json.RawMessage("null") +} + // A value of this type can a JSON-RPC request, notification, successful response or // error response. Which one it is depends on the fields. type Message struct { @@ -24,6 +28,42 @@ type Message struct { Error *JsonError `json:"error,omitempty"` } +func (m *Message) MarshalJSON() ([]byte, error) { + var enc jx.Encoder + // use encoder + enc.Obj(func(e *jx.Encoder) { + e.Field("jsonrpc", func(e *jx.Encoder) { + e.Str("2.0") + }) + if m.ID != nil { + e.Field("id", func(e *jx.Encoder) { + e.Raw(m.ID.RawMessage()) + }) + } + e.Field("method", func(e *jx.Encoder) { + e.Str(m.Method) + }) + if m.Error == nil { + e.Field("error", func(e *jx.Encoder) { + xs, _ := json.Marshal(m.Error) + e.Raw(xs) + }) + } + if len(m.Params) != 0 { + e.Field("params", func(e *jx.Encoder) { + e.Raw(m.Params) + }) + } + if len(m.Result) != 0 { + e.Field("result", func(e *jx.Encoder) { + e.Raw(m.Result) + }) + } + }) + // output + return enc.Bytes(), nil +} + func MakeCall(id int, method string, params []any) *Message { return &Message{ ID: NewNumberIDPtr(int64(id)), @@ -139,13 +179,21 @@ func ParseMessage(raw json.RawMessage) ([]*Message, bool) { } // TODO: // for some reason other json decoders are incompatible with our test suite - // pretty sure its how we handle EOFs and stuff - dec := json.NewDecoder(bytes.NewReader(raw)) - dec.Token() // skip '[' + // pretty sure its how we horle EOFs and stuff + dec := jx.DecodeBytes(raw) var msgs []*Message - for dec.More() { - msgs = append(msgs, new(Message)) - dec.Decode(&msgs[len(msgs)-1]) - } + dec.Arr(func(d *jx.Decoder) error { + msg := new(Message) + raw, err := d.Raw() + if err != nil { + return nil + } + err = json.Unmarshal(raw, msg) + if err != nil { + msg = nil + } + msgs = append(msgs, msg) + return nil + }) return msgs, true } diff --git a/pkg/codec/wire.go b/pkg/codec/wire.go index 280935e..922e3af 100644 --- a/pkg/codec/wire.go +++ b/pkg/codec/wire.go @@ -122,5 +122,21 @@ func (id *ID) UnmarshalJSON(data []byte) error { return nil } *id = data - return nil + // now validate + if id.IsNull() { + return nil + } + // it has to be a string or number + var num int + err := json.Unmarshal(data, &num) + if err == nil { + return nil + } + var str string + err = json.Unmarshal(data, &str) + if err == nil { + return nil + } + *id = NewNullID() + return fmt.Errorf("invalid id") } diff --git a/pkg/jrpctest/services.go b/pkg/jrpctest/services.go index eb510c1..eab9194 100644 --- a/pkg/jrpctest/services.go +++ b/pkg/jrpctest/services.go @@ -3,10 +3,11 @@ package jrpctest import ( "context" "errors" - "gfx.cafe/open/jrpc/pkg/codec" - "gfx.cafe/open/jrpc/pkg/server" "strings" "time" + + "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/open/jrpc/pkg/server" ) type testService struct{} diff --git a/pkg/server/json_test.go b/pkg/server/json_test.go new file mode 100644 index 0000000..92eafae --- /dev/null +++ b/pkg/server/json_test.go @@ -0,0 +1,8 @@ +package server + +import ( + "testing" +) + +func TestJson(t *testing.T) { +} diff --git a/pkg/server/server.go b/pkg/server/server.go index af0bf46..bbdf0ac 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "io" "net/http" @@ -49,6 +50,74 @@ func (s *Server) printError(remote codec.ReaderWriter, err error) { } } +func (s *Server) codecLoop(ctx context.Context, remote codec.ReaderWriter, responder *callResponder) error { + msgs, err := remote.ReadBatch(ctx) + if err != nil { + remote.Flush() + s.printError(remote, err) + return err + } + msg, batch := codec.ParseMessage(msgs) + env := &callEnv{ + batch: batch, + } + // check for empty batch + if batch && len(msg) == 0 { + // if it is empty batch, send the empty batch warning + responder.toSend <- &callEnv{ + responses: []*callRespWriter{{ + err: codec.NewInvalidRequestError("empty batch"), + }}, + batch: false, + } + return nil + } + + // populate the envelope + for _, v := range msg { + rw := &callRespWriter{ + notifications: responder.toNotify, + header: remote.PeerInfo().HTTP.Headers, + } + env.responses = append(env.responses, rw) + if v == nil { + continue + } + rw.msg = v + if v.ID != nil { + rw.id = *v.ID + } + } + + // create a waitgroup + wg := sync.WaitGroup{} + wg.Add(len(msg)) + for _, vv := range env.responses { + v := vv + // early respond to nil requests + if v.msg == nil || v.msg.ID == nil || v.msg.ID.IsNull() || len(v.msg.Method) == 0 { + v.err = codec.NewInvalidRequestError("invalid request") + wg.Done() + continue + } + go func() { + defer wg.Done() + s.services.ServeRPC(v, codec.NewRequestFromRaw( + ctx, + &codec.RequestMarshaling{ + ID: v.msg.ID, + Version: v.msg.Version, + Method: v.msg.Method, + Params: v.msg.Params, + Peer: remote.PeerInfo(), + })) + }() + } + wg.Wait() + responder.toSend <- env + return nil +} + // ServeCodec reads incoming requests from codec, calls the appropriate callback and writes // the response back using the given codec. It will block until the codec is closed or the // server is stopped. In either case the codec is closed. @@ -93,47 +162,11 @@ func (s *Server) ServeCodec(pctx context.Context, remote codec.ReaderWriter) { }() for { - msgs, err := remote.ReadBatch(ctx) + err := s.codecLoop(ctx, remote, responder) if err != nil { - remote.Flush() s.printError(remote, err) return } - msg, batch := codec.ParseMessage(msgs) - env := &callEnv{ - batch: batch, - } - for _, v := range msg { - rw := &callRespWriter{ - msg: v, - notifications: responder.toNotify, - header: remote.PeerInfo().HTTP.Headers, - } - env.responses = append(env.responses, rw) - } - wg := sync.WaitGroup{} - wg.Add(len(msg)) - for _, vv := range env.responses { - v := vv - go func() { - if v.msg.ID == nil { - wg.Done() - } else { - defer wg.Done() - } - s.services.ServeRPC(v, codec.NewRequestFromRaw( - ctx, - &codec.RequestMarshaling{ - ID: v.msg.ID, - Version: v.msg.Version, - Method: v.msg.Method, - Params: v.msg.Params, - Peer: remote.PeerInfo(), - })) - }() - } - wg.Wait() - responder.toSend <- env } } @@ -190,12 +223,11 @@ func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error { if err != nil { return err } + return nil } func (c *callResponder) send(ctx context.Context, env *callEnv) error { - buf := bufpool.GetStd() - defer bufpool.PutStd(buf) enc := jx.GetEncoder() enc.Reset() //enc.ResetWriter(c.remote) @@ -204,40 +236,42 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { enc.ArrStart() } for _, v := range env.responses { - if v.msg.ID == nil { + id := codec.Null + if v.id != nil { + id = v.id.RawMessage() + } + if v.skip { continue } - enc.ObjStart() - enc.FieldStart("jsonrpc") - enc.Str("2.0") - enc.FieldStart("id") - enc.Raw(v.msg.ID.RawMessage()) - err := v.err - if err == nil { - if v.dat != nil { - buf.Reset() - err = v.dat(buf) - if err == nil { - enc.FieldStart("result") - enc.Raw(buf.Bytes()) + enc.Obj(func(e *jx.Encoder) { + e.FieldStart("jsonrpc") + e.Str("2.0") + e.FieldStart("id") + e.Raw(id) + err := v.err + if err == nil { + if v.dat != nil { + buf := new(bytes.Buffer) + err = v.dat(buf) + if err == nil { + e.Field("result", func(e *jx.Encoder) { + e.Raw(bytes.TrimSpace(buf.Bytes())) + }) + } + } else { + err = codec.NewInvalidRequestError("invalid request") } - } else { - err = codec.NewMethodNotFoundError(v.msg.Method) } - } - if err != nil { - enc.FieldStart("error") - err := codec.EncodeError(enc, err) if err != nil { - return err + e.Field("error", func(e *jx.Encoder) { + codec.EncodeError(e, err) + }) } - } - enc.ObjEnd() + }) } if env.batch { enc.ArrEnd() } - //err := enc.Close() _, err := enc.WriteTo(c.remote) if err != nil { return err @@ -258,6 +292,7 @@ type notifyEnv struct { var _ codec.ResponseWriter = (*callRespWriter)(nil) type callRespWriter struct { + id codec.ID msg *codec.Message dat func(io.Writer) error err error diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go new file mode 100644 index 0000000..c1d3d85 --- /dev/null +++ b/pkg/server/server_test.go @@ -0,0 +1,44 @@ +package server_test + +import ( + "bufio" + "context" + "net" + "strings" + "testing" + "time" + + "gfx.cafe/open/jrpc/contrib/codecs/rdwr" + "gfx.cafe/open/jrpc/pkg/jrpctest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGoEthereumTestScripts(t *testing.T) { + for _, tf := range jrpctest.OriginalTestData.Files { + t.Run(tf.Name, func(t *testing.T) { + // create a net pipe + rd, wr := net.Pipe() + readbuf := bufio.NewReader(rd) + srv := jrpctest.NewServer() + c := rdwr.NewCodec(wr, wr, func(err error) { + require.NoError(t, err) + }) + go srv.ServeCodec(context.TODO(), c) + defer srv.Stop() + for _, act := range tf.Action { + switch act.Direction { + case jrpctest.DirectionRecv: + rd.SetReadDeadline(time.Now().Add(5 * time.Second)) + sent, err := readbuf.ReadString('\n') + require.NoError(t, err) + assert.EqualValues(t, string(act.Data), strings.TrimSpace(sent)) + case jrpctest.DirectionSend: + rd.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, err := rd.Write(append(act.Data, ' ')) + require.NoError(t, err) + } + } + }) + } +} -- GitLab