From 78da35ec5b221d5ec664ee9cbf0a8fb034d46f4c Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Fri, 7 Feb 2020 00:58:57 -0600
Subject: [PATCH] Get test with multiple messages working

---
 README.md        |  2 +-
 assert_test.go   | 79 +++++++++++++++++++++++++++++++-----------
 autobahn_test.go | 16 +++++----
 conn_test.go     | 89 +++++++++++++++++-------------------------------
 example_test.go  |  2 --
 read.go          | 11 +++---
 write.go         |  1 +
 ws_js_test.go    |  2 +-
 8 files changed, 109 insertions(+), 93 deletions(-)

diff --git a/README.md b/README.md
index e958d2a..2569383 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ go get nhooyr.io/websocket
 - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
 - [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API
 - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
-- Can target [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm)
+- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm)
 
 ## Roadmap
 
diff --git a/assert_test.go b/assert_test.go
index 6cfd926..3727d99 100644
--- a/assert_test.go
+++ b/assert_test.go
@@ -3,10 +3,15 @@ package websocket_test
 import (
 	"context"
 	"crypto/rand"
+	"fmt"
+	"net/http"
+	"net/http/httptest"
 	"strings"
 	"testing"
+	"time"
 
 	"cdr.dev/slog"
+	"cdr.dev/slog/sloggers/slogtest"
 	"cdr.dev/slog/sloggers/slogtest/assert"
 
 	"nhooyr.io/websocket"
@@ -20,26 +25,31 @@ func randBytes(t *testing.T, n int) []byte {
 	return b
 }
 
-func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) {
-	t.Helper()
-	defer c.Close(websocket.StatusInternalError, "")
+func echoJSON(t *testing.T, c *websocket.Conn, n int) {
+	slog.Helper()
 
-	exp := randString(t, n)
-	err := wsjson.Write(ctx, c, exp)
-	assert.Success(t, "wsjson.Write", err)
+	s := randString(t, n)
+	writeJSON(t, c, s)
+	readJSON(t, c, s)
+}
 
-	assertJSONRead(t, ctx, c, exp)
+func writeJSON(t *testing.T, c *websocket.Conn, v interface{}) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+	defer cancel()
 
-	c.Close(websocket.StatusNormalClosure, "")
+	err := wsjson.Write(ctx, c, v)
+	assert.Success(t, "wsjson.Write", err)
 }
 
-func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
+func readJSON(t *testing.T, c *websocket.Conn, exp interface{}) {
 	slog.Helper()
 
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+	defer cancel()
+
 	var act interface{}
 	err := wsjson.Read(ctx, c, &act)
 	assert.Success(t, "wsjson.Read", err)
-
 	assert.Equal(t, "json", exp, act)
 }
 
@@ -58,7 +68,7 @@ func randString(t *testing.T, n int) string {
 }
 
 func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) {
-	t.Helper()
+	slog.Helper()
 
 	p := randBytes(t, n)
 	err := c.Write(ctx, typ, p)
@@ -72,17 +82,46 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc
 }
 
 func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) {
-	t.Helper()
+	slog.Helper()
 
 	assert.Equal(t, "subprotocol", exp, c.Subprotocol())
 }
 
-func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) {
-	t.Helper()
-	defer func() {
-		if t.Failed() {
-			t.Logf("error: %+v", err)
-		}
-	}()
-	assert.Equal(t, "closeStatus", exp, websocket.CloseStatus(err))
+func assertCloseStatus(t testing.TB, exp websocket.StatusCode, err error) {
+	slog.Helper()
+
+	if websocket.CloseStatus(err) == -1 {
+		slogtest.Fatal(t, "expected websocket.CloseError", slogType(err), slog.Error(err))
+	}
+	if websocket.CloseStatus(err) != exp {
+		slogtest.Error(t, "unexpected close status",
+			slog.F("exp", exp),
+			slog.F("act", err),
+		)
+	}
+
+}
+
+func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts *websocket.AcceptOptions) *websocket.Conn {
+	c, err := websocket.Accept(w, r, opts)
+	assert.Success(t, "websocket.Accept", err)
+	return c
+}
+
+func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+	defer cancel()
+
+	if opts == nil {
+		opts = &websocket.DialOptions{}
+	}
+	opts.HTTPClient = s.Client()
+
+	c, resp, err := websocket.Dial(ctx, wsURL(s), opts)
+	assert.Success(t, "websocket.Dial", err)
+	return c, resp
+}
+
+func slogType(v interface{}) slog.Field {
+	return slog.F("type", fmt.Sprintf("%T", v))
 }
diff --git a/autobahn_test.go b/autobahn_test.go
index bcbf867..dd9887f 100644
--- a/autobahn_test.go
+++ b/autobahn_test.go
@@ -9,6 +9,7 @@ import (
 	"io/ioutil"
 	"net"
 	"net/http"
+	"net/http/httptest"
 	"os"
 	"os/exec"
 	"strconv"
@@ -53,15 +54,18 @@ func TestAutobahn(t *testing.T) {
 func testServerAutobahn(t *testing.T) {
 	t.Parallel()
 
-	s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{
 			Subprotocols: []string{"echo"},
 		})
-		assert.Success(t, "accept", err)
-		err = echoLoop(r.Context(), c)
+		err := echoLoop(r.Context(), c)
 		assertCloseStatus(t, websocket.StatusNormalClosure, err)
-	}, false)
-	defer closeFn()
+	}))
+	closeFn := wsgrace(s.Config)
+	defer func() {
+		err := closeFn()
+		assert.Success(t, "closeFn", err)
+	}()
 
 	specFile, err := tempJSONFile(map[string]interface{}{
 		"outdir": "ci/out/wstestServerReports",
diff --git a/conn_test.go b/conn_test.go
index 4720cba..6f6b8d5 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -4,7 +4,9 @@ package websocket_test
 
 import (
 	"context"
+	"crypto/rand"
 	"io"
+	"math/big"
 	"net/http"
 	"net/http/httptest"
 	"strings"
@@ -18,77 +20,32 @@ import (
 	"nhooyr.io/websocket"
 )
 
-func TestFuzz(t *testing.T) {
-	t.Parallel()
-
-	s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-			CompressionOptions: websocket.CompressionOptions{
-				Mode: websocket.CompressionContextTakeover,
-			},
-		})
-		assert.Success(t, "accept", err)
-		defer c.Close(websocket.StatusInternalError, "")
-
-		err = echoLoop(r.Context(), c)
-		assertCloseStatus(t, websocket.StatusNormalClosure, err)
-	}, false)
-	defer closeFn()
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-	defer cancel()
-
-	opts := &websocket.DialOptions{
-		CompressionOptions: websocket.CompressionOptions{
-			Mode: websocket.CompressionContextTakeover,
-		},
-	}
-	opts.HTTPClient = s.Client()
-
-	c, _, err := websocket.Dial(ctx, wsURL(s), opts)
-	assert.Success(t, "dial", err)
-	assertJSONEcho(t, ctx, c, 8393)
-}
-
 func TestConn(t *testing.T) {
 	t.Parallel()
 
 	t.Run("json", func(t *testing.T) {
-		s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
-			c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-				Subprotocols: []string{"echo"},
-				CompressionOptions: websocket.CompressionOptions{
-					Mode: websocket.CompressionContextTakeover,
-				},
-			})
-			assert.Success(t, "accept", err)
-			defer c.Close(websocket.StatusInternalError, "")
-
-			err = echoLoop(r.Context(), c)
-			assertCloseStatus(t, websocket.StatusNormalClosure, err)
-		}, false)
+		t.Parallel()
+
+		s, closeFn := testEchoLoop(t)
 		defer closeFn()
 
-		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-		defer cancel()
+		c, _ := dialWebSocket(t, s, nil)
+		defer c.Close(websocket.StatusInternalError, "")
 
-		opts := &websocket.DialOptions{
-			Subprotocols: []string{"echo"},
-			CompressionOptions: websocket.CompressionOptions{
-				Mode: websocket.CompressionContextTakeover,
-			},
+		c.SetReadLimit(1 << 30)
+
+		for i := 0; i < 10; i++ {
+			n := randInt(t, 1_048_576)
+			echoJSON(t, c, n)
 		}
-		opts.HTTPClient = s.Client()
 
-		c, _, err := websocket.Dial(ctx, wsURL(s), opts)
-		assert.Success(t, "dial", err)
-		assertJSONEcho(t, ctx, c, 8393)
+		c.Close(websocket.StatusNormalClosure, "")
 	})
 }
 
-func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) {
+func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) {
 	h := http.HandlerFunc(fn)
-	if tls {
+	if randInt(tb, 2) == 1 {
 		s = httptest.NewTLSServer(h)
 	} else {
 		s = httptest.NewServer(h)
@@ -179,3 +136,19 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error {
 func wsURL(s *httptest.Server) string {
 	return strings.Replace(s.URL, "http", "ws", 1)
 }
+
+func testEchoLoop(t testing.TB) (*httptest.Server, func()) {
+	return testServer(t, func(w http.ResponseWriter, r *http.Request) {
+		c := acceptWebSocket(t, r, w, nil)
+		defer c.Close(websocket.StatusInternalError, "")
+
+		err := echoLoop(r.Context(), c)
+		assertCloseStatus(t, websocket.StatusNormalClosure, err)
+	})
+}
+
+func randInt(t testing.TB, max int) int {
+	x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
+	assert.Success(t, "rand.Int", err)
+	return int(x.Int64())
+}
diff --git a/example_test.go b/example_test.go
index bc603af..1842b76 100644
--- a/example_test.go
+++ b/example_test.go
@@ -33,8 +33,6 @@ func ExampleAccept() {
 			return
 		}
 
-		log.Printf("received: %v", v)
-
 		c.Close(websocket.StatusNormalClosure, "")
 	})
 
diff --git a/read.go b/read.go
index 73ec0b3..7e74894 100644
--- a/read.go
+++ b/read.go
@@ -95,6 +95,7 @@ func (mr *msgReader) ensureFlate() {
 		mr.flateReader = getFlateReader(readerFunc(mr.read), nil)
 	}
 	mr.limitReader.r = mr.flateReader
+	mr.flateTail.Reset(deflateMessageTail)
 }
 
 func (mr *msgReader) returnFlateReader() {
@@ -328,12 +329,12 @@ type msgReader struct {
 func (mr *msgReader) reset(ctx context.Context, h header) {
 	mr.ctx = ctx
 	mr.flate = h.rsv1
+	mr.limitReader.reset(readerFunc(mr.read))
+
 	if mr.flate {
 		mr.ensureFlate()
-		mr.flateTail.Reset(deflateMessageTail)
 	}
 
-	mr.limitReader.reset()
 	mr.setFrame(h)
 }
 
@@ -423,13 +424,13 @@ func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
 		c: c,
 	}
 	lr.limit.Store(limit)
-	lr.r = r
-	lr.reset()
+	lr.reset(r)
 	return lr
 }
 
-func (lr *limitReader) reset() {
+func (lr *limitReader) reset(r io.Reader) {
 	lr.n = lr.limit.Load()
+	lr.r = r
 }
 
 func (lr *limitReader) Read(p []byte) (int, error) {
diff --git a/write.go b/write.go
index 4a756fa..3454348 100644
--- a/write.go
+++ b/write.go
@@ -76,6 +76,7 @@ func (mw *msgWriter) ensureFlate() {
 				w: writerFunc(mw.write),
 			}
 		}
+		mw.trimWriter.reset()
 
 		mw.flateWriter = getFlateWriter(mw.trimWriter)
 		mw.flate = true
diff --git a/ws_js_test.go b/ws_js_test.go
index 6e87480..9f725a5 100644
--- a/ws_js_test.go
+++ b/ws_js_test.go
@@ -24,7 +24,7 @@ func TestEcho(t *testing.T) {
 
 	assertSubprotocol(t, c, "echo")
 	assert.Equalf(t, &http.Response{}, resp, "http.Response")
-	assertJSONEcho(t, ctx, c, 1024)
+	echoJSON(t, ctx, c, 1024)
 	assertEcho(t, ctx, c, websocket.MessageBinary, 1024)
 
 	err = c.Close(websocket.StatusNormalClosure, "")
-- 
GitLab