diff --git a/contrib/codecs/http/client.go b/contrib/codecs/http/client.go index fa8f2c14f1a62abb60097e201d1267091fdab39a..57097298bb29e2f81f330e030e6bcf79d9fe4265 100644 --- a/contrib/codecs/http/client.go +++ b/contrib/codecs/http/client.go @@ -14,8 +14,9 @@ import ( "gfx.cafe/open/jrpc/pkg/codec" - "gfx.cafe/open/jrpc/pkg/clientutil" "gfx.cafe/util/go/bufpool" + + "gfx.cafe/open/jrpc/pkg/clientutil" ) var ( @@ -107,7 +108,7 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any) } func (c *Client) post(req *codec.Request) (*http.Response, error) { - //TODO: use buffer for this + // TODO: use buffer for this buf := bufpool.GetStd() defer bufpool.PutStd(buf) buf.Reset() @@ -139,13 +140,13 @@ func (c *Client) Notify(ctx context.Context, method string, params any) error { func (c *Client) BatchCall(ctx context.Context, b ...*codec.BatchElem) error { reqs := make([]*codec.Request, len(b)) - ids := make([]int, 0, len(b)) - for _, v := range b { + ids := make(map[int]int, len(b)) + for idx, v := range b { if v.IsNotification { reqs = append(reqs, codec.NewRequest(ctx, "", v.Method, v.Params)) } else { id := int(c.id.Add(1)) - ids = append(ids, id) + ids[idx] = id reqs = append(reqs, codec.NewRequestInt(ctx, id, v.Method, v.Params)) } } diff --git a/contrib/codecs/inproc/inproc.go b/contrib/codecs/inproc/inproc.go index 53b60dea772b39e7766da3d65398e38797be87cf..5c35e02ed731ff867f7f31c7c152d94bc56d28d5 100644 --- a/contrib/codecs/inproc/inproc.go +++ b/contrib/codecs/inproc/inproc.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "io" + "sync" "gfx.cafe/open/jrpc/pkg/codec" ) @@ -13,9 +14,10 @@ type Codec struct { ctx context.Context cn func() - rd io.Reader - wr *bufio.Writer - msgs chan json.RawMessage + rd io.Reader + wrLock sync.Mutex + wr *bufio.Writer + msgs chan json.RawMessage } func NewCodec() *Codec { @@ -58,10 +60,14 @@ func (c *Codec) Close() error { } func (c *Codec) Write(p []byte) (n int, err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() return c.wr.Write(p) } func (c *Codec) Flush() (err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() return c.wr.Flush() } @@ -76,7 +82,7 @@ func (c *Codec) RemoteAddr() string { } // DialInProc attaches an in-process connection to the given RPC server. -//func DialInProc(handler *Server) *Client { +// func DialInProc(handler *Server) *Client { // initctx := context.Background() // c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { // p1, p2 := net.Pipe() @@ -84,4 +90,4 @@ func (c *Codec) RemoteAddr() string { // return NewCodec(p2), nil // }) // return c -//} +// } diff --git a/contrib/codecs/rdwr/codec.go b/contrib/codecs/rdwr/codec.go index 8a71503b8fff4ceeea7dac296b47c66fa5e5f080..a352b659790fef01dd87499132ade2d58015fe9b 100644 --- a/contrib/codecs/rdwr/codec.go +++ b/contrib/codecs/rdwr/codec.go @@ -4,18 +4,21 @@ import ( "bufio" "context" "io" + "sync" - "gfx.cafe/open/jrpc/pkg/codec" "github.com/goccy/go-json" + + "gfx.cafe/open/jrpc/pkg/codec" ) type Codec struct { ctx context.Context cn func() - rd io.Reader - wr *bufio.Writer - msgs chan json.RawMessage + rd io.Reader + wrLock sync.Mutex + wr *bufio.Writer + msgs chan json.RawMessage } func NewCodec(rd io.Reader, wr io.Writer, onError func(error)) *Codec { @@ -77,10 +80,14 @@ func (c *Codec) Close() error { } func (c *Codec) Write(p []byte) (n int, err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() return c.wr.Write(p) } func (c *Codec) Flush() (err error) { + c.wrLock.Lock() + defer c.wrLock.Unlock() c.wr.WriteByte('\n') return c.wr.Flush() } @@ -96,7 +103,7 @@ func (c *Codec) RemoteAddr() string { } // Dialrdwr attaches an in-process connection to the given RPC server. -//func Dialrdwr(handler *Server) *Client { +// func Dialrdwr(handler *Server) *Client { // initctx := context.Background() // c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { // p1, p2 := net.Pipe() @@ -104,4 +111,4 @@ func (c *Codec) RemoteAddr() string { // return NewCodec(p2), nil // }) // return c -//} +// } diff --git a/contrib/codecs/rdwr/codec_test.go b/contrib/codecs/rdwr/codec_test.go index e783d154f5e85985d72702387d33d031df11950f..6dcadf1494360280f226a8af7805bb16944a8636 100644 --- a/contrib/codecs/rdwr/codec_test.go +++ b/contrib/codecs/rdwr/codec_test.go @@ -24,7 +24,7 @@ func TestBasicSuite(t *testing.T) { s.ServeCodec(context.Background(), clientCodec) }() return s, func() codec.Conn { - return rdwr.NewClient(rd_s, wr_c, nil) + return rdwr.NewClient(rd_s, wr_c) }, func() {} }, }) diff --git a/contrib/codecs/rdwr/rdwr_test.go b/contrib/codecs/rdwr/rdwr_test.go index bd86dae16d4384b9524ff6db7b2d4439a96061df..277d3b57bda81b11f32a656ac37f2d5f475632c2 100644 --- a/contrib/codecs/rdwr/rdwr_test.go +++ b/contrib/codecs/rdwr/rdwr_test.go @@ -22,7 +22,7 @@ func TestRDWRSetup(t *testing.T) { rd_c, wr_c := io.Pipe() clientCodec := rdwr.NewCodec(rd_s, wr_c, nil) - client := rdwr.NewClient(rd_c, wr_s, nil) + client := rdwr.NewClient(rd_c, wr_s) go func() { srv.ServeCodec(ctx, clientCodec) }() diff --git a/pkg/clientutil/helper.go b/pkg/clientutil/helper.go index b0f70cd4fe601d2db84a4a278931f5d1dd9be852..c139aa02cbadce8e5cf0a7b5df3b9a4d43b57e90 100644 --- a/pkg/clientutil/helper.go +++ b/pkg/clientutil/helper.go @@ -3,14 +3,19 @@ package clientutil import ( "encoding/json" "fmt" - "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/util/go/generic" + + "gfx.cafe/open/jrpc/pkg/codec" ) var msgPool = generic.HookPool[*codec.Message]{ New: func() *codec.Message { return &codec.Message{} }, + FnPut: func(msg *codec.Message) { + *msg = codec.Message{} + }, } func GetMessage() *codec.Message { @@ -21,14 +26,13 @@ func PutMessage(x *codec.Message) { msgPool.Put(x) } -func FillBatch(ids []int, msgs []*codec.Message, b []*codec.BatchElem) { - answers := map[int]*codec.Message{} +func FillBatch(ids map[int]int, msgs []*codec.Message, b []*codec.BatchElem) { + answers := make(map[int]*codec.Message, len(msgs)) for _, v := range msgs { answers[v.ID.Number()] = v } - for i := range ids { - idx := i - ans, ok := answers[i] + for idx, id := range ids { + ans, ok := answers[id] if !ok { b[idx].Error = fmt.Errorf("No response found") continue diff --git a/pkg/clientutil/helper_test.go b/pkg/clientutil/helper_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ed765eea1022a07e87a400eb2e22faed840c08ea --- /dev/null +++ b/pkg/clientutil/helper_test.go @@ -0,0 +1,79 @@ +package clientutil + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gfx.cafe/open/jrpc/pkg/codec" +) + +func ptr[T any](v T) *T { + return &v +} + +func TestFillBatch(t *testing.T) { + msgs := []*codec.Message{ + { + ID: ptr(codec.ID(`"5"`)), + Result: json.RawMessage(`["test", "abc", "123"]`), + }, + { + ID: ptr(codec.ID(`"6"`)), + Result: json.RawMessage(`12345`), + }, + {}, + { + ID: ptr(codec.ID(`"7"`)), + Result: json.RawMessage(`"abcdefgh"`), + }, + } + ids := map[int]int{ + 0: 5, + 1: 6, + 3: 7, + } + b := []*codec.BatchElem{ + { + Result: new([]string), + }, + { + Result: new(int), + }, + {}, + { + Result: new(string), + }, + } + + FillBatch(ids, msgs, b) + + wantResult := []*codec.BatchElem{ + { + Result: &[]string{ + "test", + "abc", + "123", + }, + }, + { + Result: ptr(12345), + }, + {}, + { + Result: ptr("abcdefgh"), + }, + } + + require.EqualValues(t, len(b), len(wantResult)) + for i := range b { + expected := wantResult[i] + actual := b[i] + assert.EqualValuesf(t, expected.Method, actual.Method, "item %d", i) + assert.EqualValuesf(t, expected.Result, actual.Result, "item %d", i) + assert.EqualValuesf(t, expected.Params, actual.Params, "item %d", i) + assert.EqualValuesf(t, expected.Error, actual.Error, "item %d", i) + } +} diff --git a/pkg/clientutil/idreply_test.go b/pkg/clientutil/idreply_test.go new file mode 100644 index 0000000000000000000000000000000000000000..40fa92a6f65144f7b438c321e2bae43c3d6d3a34 --- /dev/null +++ b/pkg/clientutil/idreply_test.go @@ -0,0 +1,46 @@ +package clientutil + +import ( + "bytes" + "context" + "encoding/json" + "sync" + "testing" +) + +const count = 1000 + +func TestIdReply(t *testing.T) { + reply := NewIdReply() + + testMessage := json.RawMessage("{\"test\": 123}") + + var wg sync.WaitGroup + + wg.Add(count) + + for i := 0; i < count; i++ { + go func() { + defer wg.Done() + id := reply.NextId() + v, err := reply.Ask(context.Background(), id) + if err != nil { + t.Error(err) + return + } + + if !bytes.Equal(v, testMessage) { + t.Error("expected contents to be equal") + return + } + }() + } + + for i := 0; i < count; i++ { + go func(id int) { + reply.Resolve(id+1, testMessage, nil) + }(i) + } + + wg.Wait() +} diff --git a/pkg/codec/json.go b/pkg/codec/json.go index 1484569a366afe91af1da84457fc612be22011f1..c4988d113c7c58868d8e7b00f9b8004a833cafdc 100644 --- a/pkg/codec/json.go +++ b/pkg/codec/json.go @@ -4,8 +4,9 @@ import ( "encoding/json" "strconv" - "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" "github.com/go-faster/jx" + + "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" ) var jzon = wsjson.JZON @@ -40,9 +41,11 @@ func (m *Message) MarshalJSON() ([]byte, error) { e.Raw(m.ID.RawMessage()) }) } - e.Field("method", func(e *jx.Encoder) { - e.Str(m.Method) - }) + if m.Method != "" { + e.Field("method", func(e *jx.Encoder) { + e.Str(m.Method) + }) + } if m.Error != nil { e.Field("error", func(e *jx.Encoder) { xs, _ := json.Marshal(m.Error) diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go index 4953e983fafc2c0a7d298955270d8eea0e8f0918..46df154bfbedff9074f50ce874379fdfc5ae08b2 100644 --- a/pkg/jrpctest/suites.go +++ b/pkg/jrpctest/suites.go @@ -75,6 +75,11 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { Params: []any{"hello2", 11, &EchoArgs{"world"}}, Result: new(EchoResult), }, + { + Method: "test_echo", + Params: []any{"hello3", 12, &EchoArgs{"world"}}, + IsNotification: true, + }, { Method: "no_such_method", Params: []any{1, 2, 3}, @@ -95,6 +100,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { Params: []any{"hello2", 11, &EchoArgs{"world"}}, Result: &EchoResult{"hello2", 11, &EchoArgs{"world"}}, }, + { + Method: "test_echo", + Params: []any{"hello3", 12, &EchoArgs{"world"}}, + }, { Method: "no_such_method", Params: []any{1, 2, 3}, @@ -105,13 +114,11 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { require.EqualValues(t, len(batch), len(wantResult)) for i := range batch { a := batch[i] - b := batch[i] + b := wantResult[i] assert.EqualValuesf(t, a.Method, b.Method, "item %d", i) assert.EqualValuesf(t, a.Result, b.Result, "item %d", i) assert.EqualValuesf(t, a.Params, b.Params, "item %d", i) - if a.Error != nil { - assert.EqualValuesf(t, a.Error, b.Error, "item %d", i) - } + assert.EqualValuesf(t, a.Error, b.Error, "item %d", i) } }) @@ -146,8 +153,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { } }) makeTest("Notify", func(t *testing.T, server *server.Server, client codec.Conn) { - if err := client.Notify(context.Background(), "test_echo", []any{"hello", 10, &EchoArgs{"world"}}); err != nil { - t.Fatal(err) + if c, ok := client.(codec.StreamingConn); ok { + if err := c.Notify(context.Background(), "test_echo", []any{"hello", 10, &EchoArgs{"world"}}); err != nil { + t.Fatal(err) + } } }) diff --git a/pkg/server/server.go b/pkg/server/server.go index af92dab22170de15dd11b4b6c0864acf58b4cc6c..ae59fe5c0a2561eb95412e8ecb07d1f0cbed6ab8 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -249,7 +249,7 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { } enc := jx.GetEncoder() enc.Reset() - //enc.ResetWriter(c.remote) + // enc.ResetWriter(c.remote) defer jx.PutEncoder(enc) if env.batch { enc.ArrStart() @@ -267,8 +267,6 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { e.Str("2.0") e.FieldStart("id") e.Raw(id) - e.FieldStart("method") - e.Str(v.msg.Method) err := v.err if err == nil { if v.dat != nil {