From 33ed508c046d6678a364ad4f7268d8f2cee59385 Mon Sep 17 00:00:00 2001
From: Garet Halliday <ghalliday@gfxlabs.io>
Date: Fri, 14 Jul 2023 02:19:48 +0000
Subject: [PATCH] Various test fixes

---
 contrib/codecs/http/client.go     | 11 +++--
 contrib/codecs/inproc/inproc.go   | 16 +++++--
 contrib/codecs/rdwr/codec.go      | 19 +++++---
 contrib/codecs/rdwr/codec_test.go |  2 +-
 contrib/codecs/rdwr/rdwr_test.go  |  2 +-
 pkg/clientutil/helper.go          | 16 ++++---
 pkg/clientutil/helper_test.go     | 79 +++++++++++++++++++++++++++++++
 pkg/clientutil/idreply_test.go    | 46 ++++++++++++++++++
 pkg/codec/json.go                 | 11 +++--
 pkg/jrpctest/suites.go            | 21 +++++---
 pkg/server/server.go              |  4 +-
 11 files changed, 190 insertions(+), 37 deletions(-)
 create mode 100644 pkg/clientutil/helper_test.go
 create mode 100644 pkg/clientutil/idreply_test.go

diff --git a/contrib/codecs/http/client.go b/contrib/codecs/http/client.go
index fa8f2c1..5709729 100644
--- a/contrib/codecs/http/client.go
+++ b/contrib/codecs/http/client.go
@@ -14,8 +14,9 @@ import (
 
 	"gfx.cafe/open/jrpc/pkg/codec"
 
-	"gfx.cafe/open/jrpc/pkg/clientutil"
 	"gfx.cafe/util/go/bufpool"
+
+	"gfx.cafe/open/jrpc/pkg/clientutil"
 )
 
 var (
@@ -107,7 +108,7 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any)
 }
 
 func (c *Client) post(req *codec.Request) (*http.Response, error) {
-	//TODO: use buffer for this
+	// TODO: use buffer for this
 	buf := bufpool.GetStd()
 	defer bufpool.PutStd(buf)
 	buf.Reset()
@@ -139,13 +140,13 @@ func (c *Client) Notify(ctx context.Context, method string, params any) error {
 
 func (c *Client) BatchCall(ctx context.Context, b ...*codec.BatchElem) error {
 	reqs := make([]*codec.Request, len(b))
-	ids := make([]int, 0, len(b))
-	for _, v := range b {
+	ids := make(map[int]int, len(b))
+	for idx, v := range b {
 		if v.IsNotification {
 			reqs = append(reqs, codec.NewRequest(ctx, "", v.Method, v.Params))
 		} else {
 			id := int(c.id.Add(1))
-			ids = append(ids, id)
+			ids[idx] = id
 			reqs = append(reqs, codec.NewRequestInt(ctx, id, v.Method, v.Params))
 		}
 	}
diff --git a/contrib/codecs/inproc/inproc.go b/contrib/codecs/inproc/inproc.go
index 53b60de..5c35e02 100644
--- a/contrib/codecs/inproc/inproc.go
+++ b/contrib/codecs/inproc/inproc.go
@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"io"
+	"sync"
 
 	"gfx.cafe/open/jrpc/pkg/codec"
 )
@@ -13,9 +14,10 @@ type Codec struct {
 	ctx context.Context
 	cn  func()
 
-	rd   io.Reader
-	wr   *bufio.Writer
-	msgs chan json.RawMessage
+	rd     io.Reader
+	wrLock sync.Mutex
+	wr     *bufio.Writer
+	msgs   chan json.RawMessage
 }
 
 func NewCodec() *Codec {
@@ -58,10 +60,14 @@ func (c *Codec) Close() error {
 }
 
 func (c *Codec) Write(p []byte) (n int, err error) {
+	c.wrLock.Lock()
+	defer c.wrLock.Unlock()
 	return c.wr.Write(p)
 }
 
 func (c *Codec) Flush() (err error) {
+	c.wrLock.Lock()
+	defer c.wrLock.Unlock()
 	return c.wr.Flush()
 }
 
@@ -76,7 +82,7 @@ func (c *Codec) RemoteAddr() string {
 }
 
 // DialInProc attaches an in-process connection to the given RPC server.
-//func DialInProc(handler *Server) *Client {
+// func DialInProc(handler *Server) *Client {
 //	initctx := context.Background()
 //	c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
 //		p1, p2 := net.Pipe()
@@ -84,4 +90,4 @@ func (c *Codec) RemoteAddr() string {
 //		return NewCodec(p2), nil
 //	})
 //	return c
-//}
+// }
diff --git a/contrib/codecs/rdwr/codec.go b/contrib/codecs/rdwr/codec.go
index 8a71503..a352b65 100644
--- a/contrib/codecs/rdwr/codec.go
+++ b/contrib/codecs/rdwr/codec.go
@@ -4,18 +4,21 @@ import (
 	"bufio"
 	"context"
 	"io"
+	"sync"
 
-	"gfx.cafe/open/jrpc/pkg/codec"
 	"github.com/goccy/go-json"
+
+	"gfx.cafe/open/jrpc/pkg/codec"
 )
 
 type Codec struct {
 	ctx context.Context
 	cn  func()
 
-	rd   io.Reader
-	wr   *bufio.Writer
-	msgs chan json.RawMessage
+	rd     io.Reader
+	wrLock sync.Mutex
+	wr     *bufio.Writer
+	msgs   chan json.RawMessage
 }
 
 func NewCodec(rd io.Reader, wr io.Writer, onError func(error)) *Codec {
@@ -77,10 +80,14 @@ func (c *Codec) Close() error {
 }
 
 func (c *Codec) Write(p []byte) (n int, err error) {
+	c.wrLock.Lock()
+	defer c.wrLock.Unlock()
 	return c.wr.Write(p)
 }
 
 func (c *Codec) Flush() (err error) {
+	c.wrLock.Lock()
+	defer c.wrLock.Unlock()
 	c.wr.WriteByte('\n')
 	return c.wr.Flush()
 }
@@ -96,7 +103,7 @@ func (c *Codec) RemoteAddr() string {
 }
 
 // Dialrdwr attaches an in-process connection to the given RPC server.
-//func Dialrdwr(handler *Server) *Client {
+// func Dialrdwr(handler *Server) *Client {
 //	initctx := context.Background()
 //	c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
 //		p1, p2 := net.Pipe()
@@ -104,4 +111,4 @@ func (c *Codec) RemoteAddr() string {
 //		return NewCodec(p2), nil
 //	})
 //	return c
-//}
+// }
diff --git a/contrib/codecs/rdwr/codec_test.go b/contrib/codecs/rdwr/codec_test.go
index e783d15..6dcadf1 100644
--- a/contrib/codecs/rdwr/codec_test.go
+++ b/contrib/codecs/rdwr/codec_test.go
@@ -24,7 +24,7 @@ func TestBasicSuite(t *testing.T) {
 				s.ServeCodec(context.Background(), clientCodec)
 			}()
 			return s, func() codec.Conn {
-				return rdwr.NewClient(rd_s, wr_c, nil)
+				return rdwr.NewClient(rd_s, wr_c)
 			}, func() {}
 		},
 	})
diff --git a/contrib/codecs/rdwr/rdwr_test.go b/contrib/codecs/rdwr/rdwr_test.go
index bd86dae..277d3b5 100644
--- a/contrib/codecs/rdwr/rdwr_test.go
+++ b/contrib/codecs/rdwr/rdwr_test.go
@@ -22,7 +22,7 @@ func TestRDWRSetup(t *testing.T) {
 	rd_c, wr_c := io.Pipe()
 
 	clientCodec := rdwr.NewCodec(rd_s, wr_c, nil)
-	client := rdwr.NewClient(rd_c, wr_s, nil)
+	client := rdwr.NewClient(rd_c, wr_s)
 	go func() {
 		srv.ServeCodec(ctx, clientCodec)
 	}()
diff --git a/pkg/clientutil/helper.go b/pkg/clientutil/helper.go
index b0f70cd..c139aa0 100644
--- a/pkg/clientutil/helper.go
+++ b/pkg/clientutil/helper.go
@@ -3,14 +3,19 @@ package clientutil
 import (
 	"encoding/json"
 	"fmt"
-	"gfx.cafe/open/jrpc/pkg/codec"
+
 	"gfx.cafe/util/go/generic"
+
+	"gfx.cafe/open/jrpc/pkg/codec"
 )
 
 var msgPool = generic.HookPool[*codec.Message]{
 	New: func() *codec.Message {
 		return &codec.Message{}
 	},
+	FnPut: func(msg *codec.Message) {
+		*msg = codec.Message{}
+	},
 }
 
 func GetMessage() *codec.Message {
@@ -21,14 +26,13 @@ func PutMessage(x *codec.Message) {
 	msgPool.Put(x)
 }
 
-func FillBatch(ids []int, msgs []*codec.Message, b []*codec.BatchElem) {
-	answers := map[int]*codec.Message{}
+func FillBatch(ids map[int]int, msgs []*codec.Message, b []*codec.BatchElem) {
+	answers := make(map[int]*codec.Message, len(msgs))
 	for _, v := range msgs {
 		answers[v.ID.Number()] = v
 	}
-	for i := range ids {
-		idx := i
-		ans, ok := answers[i]
+	for idx, id := range ids {
+		ans, ok := answers[id]
 		if !ok {
 			b[idx].Error = fmt.Errorf("No response found")
 			continue
diff --git a/pkg/clientutil/helper_test.go b/pkg/clientutil/helper_test.go
new file mode 100644
index 0000000..ed765ee
--- /dev/null
+++ b/pkg/clientutil/helper_test.go
@@ -0,0 +1,79 @@
+package clientutil
+
+import (
+	"encoding/json"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+
+	"gfx.cafe/open/jrpc/pkg/codec"
+)
+
+func ptr[T any](v T) *T {
+	return &v
+}
+
+func TestFillBatch(t *testing.T) {
+	msgs := []*codec.Message{
+		{
+			ID:     ptr(codec.ID(`"5"`)),
+			Result: json.RawMessage(`["test", "abc", "123"]`),
+		},
+		{
+			ID:     ptr(codec.ID(`"6"`)),
+			Result: json.RawMessage(`12345`),
+		},
+		{},
+		{
+			ID:     ptr(codec.ID(`"7"`)),
+			Result: json.RawMessage(`"abcdefgh"`),
+		},
+	}
+	ids := map[int]int{
+		0: 5,
+		1: 6,
+		3: 7,
+	}
+	b := []*codec.BatchElem{
+		{
+			Result: new([]string),
+		},
+		{
+			Result: new(int),
+		},
+		{},
+		{
+			Result: new(string),
+		},
+	}
+
+	FillBatch(ids, msgs, b)
+
+	wantResult := []*codec.BatchElem{
+		{
+			Result: &[]string{
+				"test",
+				"abc",
+				"123",
+			},
+		},
+		{
+			Result: ptr(12345),
+		},
+		{},
+		{
+			Result: ptr("abcdefgh"),
+		},
+	}
+
+	require.EqualValues(t, len(b), len(wantResult))
+	for i := range b {
+		expected := wantResult[i]
+		actual := b[i]
+		assert.EqualValuesf(t, expected.Method, actual.Method, "item %d", i)
+		assert.EqualValuesf(t, expected.Result, actual.Result, "item %d", i)
+		assert.EqualValuesf(t, expected.Params, actual.Params, "item %d", i)
+		assert.EqualValuesf(t, expected.Error, actual.Error, "item %d", i)
+	}
+}
diff --git a/pkg/clientutil/idreply_test.go b/pkg/clientutil/idreply_test.go
new file mode 100644
index 0000000..40fa92a
--- /dev/null
+++ b/pkg/clientutil/idreply_test.go
@@ -0,0 +1,46 @@
+package clientutil
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"sync"
+	"testing"
+)
+
+const count = 1000
+
+func TestIdReply(t *testing.T) {
+	reply := NewIdReply()
+
+	testMessage := json.RawMessage("{\"test\": 123}")
+
+	var wg sync.WaitGroup
+
+	wg.Add(count)
+
+	for i := 0; i < count; i++ {
+		go func() {
+			defer wg.Done()
+			id := reply.NextId()
+			v, err := reply.Ask(context.Background(), id)
+			if err != nil {
+				t.Error(err)
+				return
+			}
+
+			if !bytes.Equal(v, testMessage) {
+				t.Error("expected contents to be equal")
+				return
+			}
+		}()
+	}
+
+	for i := 0; i < count; i++ {
+		go func(id int) {
+			reply.Resolve(id+1, testMessage, nil)
+		}(i)
+	}
+
+	wg.Wait()
+}
diff --git a/pkg/codec/json.go b/pkg/codec/json.go
index 1484569..c4988d1 100644
--- a/pkg/codec/json.go
+++ b/pkg/codec/json.go
@@ -4,8 +4,9 @@ import (
 	"encoding/json"
 	"strconv"
 
-	"gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson"
 	"github.com/go-faster/jx"
+
+	"gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson"
 )
 
 var jzon = wsjson.JZON
@@ -40,9 +41,11 @@ func (m *Message) MarshalJSON() ([]byte, error) {
 				e.Raw(m.ID.RawMessage())
 			})
 		}
-		e.Field("method", func(e *jx.Encoder) {
-			e.Str(m.Method)
-		})
+		if m.Method != "" {
+			e.Field("method", func(e *jx.Encoder) {
+				e.Str(m.Method)
+			})
+		}
 		if m.Error != nil {
 			e.Field("error", func(e *jx.Encoder) {
 				xs, _ := json.Marshal(m.Error)
diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go
index 4953e98..46df154 100644
--- a/pkg/jrpctest/suites.go
+++ b/pkg/jrpctest/suites.go
@@ -75,6 +75,11 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) {
 				Params: []any{"hello2", 11, &EchoArgs{"world"}},
 				Result: new(EchoResult),
 			},
+			{
+				Method:         "test_echo",
+				Params:         []any{"hello3", 12, &EchoArgs{"world"}},
+				IsNotification: true,
+			},
 			{
 				Method: "no_such_method",
 				Params: []any{1, 2, 3},
@@ -95,6 +100,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) {
 				Params: []any{"hello2", 11, &EchoArgs{"world"}},
 				Result: &EchoResult{"hello2", 11, &EchoArgs{"world"}},
 			},
+			{
+				Method: "test_echo",
+				Params: []any{"hello3", 12, &EchoArgs{"world"}},
+			},
 			{
 				Method: "no_such_method",
 				Params: []any{1, 2, 3},
@@ -105,13 +114,11 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) {
 		require.EqualValues(t, len(batch), len(wantResult))
 		for i := range batch {
 			a := batch[i]
-			b := batch[i]
+			b := wantResult[i]
 			assert.EqualValuesf(t, a.Method, b.Method, "item %d", i)
 			assert.EqualValuesf(t, a.Result, b.Result, "item %d", i)
 			assert.EqualValuesf(t, a.Params, b.Params, "item %d", i)
-			if a.Error != nil {
-				assert.EqualValuesf(t, a.Error, b.Error, "item %d", i)
-			}
+			assert.EqualValuesf(t, a.Error, b.Error, "item %d", i)
 		}
 	})
 
@@ -146,8 +153,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) {
 		}
 	})
 	makeTest("Notify", func(t *testing.T, server *server.Server, client codec.Conn) {
-		if err := client.Notify(context.Background(), "test_echo", []any{"hello", 10, &EchoArgs{"world"}}); err != nil {
-			t.Fatal(err)
+		if c, ok := client.(codec.StreamingConn); ok {
+			if err := c.Notify(context.Background(), "test_echo", []any{"hello", 10, &EchoArgs{"world"}}); err != nil {
+				t.Fatal(err)
+			}
 		}
 	})
 
diff --git a/pkg/server/server.go b/pkg/server/server.go
index af92dab..ae59fe5 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -249,7 +249,7 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error {
 	}
 	enc := jx.GetEncoder()
 	enc.Reset()
-	//enc.ResetWriter(c.remote)
+	// enc.ResetWriter(c.remote)
 	defer jx.PutEncoder(enc)
 	if env.batch {
 		enc.ArrStart()
@@ -267,8 +267,6 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error {
 			e.Str("2.0")
 			e.FieldStart("id")
 			e.Raw(id)
-			e.FieldStart("method")
-			e.Str(v.msg.Method)
 			err := v.err
 			if err == nil {
 				if v.dat != nil {
-- 
GitLab