From 00d78f820129fbb2b7cd066e58a32b153dd3490d Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Sat, 28 Oct 2023 02:44:56 -0500
Subject: [PATCH] tests pass...

---
 pkg/codec/json.go            |  3 +++
 pkg/server/responsewriter.go | 18 ++++++++++++++----
 pkg/server/server.go         | 27 ++++++++++++++++++++-------
 3 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/pkg/codec/json.go b/pkg/codec/json.go
index 77f2c8a..84662e0 100644
--- a/pkg/codec/json.go
+++ b/pkg/codec/json.go
@@ -84,6 +84,9 @@ func UnmarshalMessage(m *Message, dec *jx.Decoder) error {
 			}
 			buf := bytes.NewBuffer(make(json.RawMessage, len(val)))
 			buf.Write(val)
+			if m.ExtraFields == nil {
+				m.ExtraFields = ExtraFields{}
+			}
 			m.ExtraFields[key] = buf.Bytes()
 		case "jsonrpc":
 			value, err := d.Str()
diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go
index b353781..e23a03e 100644
--- a/pkg/server/responsewriter.go
+++ b/pkg/server/responsewriter.go
@@ -49,7 +49,7 @@ func (c *callRespWriter) Send(v any, e error) (err error) {
 	// ultimately they need to be buffered. there's some optimistic multiplexing you can
 	// do, but that felt really complicated and not worth the time.
 	if c.noStream {
-		if e != nil {
+		if c.err == nil {
 			c.err = e
 		}
 		if v != nil {
@@ -66,12 +66,22 @@ func (c *callRespWriter) Send(v any, e error) (err error) {
 		return err
 	}
 	defer c.cr.mu.Release(1)
-	err = c.cr.send(c.ctx, &callEnv{
-		v:           &v,
+	if c.err != nil {
+		e = c.err
+	}
+	ce := &callEnv{
 		err:         e,
 		id:          c.msg.ID,
 		extrafields: c.msg.ExtraFields,
-	})
+	}
+	if v != nil {
+		ce.v = &v
+	}
+
+	err = c.cr.send(c.ctx, ce)
+	if err != nil {
+		return err
+	}
 	err = c.cr.remote.Flush()
 	if err != nil {
 		return err
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 1a187bf..26d65a4 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -86,15 +86,23 @@ func (s *Server) serveBatch(ctx context.Context,
 	// check for empty batch
 	if r.batch && len(incoming) == 0 {
 		// if it is empty batch, send the empty batch error and immediately return
-		err := r.send(ctx, &callEnv{
-			pkt: &codec.Message{
-				ID:    codec.NewNullIDPtr(),
-				Error: codec.NewInvalidRequestError("empty batch"),
-			},
+		err := r.mu.Acquire(ctx, 1)
+		if err != nil {
+			return err
+		}
+		defer r.mu.Release(1)
+		err = r.send(ctx, &callEnv{
+			id:  codec.NewNullIDPtr(),
+			err: codec.NewInvalidRequestError("empty batch"),
 		})
 		if err != nil {
 			return err
 		}
+		err = r.remote.Flush()
+		if err != nil {
+			return err
+		}
+		return nil
 	}
 
 	rs := []*callRespWriter{}
@@ -111,9 +119,15 @@ func (s *Server) serveBatch(ctx context.Context,
 		// a nil incoming message means an empty response
 		if v == nil {
 			v = &codec.Message{ID: codec.NewNullIDPtr()}
+			rw.err = codec.NewInvalidRequestError("invalid request")
 		}
 		rw.msg = v
+		rw.msg.ExtraFields = codec.ExtraFields{}
+		rw.msg.Error = nil
 		if len(v.Method) == 0 {
+			if v.ID == nil {
+				v.ID = codec.NewNullIDPtr()
+			}
 			rw.err = codec.NewInvalidRequestError("invalid request")
 		}
 		if v.ID != nil {
@@ -154,7 +168,7 @@ func (s *Server) serveBatch(ctx context.Context,
 			s.services.ServeRPC(v, req)
 		}()
 	}
-	if r.batch {
+	if r.batch && totalRequests > 0 {
 		err = doneMu.Acquire(ctx, int64(totalRequests))
 		if err != nil {
 			return err
@@ -221,7 +235,6 @@ type callResponder struct {
 type callEnv struct {
 	v           *any
 	err         error
-	pkt         *codec.Message
 	id          *codec.ID
 	extrafields codec.ExtraFields
 }
-- 
GitLab