From 33ed508c046d6678a364ad4f7268d8f2cee59385 Mon Sep 17 00:00:00 2001 From: Garet Halliday <ghalliday@gfxlabs.io> Date: Fri, 14 Jul 2023 02:19:48 +0000 Subject: [PATCH] Various test fixes --- contrib/codecs/http/client.go | 11 +++-- contrib/codecs/inproc/inproc.go | 16 +++++-- contrib/codecs/rdwr/codec.go | 19 +++++--- contrib/codecs/rdwr/codec_test.go | 2 +- contrib/codecs/rdwr/rdwr_test.go | 2 +- pkg/clientutil/helper.go | 16 ++++--- pkg/clientutil/helper_test.go | 79 +++++++++++++++++++++++++++++++ pkg/clientutil/idreply_test.go | 46 ++++++++++++++++++ pkg/codec/json.go | 11 +++-- pkg/jrpctest/suites.go | 21 +++++--- pkg/server/server.go | 4 +- 11 files changed, 190 insertions(+), 37 deletions(-) create mode 100644 pkg/clientutil/helper_test.go create mode 100644 pkg/clientutil/idreply_test.go diff --git a/contrib/codecs/http/client.go b/contrib/codecs/http/client.go index fa8f2c1..5709729 100644 --- a/contrib/codecs/http/client.go +++ b/contrib/codecs/http/client.go @@ -14,8 +14,9 @@ import ( "gfx.cafe/open/jrpc/pkg/codec" - "gfx.cafe/open/jrpc/pkg/clientutil" "gfx.cafe/util/go/bufpool" + + "gfx.cafe/open/jrpc/pkg/clientutil" ) var ( @@ -107,7 +108,7 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any) } func (c *Client) post(req *codec.Request) (*http.Response, error) { - //TODO: use buffer for this + // TODO: use buffer for this buf := bufpool.GetStd() defer bufpool.PutStd(buf) buf.Reset() @@ -139,13 +140,13 @@ func (c *Client) Notify(ctx context.Context, method string, params any) error { func (c *Client) BatchCall(ctx context.Context, b ...*codec.BatchElem) error { reqs := make([]*codec.Request, len(b)) - ids := make([]int, 0, len(b)) - for _, v := range b { + ids := make(map[int]int, len(b)) + for idx, v := range b { if v.IsNotification { reqs = append(reqs, codec.NewRequest(ctx, "", v.Method, v.Params)) } else { id := int(c.id.Add(1)) - ids = append(ids, id) + ids[idx] = id reqs = append(reqs, codec.NewRequestInt(ctx, id, v.Method, v.Params)) } } diff --git a/contrib/codecs/inproc/inproc.go b/contrib/codecs/inproc/inproc.go index 53b60de..5c35e02 100644 --- a/contrib/codecs/inproc/inproc.go +++ b/contrib/codecs/inproc/inproc.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "io" + "sync" "gfx.cafe/open/jrpc/pkg/codec" ) @@ -13,9 +14,10 @@ type Codec struct { ctx context.Context cn func() - rd io.Reader - wr *bufio.Writer - msgs chan json.RawMessage + rd io.Reader + wrLock sync.Mutex + wr *bufio.Writer + msgs chan json.RawMessage } func NewCodec() *Codec { @@ -58,10 +60,14 @@ func (c *Codec) Close() error { } func (c *Codec) Write(p []byte) (n int, err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() return c.wr.Write(p) } func (c *Codec) Flush() (err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() return c.wr.Flush() } @@ -76,7 +82,7 @@ func (c *Codec) RemoteAddr() string { } // DialInProc attaches an in-process connection to the given RPC server. -//func DialInProc(handler *Server) *Client { +// func DialInProc(handler *Server) *Client { // initctx := context.Background() // c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { // p1, p2 := net.Pipe() @@ -84,4 +90,4 @@ func (c *Codec) RemoteAddr() string { // return NewCodec(p2), nil // }) // return c -//} +// } diff --git a/contrib/codecs/rdwr/codec.go b/contrib/codecs/rdwr/codec.go index 8a71503..a352b65 100644 --- a/contrib/codecs/rdwr/codec.go +++ b/contrib/codecs/rdwr/codec.go @@ -4,18 +4,21 @@ import ( "bufio" "context" "io" + "sync" - "gfx.cafe/open/jrpc/pkg/codec" "github.com/goccy/go-json" + + "gfx.cafe/open/jrpc/pkg/codec" ) type Codec struct { ctx context.Context cn func() - rd io.Reader - wr *bufio.Writer - msgs chan json.RawMessage + rd io.Reader + wrLock sync.Mutex + wr *bufio.Writer + msgs chan json.RawMessage } func NewCodec(rd io.Reader, wr io.Writer, onError func(error)) *Codec { @@ -77,10 +80,14 @@ func (c *Codec) Close() error { } func (c *Codec) Write(p []byte) (n int, err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() return c.wr.Write(p) } func (c *Codec) Flush() (err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() c.wr.WriteByte('\n') return c.wr.Flush() } @@ -96,7 +103,7 @@ func (c *Codec) RemoteAddr() string { } // Dialrdwr attaches an in-process connection to the given RPC server. -//func Dialrdwr(handler *Server) *Client { +// func Dialrdwr(handler *Server) *Client { // initctx := context.Background() // c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { // p1, p2 := net.Pipe() @@ -104,4 +111,4 @@ func (c *Codec) RemoteAddr() string { // return NewCodec(p2), nil // }) // return c -//} +// } diff --git a/contrib/codecs/rdwr/codec_test.go b/contrib/codecs/rdwr/codec_test.go index e783d15..6dcadf1 100644 --- a/contrib/codecs/rdwr/codec_test.go +++ b/contrib/codecs/rdwr/codec_test.go @@ -24,7 +24,7 @@ func TestBasicSuite(t *testing.T) { s.ServeCodec(context.Background(), clientCodec) }() return s, func() codec.Conn { - return rdwr.NewClient(rd_s, wr_c, nil) + return rdwr.NewClient(rd_s, wr_c) }, func() {} }, }) diff --git a/contrib/codecs/rdwr/rdwr_test.go b/contrib/codecs/rdwr/rdwr_test.go index bd86dae..277d3b5 100644 --- a/contrib/codecs/rdwr/rdwr_test.go +++ b/contrib/codecs/rdwr/rdwr_test.go @@ -22,7 +22,7 @@ func TestRDWRSetup(t *testing.T) { rd_c, wr_c := io.Pipe() clientCodec := rdwr.NewCodec(rd_s, wr_c, nil) - client := rdwr.NewClient(rd_c, wr_s, nil) + client := rdwr.NewClient(rd_c, wr_s) go func() { srv.ServeCodec(ctx, clientCodec) }() diff --git a/pkg/clientutil/helper.go b/pkg/clientutil/helper.go index b0f70cd..c139aa0 100644 --- a/pkg/clientutil/helper.go +++ b/pkg/clientutil/helper.go @@ -3,14 +3,19 @@ package clientutil import ( "encoding/json" "fmt" - "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/util/go/generic" + + "gfx.cafe/open/jrpc/pkg/codec" ) var msgPool = generic.HookPool[*codec.Message]{ New: func() *codec.Message { return &codec.Message{} }, + FnPut: func(msg *codec.Message) { + *msg = codec.Message{} + }, } func GetMessage() *codec.Message { @@ -21,14 +26,13 @@ func PutMessage(x *codec.Message) { msgPool.Put(x) } -func FillBatch(ids []int, msgs []*codec.Message, b []*codec.BatchElem) { - answers := map[int]*codec.Message{} +func FillBatch(ids map[int]int, msgs []*codec.Message, b []*codec.BatchElem) { + answers := make(map[int]*codec.Message, len(msgs)) for _, v := range msgs { answers[v.ID.Number()] = v } - for i := range ids { - idx := i - ans, ok := answers[i] + for idx, id := range ids { + ans, ok := answers[id] if !ok { b[idx].Error = fmt.Errorf("No response found") continue diff --git a/pkg/clientutil/helper_test.go b/pkg/clientutil/helper_test.go new file mode 100644 index 0000000..ed765ee --- /dev/null +++ b/pkg/clientutil/helper_test.go @@ -0,0 +1,79 @@ +package clientutil + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gfx.cafe/open/jrpc/pkg/codec" +) + +func ptr[T any](v T) *T { + return &v +} + +func TestFillBatch(t *testing.T) { + msgs := []*codec.Message{ + { + ID: ptr(codec.ID(`"5"`)), + Result: json.RawMessage(`["test", "abc", "123"]`), + }, + { + ID: ptr(codec.ID(`"6"`)), + Result: json.RawMessage(`12345`), + }, + {}, + { + ID: ptr(codec.ID(`"7"`)), + Result: json.RawMessage(`"abcdefgh"`), + }, + } + ids := map[int]int{ + 0: 5, + 1: 6, + 3: 7, + } + b := []*codec.BatchElem{ + { + Result: new([]string), + }, + { + Result: new(int), + }, + {}, + { + Result: new(string), + }, + } + + FillBatch(ids, msgs, b) + + wantResult := []*codec.BatchElem{ + { + Result: &[]string{ + "test", + "abc", + "123", + }, + }, + { + Result: ptr(12345), + }, + {}, + { + Result: ptr("abcdefgh"), + }, + } + + require.EqualValues(t, len(b), len(wantResult)) + for i := range b { + expected := wantResult[i] + actual := b[i] + assert.EqualValuesf(t, expected.Method, actual.Method, "item %d", i) + assert.EqualValuesf(t, expected.Result, actual.Result, "item %d", i) + assert.EqualValuesf(t, expected.Params, actual.Params, "item %d", i) + assert.EqualValuesf(t, expected.Error, actual.Error, "item %d", i) + } +} diff --git a/pkg/clientutil/idreply_test.go b/pkg/clientutil/idreply_test.go new file mode 100644 index 0000000..40fa92a --- /dev/null +++ b/pkg/clientutil/idreply_test.go @@ -0,0 +1,46 @@ +package clientutil + +import ( + "bytes" + "context" + "encoding/json" + "sync" + "testing" +) + +const count = 1000 + +func TestIdReply(t *testing.T) { + reply := NewIdReply() + + testMessage := json.RawMessage("{\"test\": 123}") + + var wg sync.WaitGroup + + wg.Add(count) + + for i := 0; i < count; i++ { + go func() { + defer wg.Done() + id := reply.NextId() + v, err := reply.Ask(context.Background(), id) + if err != nil { + t.Error(err) + return + } + + if !bytes.Equal(v, testMessage) { + t.Error("expected contents to be equal") + return + } + }() + } + + for i := 0; i < count; i++ { + go func(id int) { + reply.Resolve(id+1, testMessage, nil) + }(i) + } + + wg.Wait() +} diff --git a/pkg/codec/json.go b/pkg/codec/json.go index 1484569..c4988d1 100644 --- a/pkg/codec/json.go +++ b/pkg/codec/json.go @@ -4,8 +4,9 @@ import ( "encoding/json" "strconv" - "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" "github.com/go-faster/jx" + + "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" ) var jzon = wsjson.JZON @@ -40,9 +41,11 @@ func (m *Message) MarshalJSON() ([]byte, error) { e.Raw(m.ID.RawMessage()) }) } - e.Field("method", func(e *jx.Encoder) { - e.Str(m.Method) - }) + if m.Method != "" { + e.Field("method", func(e *jx.Encoder) { + e.Str(m.Method) + }) + } if m.Error != nil { e.Field("error", func(e *jx.Encoder) { xs, _ := json.Marshal(m.Error) diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go index 4953e98..46df154 100644 --- a/pkg/jrpctest/suites.go +++ b/pkg/jrpctest/suites.go @@ -75,6 +75,11 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { Params: []any{"hello2", 11, &EchoArgs{"world"}}, Result: new(EchoResult), }, + { + Method: "test_echo", + Params: []any{"hello3", 12, &EchoArgs{"world"}}, + IsNotification: true, + }, { Method: "no_such_method", Params: []any{1, 2, 3}, @@ -95,6 +100,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { Params: []any{"hello2", 11, &EchoArgs{"world"}}, Result: &EchoResult{"hello2", 11, &EchoArgs{"world"}}, }, + { + Method: "test_echo", + Params: []any{"hello3", 12, &EchoArgs{"world"}}, + }, { Method: "no_such_method", Params: []any{1, 2, 3}, @@ -105,13 +114,11 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { require.EqualValues(t, len(batch), len(wantResult)) for i := range batch { a := batch[i] - b := batch[i] + b := wantResult[i] assert.EqualValuesf(t, a.Method, b.Method, "item %d", i) assert.EqualValuesf(t, a.Result, b.Result, "item %d", i) assert.EqualValuesf(t, a.Params, b.Params, "item %d", i) - if a.Error != nil { - assert.EqualValuesf(t, a.Error, b.Error, "item %d", i) - } + assert.EqualValuesf(t, a.Error, b.Error, "item %d", i) } }) @@ -146,8 +153,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { } }) makeTest("Notify", func(t *testing.T, server *server.Server, client codec.Conn) { - if err := client.Notify(context.Background(), "test_echo", []any{"hello", 10, &EchoArgs{"world"}}); err != nil { - t.Fatal(err) + if c, ok := client.(codec.StreamingConn); ok { + if err := c.Notify(context.Background(), "test_echo", []any{"hello", 10, &EchoArgs{"world"}}); err != nil { + t.Fatal(err) + } } }) diff --git a/pkg/server/server.go b/pkg/server/server.go index af92dab..ae59fe5 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -249,7 +249,7 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { } enc := jx.GetEncoder() enc.Reset() - //enc.ResetWriter(c.remote) + // enc.ResetWriter(c.remote) defer jx.PutEncoder(enc) if env.batch { enc.ArrStart() @@ -267,8 +267,6 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { e.Str("2.0") e.FieldStart("id") e.Raw(id) - e.FieldStart("method") - e.Str(v.msg.Method) err := v.err if err == nil { if v.dat != nil { -- GitLab