From 9dfe16e1133ab6dbc7b86915cce6b5464189bd0c Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Mon, 25 Mar 2024 13:04:44 -0500
Subject: [PATCH] a

---
 pkg/jjson/json.go     | 41 ++++++++++++++++++++++++++++++++---------
 pkg/jsonrpc/encode.go |  8 --------
 2 files changed, 32 insertions(+), 17 deletions(-)

diff --git a/pkg/jjson/json.go b/pkg/jjson/json.go
index 5bda960..8d685b6 100644
--- a/pkg/jjson/json.go
+++ b/pkg/jjson/json.go
@@ -2,6 +2,7 @@ package jjson
 
 import (
 	"bytes"
+	"encoding/json"
 	"io"
 
 	jsoniter "github.com/json-iterator/go"
@@ -29,26 +30,48 @@ func MarshalAndEncode(w io.Writer, v any) error {
 func Encode(w io.Writer, v any) error {
 	s := jConfig.BorrowStream(w)
 	defer jConfig.ReturnStream(s)
-	s.WriteVal(v)
-	return s.Flush()
+	switch cast := (v).(type) {
+	case func(e *jsoniter.Stream):
+		cast(s)
+		return s.Flush()
+	case json.Marshaler:
+		s.WriteVal(v)
+		return s.Flush()
+	case io.Reader:
+		_, err := io.Copy(w, cast)
+		if err != nil {
+			return err
+		}
+		return nil
+	default:
+		s.WriteVal(v)
+		return s.Flush()
+	}
 }
 
 func Decode(r io.Reader, v any) error {
 	d := jConfig.NewDecoder(r)
-	return d.Decode(v)
+	switch cast := (v).(type) {
+	case json.Unmarshaler:
+		return d.Decode(v)
+	case io.Writer:
+		_, err := io.Copy(cast, r)
+		if err != nil {
+			return err
+		}
+		return nil
+	default:
+		return d.Decode(v)
+	}
 }
 
 func Unmarshal(xs []byte, v any) error {
-	d := jConfig.NewDecoder(bytes.NewBuffer(xs))
-	return d.Decode(v)
+	return Decode(bytes.NewBuffer(xs), v)
 }
 
 func Marshal(v any) ([]byte, error) {
 	out := &bytes.Buffer{}
-	s := jConfig.BorrowStream(out)
-	defer jConfig.ReturnStream(s)
-	s.WriteVal(v)
-	err := s.Flush()
+	err := Encode(out, v)
 	if err != nil {
 		return nil, err
 	}
diff --git a/pkg/jsonrpc/encode.go b/pkg/jsonrpc/encode.go
index 7553fff..e7327c3 100644
--- a/pkg/jsonrpc/encode.go
+++ b/pkg/jsonrpc/encode.go
@@ -31,14 +31,6 @@ func EncodeObject(wr io.Writer, dat any) error {
 			}
 		}
 		return nil
-	case json.Marshaler:
-		return jjson.Encode(wr, cast)
-	case io.Reader:
-		_, err := io.Copy(wr, cast)
-		if err != nil {
-			return err
-		}
-		return nil
 	default:
 		return jjson.Encode(wr, cast)
 	}
-- 
GitLab