From 52ac4c1ce89c58c1cb3588204b6957272f4e52a7 Mon Sep 17 00:00:00 2001
From: a <a@a.a>
Date: Sat, 17 Sep 2022 07:45:21 -0500
Subject: [PATCH] jayzon

---
 client.go        |  9 +++---
 go.mod           |  1 +
 go.sum           |  2 ++
 http.go          |  4 +--
 json.go          | 18 ++++++++++--
 protocol.go      | 14 ++++-----
 server_test.go   |  5 ++--
 websocket.go     |  2 +-
 wire.go          | 18 +++++-------
 wsjson/wsjson.go | 76 ++++++++++++++++++++++++++++++++++++++++++++++++
 10 files changed, 116 insertions(+), 33 deletions(-)
 create mode 100644 wsjson/wsjson.go

diff --git a/client.go b/client.go
index 9d1151b..03f3394 100644
--- a/client.go
+++ b/client.go
@@ -27,7 +27,6 @@ import (
 	"time"
 
 	"git.tuxpa.in/a/zlog/log"
-	jsoniter "github.com/json-iterator/go"
 )
 
 var (
@@ -286,7 +285,7 @@ func (c *Client) call(ctx context.Context, result any, msg *jsonrpcMessage) erro
 	case len(resp.Result) == 0:
 		return ErrNoResult
 	default:
-		return jsoniter.Unmarshal(resp.Result, &result)
+		return jzon.Unmarshal(resp.Result, &result)
 	}
 }
 
@@ -403,7 +402,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
 			elem.Error = ErrNoResult
 			continue
 		}
-		elem.Error = jsoniter.Unmarshal(resp.Result, elem.Result)
+		elem.Error = jzon.Unmarshal(resp.Result, elem.Result)
 	}
 
 	return err
@@ -428,7 +427,7 @@ func (c *Client) newMessage(method string, paramsIn ...any) (*jsonrpcMessage, er
 	msg := &jsonrpcMessage{ID: c.nextID(), Method: method}
 	if paramsIn != nil { // prevent sending "params":null
 		var err error
-		if msg.Params, err = jsoniter.Marshal(paramsIn); err != nil {
+		if msg.Params, err = jzon.Marshal(paramsIn); err != nil {
 			return nil, err
 		}
 	}
@@ -438,7 +437,7 @@ func (c *Client) newMessageP(method string, paramIn any) (*jsonrpcMessage, error
 	msg := &jsonrpcMessage{ID: c.nextID(), Method: method}
 	if paramIn != nil { // prevent sending "params":null
 		var err error
-		if msg.Params, err = jsoniter.Marshal(paramIn); err != nil {
+		if msg.Params, err = jzon.Marshal(paramIn); err != nil {
 			return nil, err
 		}
 	}
diff --git a/go.mod b/go.mod
index 7991ced..601e2de 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module gfx.cafe/open/jrpc
 go 1.18
 
 require (
+	gfx.cafe/util/go/bufpool v0.0.0-20220917112702-95618babdf53
 	git.tuxpa.in/a/zlog v1.32.0
 	github.com/davecgh/go-spew v1.1.1
 	github.com/deckarep/golang-set v1.8.0
diff --git a/go.sum b/go.sum
index 93485aa..da5523e 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+gfx.cafe/util/go/bufpool v0.0.0-20220917112702-95618babdf53 h1:j45c1YN77NyWrO0dN+e7lKJctXpC5TlVZWmww/PpFA0=
+gfx.cafe/util/go/bufpool v0.0.0-20220917112702-95618babdf53/go.mod h1:+DiyiCOBGS9O9Ce4ewHQO3Y59h66WSWAbgZZ2O2AYYw=
 git.tuxpa.in/a/zlog v1.32.0 h1:KKXbRF1x8kJDSzUoGz/pivo+4TVY6xT5sVtdFZ6traY=
 git.tuxpa.in/a/zlog v1.32.0/go.mod h1:vUa2Qhu6DLPLqmfRy99FiPqaY2eb6/KQjtMekW3UNnA=
 github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 h1:fLjPD/aNc3UIOA6tDi6QXUemppXK3P9BI7mr2hd6gx8=
diff --git a/http.go b/http.go
index 5de6bc1..553de3d 100644
--- a/http.go
+++ b/http.go
@@ -30,8 +30,6 @@ import (
 	"strings"
 	"sync"
 	"time"
-
-	jsoniter "github.com/json-iterator/go"
 )
 
 const (
@@ -180,7 +178,7 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
 }
 
 func (hc *httpConn) doRequest(ctx context.Context, msg any) (io.ReadCloser, error) {
-	body, err := jsoniter.Marshal(msg)
+	body, err := jzon.Marshal(msg)
 	if err != nil {
 		return nil, err
 	}
diff --git a/json.go b/json.go
index 044466a..37d15eb 100644
--- a/json.go
+++ b/json.go
@@ -31,7 +31,19 @@ import (
 	jsoniter "github.com/json-iterator/go"
 )
 
-var jzon = jsoniter.ConfigCompatibleWithStandardLibrary
+var jzon = jsoniter.Config{
+	IndentionStep:                 0,
+	MarshalFloatWith6Digits:       false,
+	EscapeHTML:                    true,
+	SortMapKeys:                   true,
+	UseNumber:                     false,
+	DisallowUnknownFields:         false,
+	TagKey:                        "",
+	OnlyTaggedField:               false,
+	ValidateJsonRawMessage:        true,
+	ObjectFieldMustBeSimpleString: false,
+	CaseSensitive:                 false,
+}.Froze()
 
 const (
 	vsn                      = "2.0"
@@ -206,7 +218,7 @@ func NewFuncCodec(conn deadlineCloser, encode, decode func(v any) error) ServerC
 // messages will use it to include the remote address of the connection.
 func NewCodec(conn Conn) ServerCodec {
 	enc := jzon.NewEncoder(conn)
-	dec := jzon.NewDecoder(conn)
+	dec := json.NewDecoder(conn)
 	dec.UseNumber()
 	return NewFuncCodec(conn, enc.Encode, dec.Decode)
 }
@@ -269,7 +281,7 @@ func (c *jsonCodec) closed() <-chan any {
 func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) {
 	if !isBatch(raw) {
 		msgs := []*jsonrpcMessage{{}}
-		jsoniter.Unmarshal(raw, &msgs[0])
+		jzon.Unmarshal(raw, &msgs[0])
 		return msgs, false
 	}
 	dec := json.NewDecoder(bytes.NewReader(raw))
diff --git a/protocol.go b/protocol.go
index a8fd134..40a4580 100644
--- a/protocol.go
+++ b/protocol.go
@@ -4,8 +4,6 @@ import (
 	"context"
 	"encoding/json"
 	"io"
-
-	jsoniter "github.com/json-iterator/go"
 )
 
 type HandlerFunc func(w ResponseWriter, r *Request)
@@ -31,7 +29,7 @@ type Request struct {
 
 func NewRequest(ctx context.Context, id string, method string, params any) *Request {
 	r := &Request{ctx: ctx}
-	pms, _ := jsoniter.Marshal(params)
+	pms, _ := jzon.Marshal(params)
 	r.msg = jsonrpcMessage{
 		ID:     NewStringIDPtr(id),
 		Method: method,
@@ -50,16 +48,16 @@ func (r *Request) Params() json.RawMessage {
 
 func (r *Request) ParamSlice() []any {
 	var params []any
-	jsoniter.Unmarshal(r.msg.Params, &params)
+	jzon.Unmarshal(r.msg.Params, &params)
 	return params
 }
 
 func (r *Request) ParamArray(a ...any) error {
 	var params []json.RawMessage
-	jsoniter.Unmarshal(r.msg.Params, &params)
+	jzon.Unmarshal(r.msg.Params, &params)
 	for idx, v := range params {
 		if len(v) > idx {
-			err := jsoniter.Unmarshal(v, &a[idx])
+			err := jzon.Unmarshal(v, &a[idx])
 			if err != nil {
 				return err
 			}
@@ -71,7 +69,7 @@ func (r *Request) ParamArray(a ...any) error {
 }
 
 func (r *Request) ParamInto(v any) error {
-	return jsoniter.Unmarshal(r.msg.Params, &v)
+	return jzon.Unmarshal(r.msg.Params, &v)
 }
 
 func (r *Request) Context() context.Context {
@@ -115,7 +113,7 @@ func NewReaderResponseWriterIo(r *Request, w io.Writer) ResponseWriter {
 }
 
 func (w *ResponseWriterIo) Send(args any, e error) (err error) {
-	enc := jsoniter.ConfigCompatibleWithStandardLibrary.NewEncoder(w.w)
+	enc := jzon.NewEncoder(w.w)
 	if e != nil {
 		return enc.Encode(errorMessage(e))
 	}
diff --git a/server_test.go b/server_test.go
index 7bee4d7..2e3f77a 100644
--- a/server_test.go
+++ b/server_test.go
@@ -116,10 +116,9 @@ func TestServerShortLivedConn(t *testing.T) {
 		conn.Write([]byte(request))
 		conn.(*net.TCPConn).CloseWrite()
 		// Now try to get the response.
-		buf := make([]byte, 2000)
-		n, err := conn.Read(buf)
+		buf, err := io.ReadAll(conn)
+		n := len(buf)
 		conn.Close()
-
 		if err != nil {
 			t.Fatal("read error:", err)
 		}
diff --git a/websocket.go b/websocket.go
index ec48136..ff80b81 100644
--- a/websocket.go
+++ b/websocket.go
@@ -24,9 +24,9 @@ import (
 	"sync"
 	"time"
 
+	"gfx.cafe/open/jrpc/wsjson"
 	"git.tuxpa.in/a/zlog/log"
 	"nhooyr.io/websocket"
-	"nhooyr.io/websocket/wsjson"
 )
 
 const (
diff --git a/wire.go b/wire.go
index 08957b5..39216c8 100644
--- a/wire.go
+++ b/wire.go
@@ -3,8 +3,6 @@ package jrpc
 import (
 	"encoding/json"
 	"fmt"
-
-	jsoniter "github.com/json-iterator/go"
 )
 
 // Version represents a JSON-RPC version.
@@ -23,13 +21,13 @@ var (
 
 // MarshalJSON implements json.Marshaler.
 func (version) MarshalJSON() ([]byte, error) {
-	return jsoniter.Marshal(Version)
+	return jzon.Marshal(Version)
 }
 
 // UnmarshalJSON implements json.Unmarshaler.
 func (version) UnmarshalJSON(data []byte) error {
 	version := ""
-	if err := jsoniter.Unmarshal(data, &version); err != nil {
+	if err := jzon.Unmarshal(data, &version); err != nil {
 		return fmt.Errorf("failed to Unmarshal: %w", err)
 	}
 	if version != Version {
@@ -95,12 +93,12 @@ func (id *ID) RawMessage() json.RawMessage {
 		return null
 	}
 	if id.name != "" {
-		ans, err := jsoniter.Marshal(id.name)
+		ans, err := jzon.Marshal(id.name)
 		if err == nil {
 			return ans
 		}
 	}
-	ans, err := jsoniter.Marshal(id.number)
+	ans, err := jzon.Marshal(id.number)
 	if err == nil {
 		return ans
 	}
@@ -116,18 +114,18 @@ func (id *ID) MarshalJSON() ([]byte, error) {
 		return null, nil
 	}
 	if id.name != "" {
-		return jsoniter.Marshal(id.name)
+		return jzon.Marshal(id.name)
 	}
-	return jsoniter.Marshal(id.number)
+	return jzon.Marshal(id.number)
 }
 
 // UnmarshalJSON implements json.Unmarshaler.
 func (id *ID) UnmarshalJSON(data []byte) error {
 	*id = ID{}
-	if err := jsoniter.Unmarshal(data, &id.number); err == nil {
+	if err := jzon.Unmarshal(data, &id.number); err == nil {
 		return nil
 	}
-	if err := jsoniter.Unmarshal(data, &id.name); err == nil {
+	if err := jzon.Unmarshal(data, &id.name); err == nil {
 		return nil
 	}
 	id.null = true
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
new file mode 100644
index 0000000..7394181
--- /dev/null
+++ b/wsjson/wsjson.go
@@ -0,0 +1,76 @@
+package wsjson
+
+import (
+	"context"
+	"fmt"
+
+	"gfx.cafe/util/go/bufpool"
+	jsoniter "github.com/json-iterator/go"
+	"nhooyr.io/websocket"
+)
+
+var jzon = jsoniter.Config{
+	IndentionStep:                 0,
+	MarshalFloatWith6Digits:       false,
+	EscapeHTML:                    true,
+	SortMapKeys:                   true,
+	UseNumber:                     false,
+	DisallowUnknownFields:         false,
+	TagKey:                        "",
+	OnlyTaggedField:               false,
+	ValidateJsonRawMessage:        true,
+	ObjectFieldMustBeSimpleString: false,
+	CaseSensitive:                 false,
+}.Froze()
+
+// Read reads a JSON message from c into v.
+// It will reuse buffers in between calls to avoid allocations.
+func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
+	return read(ctx, c, v)
+}
+
+func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
+
+	_, r, err := c.Reader(ctx)
+	if err != nil {
+		return err
+	}
+
+	b := bufpool.Get(512)
+	defer bufpool.Put(b)
+
+	_, err = b.ReadFrom(r)
+	if err != nil {
+		return err
+	}
+
+	err = jzon.Unmarshal(b.Bytes(), v)
+	if err != nil {
+		return fmt.Errorf("failed to unmarshal JSON: %w", err)
+	}
+
+	return nil
+}
+
+// Write writes the JSON message v to c.
+// It will reuse buffers in between calls to avoid allocations.
+func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
+	return write(ctx, c, v)
+}
+
+func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
+
+	w, err := c.Writer(ctx, websocket.MessageText)
+	if err != nil {
+		return err
+	}
+
+	// json.Marshal cannot reuse buffers between calls as it has to return
+	// a copy of the byte slice but Encoder does as it directly writes to w.
+	err = jzon.NewEncoder(w).Encode(v)
+	if err != nil {
+		return fmt.Errorf("failed to marshal JSON: %w", err)
+	}
+
+	return w.Close()
+}
-- 
GitLab