From 0e788123439ef0bc7a4e0736b31956df5044ad26 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 30 Mar 2019 23:04:10 -0500
Subject: [PATCH] Improve JSON API

Closes #50
---
 README.md          | 14 +++++++++++---
 datatype_string.go |  4 ++--
 example_test.go    | 12 ++++++++++--
 json.go            | 43 ++++++++++++++++++++++++++++++++-----------
 websocket_test.go  | 12 ++++++++++--
 5 files changed, 65 insertions(+), 20 deletions(-)

diff --git a/README.md b/README.md
index 34bd480..d2261bb 100644
--- a/README.md
+++ b/README.md
@@ -45,13 +45,17 @@ fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	}
 	defer c.Close(websocket.StatusInternalError, "")
 
+	jc := websocket.JSONConn{
+		Conn: c,
+	}
+
 	ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
 	defer cancel()
 
 	v := map[string]interface{}{
 		"my_field": "foo",
 	}
-	err = websocket.WriteJSON(ctx, c, v)
+	err = jc.Write(ctx, v)
 	if err != nil {
 		log.Printf("failed to write json: %v", err)
 		return
@@ -73,7 +77,7 @@ For a production quality example that shows off the low level API, see the [echo
 
 ```go
 ctx := context.Background()
-ctx, cancel := context.WithTimeout(ctx, time.Second*10)
+ctx, cancel := context.WithTimeout(ctx, time.Minute)
 defer cancel()
 
 c, _, err := websocket.Dial(ctx, "ws://localhost:8080",
@@ -84,8 +88,12 @@ if err != nil {
 }
 defer c.Close(websocket.StatusInternalError, "")
 
+jc := websocket.JSONConn{
+	Conn: c,
+}
+
 var v interface{}
-err = websocket.ReadJSON(ctx, c, v)
+err = jc.Read(ctx, v)
 if err != nil {
 	log.Fatalf("failed to read json: %v", err)
 }
diff --git a/datatype_string.go b/datatype_string.go
index 60a85c3..1b4aaba 100644
--- a/datatype_string.go
+++ b/datatype_string.go
@@ -12,9 +12,9 @@ func _() {
 	_ = x[DataBinary-2]
 }
 
-const _DataType_name = "TextBinary"
+const _DataType_name = "DataTextDataBinary"
 
-var _DataType_index = [...]uint8{0, 4, 10}
+var _DataType_index = [...]uint8{0, 8, 18}
 
 func (i DataType) String() string {
 	i -= 1
diff --git a/example_test.go b/example_test.go
index 0b15fab..5e6d072 100644
--- a/example_test.go
+++ b/example_test.go
@@ -88,13 +88,17 @@ func ExampleAccept() {
 		}
 		defer c.Close(websocket.StatusInternalError, "")
 
+		jc := websocket.JSONConn{
+			Conn: c,
+		}
+
 		ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
 		defer cancel()
 
 		v := map[string]interface{}{
 			"my_field": "foo",
 		}
-		err = websocket.WriteJSON(ctx, c, v)
+		err = jc.Write(ctx, v)
 		if err != nil {
 			log.Printf("failed to write json: %v", err)
 			return
@@ -123,8 +127,12 @@ func ExampleDial() {
 	}
 	defer c.Close(websocket.StatusInternalError, "")
 
+	jc := websocket.JSONConn{
+		Conn: c,
+	}
+
 	var v interface{}
-	err = websocket.ReadJSON(ctx, c, v)
+	err = jc.Read(ctx, v)
 	if err != nil {
 		log.Fatalf("failed to read json: %v", err)
 	}
diff --git a/json.go b/json.go
index ca4ac92..ebe0dfd 100644
--- a/json.go
+++ b/json.go
@@ -7,15 +7,28 @@ import (
 	"golang.org/x/xerrors"
 )
 
-// ReadJSON reads a json message from c into v.
-func ReadJSON(ctx context.Context, c *Conn, v interface{}) error {
-	typ, r, err := c.ReadMessage(ctx)
+// JSONConn wraps around a Conn with JSON helpers.
+type JSONConn struct {
+	Conn *Conn
+}
+
+// Read reads a json message into v.
+func (jc JSONConn) Read(ctx context.Context, v interface{}) error {
+	err := jc.read(ctx, v)
 	if err != nil {
 		return xerrors.Errorf("failed to read json: %w", err)
 	}
+	return nil
+}
+
+func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
+	typ, r, err := jc.Conn.ReadMessage(ctx)
+	if err != nil {
+		return err
+	}
 
-	if typ != Text {
-		return xerrors.Errorf("unexpected frame type for json (expected TextFrame): %v", typ)
+	if typ != DataText {
+		return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ)
 	}
 
 	r.Limit(131072)
@@ -24,25 +37,33 @@ func ReadJSON(ctx context.Context, c *Conn, v interface{}) error {
 	d := json.NewDecoder(r)
 	err = d.Decode(v)
 	if err != nil {
-		return xerrors.Errorf("failed to read json: %w", err)
+		return xerrors.Errorf("failed to decode json: %w", err)
 	}
 	return nil
 }
 
-// WriteJSON writes the json message v into c.
-func WriteJSON(ctx context.Context, c *Conn, v interface{}) error {
-	w := c.MessageWriter(Text)
+// Write writes the json message v.
+func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
+	err := jc.write(ctx, v)
+	if err != nil {
+		return xerrors.Errorf("failed to write json: %w", err)
+	}
+	return nil
+}
+
+func (jc JSONConn) write(ctx context.Context, v interface{}) error {
+	w := jc.Conn.MessageWriter(DataText)
 	w.SetContext(ctx)
 
 	e := json.NewEncoder(w)
 	err := e.Encode(v)
 	if err != nil {
-		return xerrors.Errorf("failed to write json: %w", err)
+		return xerrors.Errorf("failed to encode json: %w", err)
 	}
 
 	err = w.Close()
 	if err != nil {
-		return xerrors.Errorf("failed to write json: %w", err)
+		return err
 	}
 	return nil
 }
diff --git a/websocket_test.go b/websocket_test.go
index 61384af..e91e5b2 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -173,10 +173,14 @@ func TestHandshake(t *testing.T) {
 				ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
 				defer cancel()
 
+				jc := websocket.JSONConn{
+					Conn: c,
+				}
+
 				v := map[string]interface{}{
 					"anmol": "wowow",
 				}
-				err = websocket.WriteJSON(ctx, c, v)
+				err = jc.Write(ctx, v)
 				if err != nil {
 					return err
 				}
@@ -191,8 +195,12 @@ func TestHandshake(t *testing.T) {
 				}
 				defer c.Close(websocket.StatusInternalError, "")
 
+				jc := websocket.JSONConn{
+					Conn: c,
+				}
+
 				var v interface{}
-				err = websocket.ReadJSON(ctx, c, &v)
+				err = jc.Read(ctx, &v)
 				if err != nil {
 					return err
 				}
-- 
GitLab