From 27ec18a7ff72c12db3ebb074bd8ddc2b4ac1fda7 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Wed, 24 Apr 2019 19:48:32 -0400
Subject: [PATCH] Add JSON and ProtoBuf helpers and improve docs

---
 README.md              |   3 +-
 accept.go              |   4 +-
 ci/bench/entrypoint.sh |   2 +-
 dial.go                |   1 +
 doc.go                 |   2 +-
 example_echo_test.go   |  40 ++------
 go.mod                 |   1 +
 go.sum                 |   2 +
 websocket.go           |   5 +-
 websocket_test.go      | 216 ++++++++++++++++++++++++++---------------
 wsjson/wsjson.go       |  71 ++++++++++++++
 wspb/wspb.go           |  80 +++++++++++++++
 12 files changed, 313 insertions(+), 114 deletions(-)
 create mode 100644 wsjson/wsjson.go
 create mode 100644 wspb/wspb.go

diff --git a/README.md b/README.md
index db96290..11309bc 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,8 @@ go get nhooyr.io/websocket
 - First class context.Context support
 - Thoroughly tested, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
 - Concurrent writes
-- Zero dependencies outside of the stdlib
+- Zero dependencies outside of the stdlib for the core library
+- JSON and ProtoBuf helpers in the wsjson and wspb subpackages
 
 ## Roadmap
 
diff --git a/accept.go b/accept.go
index 4fb9808..9cf546f 100644
--- a/accept.go
+++ b/accept.go
@@ -34,8 +34,8 @@ type AcceptOptions struct {
 	// The only time you need this is if your javascript is running on a different domain
 	// than your WebSocket server.
 	// Please think carefully about whether you really need this option before you use it.
-	// If you do, remember if you store secure data in cookies, you wil need to verify the
-	// Origin header.
+	// If you do, remember that if you store secure data in cookies, you wil need to verify the
+	// Origin header yourself otherwise you are exposing yourself to a CSRF attack.
 	InsecureSkipVerify bool
 }
 
diff --git a/ci/bench/entrypoint.sh b/ci/bench/entrypoint.sh
index 0e32cd4..5f7dcf7 100755
--- a/ci/bench/entrypoint.sh
+++ b/ci/bench/entrypoint.sh
@@ -9,7 +9,7 @@ go test --vet=off --run=^$ -bench=. \
 	-memprofile=profs/mem \
 	-blockprofile=profs/block \
 	-mutexprofile=profs/mutex \
-	./...
+	.
 
 set +x
 echo
diff --git a/dial.go b/dial.go
index 909990c..e2eacb2 100644
--- a/dial.go
+++ b/dial.go
@@ -23,6 +23,7 @@ type DialOptions struct {
 	HTTPClient *http.Client
 
 	// Header specifies the HTTP headers included in the handshake request.
+	// TODO rename to HTTPHeader
 	Header http.Header
 
 	// Subprotocols lists the subprotocols to negotiate with the server.
diff --git a/doc.go b/doc.go
index 246170a..8a4d040 100644
--- a/doc.go
+++ b/doc.go
@@ -2,7 +2,7 @@
 //
 // See https://tools.ietf.org/html/rfc6455
 //
-// Please see https://nhooyr.io/websocket for thorough overview docs and a
+// Please see https://nhooyr.io/websocket for overview docs and a
 // comparison with existing implementations.
 //
 // Conn, Dial, and Accept are the main entrypoints into this package. Use Dial to dial
diff --git a/example_echo_test.go b/example_echo_test.go
index d2867a4..f424eef 100644
--- a/example_echo_test.go
+++ b/example_echo_test.go
@@ -2,10 +2,8 @@ package websocket_test
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"log"
 	"net"
 	"net/http"
@@ -15,6 +13,7 @@ import (
 	"golang.org/x/xerrors"
 
 	"nhooyr.io/websocket"
+	"nhooyr.io/websocket/wsjson"
 )
 
 // Example_echo starts a WebSocket echo server and
@@ -58,11 +57,11 @@ func Example_echo() {
 	}
 
 	// Output:
-	// {"i":0}
-	// {"i":1}
-	// {"i":2}
-	// {"i":3}
-	// {"i":4}
+	// 0
+	// 1
+	// 2
+	// 3
+	// 4
 }
 
 // echoServer is the WebSocket echo server implementation.
@@ -142,39 +141,20 @@ func client(url string) error {
 	defer c.Close(websocket.StatusInternalError, "")
 
 	for i := 0; i < 5; i++ {
-		w, err := c.Writer(ctx, websocket.MessageText)
-		if err != nil {
-			return err
-		}
-
-		e := json.NewEncoder(w)
-		err = e.Encode(map[string]int{
+		err = wsjson.Write(ctx, c, map[string]int{
 			"i": i,
 		})
 		if err != nil {
 			return err
 		}
 
-		err = w.Close()
-		if err != nil {
-			return err
-		}
-
-		typ, r, err := c.Reader(ctx)
-		if err != nil {
-			return err
-		}
-
-		if typ != websocket.MessageText {
-			return xerrors.Errorf("expected text message but got %v", typ)
-		}
-
-		msg2, err := ioutil.ReadAll(r)
+		v := map[string]int{}
+		err = wsjson.Read(ctx, c, &v)
 		if err != nil {
 			return err
 		}
 
-		fmt.Printf("%s", msg2)
+		fmt.Printf("%v\n", v["i"])
 	}
 
 	c.Close(websocket.StatusNormalClosure, "")
diff --git a/go.mod b/go.mod
index 928137e..f39fe6f 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module nhooyr.io/websocket
 go 1.12
 
 require (
+	github.com/golang/protobuf v1.3.1
 	github.com/google/go-cmp v0.2.0
 	github.com/kr/pretty v0.1.0 // indirect
 	go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16
diff --git a/go.sum b/go.sum
index 0e10a2c..3d455a2 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
+github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
 github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
 github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
diff --git a/websocket.go b/websocket.go
index 21c4ef3..287bf3e 100644
--- a/websocket.go
+++ b/websocket.go
@@ -489,9 +489,12 @@ func (w messageWriter) Close() error {
 // Reader will wait until there is a WebSocket data message to read from the connection.
 // It returns the type of the message and a reader to read it.
 // The passed context will also bound the reader.
+//
 // Your application must keep reading messages for the Conn to automatically respond to ping
 // and close frames and not become stuck waiting for a data message to be read.
-// Please ensure to read the full message from io.Reader.
+// Please ensure to read the full message from io.Reader. If you do not read till
+// io.EOF, the connection will break unless the next read would have yielded io.EOF.
+//
 // You can only read a single message at a time so do not call this method
 // concurrently.
 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
diff --git a/websocket_test.go b/websocket_test.go
index be03cb0..f0d58ac 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -12,16 +12,21 @@ import (
 	"net/url"
 	"os"
 	"os/exec"
+	"reflect"
 	"strconv"
 	"strings"
 	"sync/atomic"
 	"testing"
 	"time"
 
+	"github.com/golang/protobuf/ptypes"
+	"github.com/golang/protobuf/ptypes/duration"
 	"github.com/google/go-cmp/cmp"
 	"golang.org/x/xerrors"
 
 	"nhooyr.io/websocket"
+	"nhooyr.io/websocket/wsjson"
+	"nhooyr.io/websocket/wspb"
 )
 
 func TestHandshake(t *testing.T) {
@@ -201,84 +206,139 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 		},
-		// {
-		// 	name: "echo",
-		// 	server: func(w http.ResponseWriter, r *http.Request) error {
-		// 		c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-		// 		if err != nil {
-		// 			return err
-		// 		}
-		// 		defer c.Close(websocket.StatusInternalError, "")
-		//
-		// 		ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
-		// 		defer cancel()
-		//
-		// 		write := func() error {
-		// 			jc := websocket.JSONConn{
-		// 				C: c,
-		// 			}
-		//
-		// 			v := map[string]interface{}{
-		// 				"anmol": "wowow",
-		// 			}
-		// 			err = jc.Write(ctx, v)
-		// 			if err != nil {
-		// 				return err
-		// 			}
-		// 			return nil
-		// 		}
-		// 		err = write()
-		// 		if err != nil {
-		// 			return err
-		// 		}
-		// 		err = write()
-		// 		if err != nil {
-		// 			return err
-		// 		}
-		//
-		// 		c.Close(websocket.StatusNormalClosure, "")
-		// 		return nil
-		// 	},
-		// 	client: func(ctx context.Context, u string) error {
-		// 		c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
-		// 		if err != nil {
-		// 			return err
-		// 		}
-		// 		defer c.Close(websocket.StatusInternalError, "")
-		//
-		// 		jc := websocket.JSONConn{
-		// 			C: c,
-		// 		}
-		//
-		// 		read := func() error {
-		// 			var v interface{}
-		// 			err = jc.Read(ctx, &v)
-		// 			if err != nil {
-		// 				return err
-		// 			}
-		//
-		// 			exp := map[string]interface{}{
-		// 				"anmol": "wowow",
-		// 			}
-		// 			if !reflect.DeepEqual(exp, v) {
-		// 				return xerrors.Errorf("expected %v but got %v", exp, v)
-		// 			}
-		// 			return nil
-		// 		}
-		// 		err = read()
-		// 		if err != nil {
-		// 			return err
-		// 		}
-		// 		// Read twice to ensure the un EOFed previous reader works correctly.
-		// 		err = read()
-		// 		if err != nil {
-		// 			return err
-		// 		}
-		//
-		// 		c.Close(websocket.StatusNormalClosure, "")
-		// 		return nil
-		// 	},
-		// },
+		{
+			name: "jsonEcho",
+			server: func(w http.ResponseWriter, r *http.Request) error {
+				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				if err != nil {
+					return err
+				}
+				defer c.Close(websocket.StatusInternalError, "")
+
+				ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
+				defer cancel()
+
+				write := func() error {
+					v := map[string]interface{}{
+						"anmol": "wowow",
+					}
+					err := wsjson.Write(ctx, c, v)
+					return err
+				}
+				err = write()
+				if err != nil {
+					return err
+				}
+				err = write()
+				if err != nil {
+					return err
+				}
+
+				c.Close(websocket.StatusNormalClosure, "")
+				return nil
+			},
+			client: func(ctx context.Context, u string) error {
+				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
+				if err != nil {
+					return err
+				}
+				defer c.Close(websocket.StatusInternalError, "")
+
+				read := func() error {
+					var v interface{}
+					err := wsjson.Read(ctx, c, &v)
+					if err != nil {
+						return err
+					}
+
+					exp := map[string]interface{}{
+						"anmol": "wowow",
+					}
+					if !reflect.DeepEqual(exp, v) {
+						return xerrors.Errorf("expected %v but got %v", exp, v)
+					}
+					return nil
+				}
+				err = read()
+				if err != nil {
+					return err
+				}
+				// Read twice to ensure the un EOFed previous reader works correctly.
+				err = read()
+				if err != nil {
+					return err
+				}
+
+				c.Close(websocket.StatusNormalClosure, "")
+				return nil
+			},
+		},
+		{
+			name: "protobufEcho",
+			server: func(w http.ResponseWriter, r *http.Request) error {
+				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				if err != nil {
+					return err
+				}
+				defer c.Close(websocket.StatusInternalError, "")
+
+				ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
+				defer cancel()
+
+				write := func() error {
+					err := wspb.Write(ctx, c, ptypes.DurationProto(100))
+					return err
+				}
+				err = write()
+				if err != nil {
+					return err
+				}
+				err = write()
+				if err != nil {
+					return err
+				}
+
+				c.Close(websocket.StatusNormalClosure, "")
+				return nil
+			},
+			client: func(ctx context.Context, u string) error {
+				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
+				if err != nil {
+					return err
+				}
+				defer c.Close(websocket.StatusInternalError, "")
+
+				read := func() error {
+					var v duration.Duration
+					err := wspb.Read(ctx, c, &v)
+					if err != nil {
+						return err
+					}
+
+					d, err := ptypes.Duration(&v)
+					if err != nil {
+						return xerrors.Errorf("failed to convert duration.Duration to time.Duration: %w", err)
+					}
+					const exp = time.Duration(100)
+					if !reflect.DeepEqual(exp, d) {
+						return xerrors.Errorf("expected %v but got %v", exp, d)
+					}
+					return nil
+				}
+				err = read()
+				if err != nil {
+					return err
+				}
+				// Read twice to ensure the un EOFed previous reader works correctly.
+				err = read()
+				if err != nil {
+					return err
+				}
+
+				c.Close(websocket.StatusNormalClosure, "")
+				return nil
+			},
+		},
 		{
 			name: "cookies",
 			server: func(w http.ResponseWriter, r *http.Request) error {
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
new file mode 100644
index 0000000..df67cf9
--- /dev/null
+++ b/wsjson/wsjson.go
@@ -0,0 +1,71 @@
+// Package wsjson provides helpers for JSON messages.
+package wsjson
+
+import (
+	"context"
+	"encoding/json"
+	"io"
+
+	"golang.org/x/xerrors"
+
+	"nhooyr.io/websocket"
+)
+
+// Read reads a json message from c into v.
+// It will read a message up to 32768 bytes in length.
+func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
+	err := read(ctx, c, v)
+	if err != nil {
+		return xerrors.Errorf("failed to read json: %w", err)
+	}
+	return nil
+}
+
+func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
+	typ, r, err := c.Reader(ctx)
+	if err != nil {
+		return err
+	}
+
+	if typ != websocket.MessageText {
+		return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ)
+	}
+
+	r = io.LimitReader(r, 32768)
+
+	d := json.NewDecoder(r)
+	err = d.Decode(v)
+	if err != nil {
+		return xerrors.Errorf("failed to decode json: %w", err)
+	}
+
+	return nil
+}
+
+// Write writes the json message v to c.
+func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
+	err := write(ctx, c, v)
+	if err != nil {
+		return xerrors.Errorf("failed to write json: %w", err)
+	}
+	return nil
+}
+
+func write(ctx context.Context, c *websocket.Conn, v interface{}) error {
+	w, err := c.Writer(ctx, websocket.MessageText)
+	if err != nil {
+		return err
+	}
+
+	e := json.NewEncoder(w)
+	err = e.Encode(v)
+	if err != nil {
+		return xerrors.Errorf("failed to encode json: %w", err)
+	}
+
+	err = w.Close()
+	if err != nil {
+		return err
+	}
+	return nil
+}
diff --git a/wspb/wspb.go b/wspb/wspb.go
new file mode 100644
index 0000000..159e92d
--- /dev/null
+++ b/wspb/wspb.go
@@ -0,0 +1,80 @@
+// Package wspb provides helpers for protobuf messages.
+package wspb
+
+import (
+	"context"
+	"io"
+	"io/ioutil"
+
+	"github.com/golang/protobuf/proto"
+	"golang.org/x/xerrors"
+
+	"nhooyr.io/websocket"
+)
+
+// Read reads a protobuf message from c into v.
+// It will read a message up to 32768 bytes in length.
+func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
+	err := read(ctx, c, v)
+	if err != nil {
+		return xerrors.Errorf("failed to read protobuf: %w", err)
+	}
+	return nil
+}
+
+func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
+	typ, r, err := c.Reader(ctx)
+	if err != nil {
+		return err
+	}
+
+	if typ != websocket.MessageBinary {
+		return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ)
+	}
+
+	r = io.LimitReader(r, 32768)
+
+	b, err := ioutil.ReadAll(r)
+	if err != nil {
+		return xerrors.Errorf("failed to read message: %w", err)
+	}
+
+	err = proto.Unmarshal(b, v)
+	if err != nil {
+		return xerrors.Errorf("failed to unmarshal protobuf: %w", err)
+	}
+
+	return nil
+}
+
+// Write writes the protobuf message v to c.
+func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
+	err := write(ctx, c, v)
+	if err != nil {
+		return xerrors.Errorf("failed to write protobuf: %w", err)
+	}
+	return nil
+}
+
+func write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
+	b, err := proto.Marshal(v)
+	if err != nil {
+		return xerrors.Errorf("failed to marshal protobuf: %w", err)
+	}
+
+	w, err := c.Writer(ctx, websocket.MessageBinary)
+	if err != nil {
+		return err
+	}
+
+	_, err = w.Write(b)
+	if err != nil {
+		return err
+	}
+
+	err = w.Close()
+	if err != nil {
+		return err
+	}
+	return nil
+}
-- 
GitLab