From faadcc9613d9e663ef39dd9d71196e033f3f2901 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 8 Feb 2020 23:14:03 -0500
Subject: [PATCH] Simplify tests

---
 assert_test.go               | 112 ------------------------
 compress_test.go             |  25 +-----
 conn_test.go                 | 165 +++++++++++++++--------------------
 dial.go                      |   1 +
 go.mod                       |   2 +-
 internal/test/cmp/cmp.go     |  22 +++++
 internal/test/doc.go         |   2 +
 internal/test/wstest/pipe.go |  82 +++++++++++++++++
 internal/test/xrand/xrand.go |  47 ++++++++++
 ws_js_test.go                |  12 ++-
 10 files changed, 234 insertions(+), 236 deletions(-)
 delete mode 100644 assert_test.go
 create mode 100644 internal/test/cmp/cmp.go
 create mode 100644 internal/test/doc.go
 create mode 100644 internal/test/wstest/pipe.go
 create mode 100644 internal/test/xrand/xrand.go

diff --git a/assert_test.go b/assert_test.go
deleted file mode 100644
index a51b2c3..0000000
--- a/assert_test.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package websocket_test
-
-import (
-	"context"
-	"crypto/rand"
-	"fmt"
-	"net/http"
-	"strings"
-	"testing"
-	"time"
-
-	"cdr.dev/slog"
-	"cdr.dev/slog/sloggers/slogtest"
-	"cdr.dev/slog/sloggers/slogtest/assert"
-
-	"nhooyr.io/websocket"
-	"nhooyr.io/websocket/wsjson"
-)
-
-func randBytes(t *testing.T, n int) []byte {
-	b := make([]byte, n)
-	_, err := rand.Reader.Read(b)
-	assert.Success(t, "readRandBytes", err)
-	return b
-}
-
-func echoJSON(t *testing.T, c *websocket.Conn, n int) {
-	slog.Helper()
-
-	s := randString(t, n)
-	go writeJSON(t, c, s)
-	readJSON(t, c, s)
-}
-
-func writeJSON(t *testing.T, c *websocket.Conn, v interface{}) {
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
-	defer cancel()
-
-	err := wsjson.Write(ctx, c, v)
-	assert.Success(t, "wsjson.Write", err)
-}
-
-func readJSON(t *testing.T, c *websocket.Conn, exp interface{}) {
-	slog.Helper()
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
-	defer cancel()
-
-	var act interface{}
-	err := wsjson.Read(ctx, c, &act)
-	assert.Success(t, "wsjson.Read", err)
-	assert.Equal(t, "json", exp, act)
-}
-
-func randString(t *testing.T, n int) string {
-	s := strings.ToValidUTF8(string(randBytes(t, n)), "_")
-	s = strings.ReplaceAll(s, "\x00", "_")
-	if len(s) > n {
-		return s[:n]
-	}
-	if len(s) < n {
-		// Pad with =
-		extra := n - len(s)
-		return s + strings.Repeat("=", extra)
-	}
-	return s
-}
-
-func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) {
-	slog.Helper()
-
-	p := randBytes(t, n)
-	err := c.Write(ctx, typ, p)
-	assert.Success(t, "write", err)
-
-	typ2, p2, err := c.Read(ctx)
-	assert.Success(t, "read", err)
-
-	assert.Equal(t, "dataType", typ, typ2)
-	assert.Equal(t, "payload", p, p2)
-}
-
-func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) {
-	slog.Helper()
-
-	assert.Equal(t, "subprotocol", exp, c.Subprotocol())
-}
-
-func assertCloseStatus(t testing.TB, exp websocket.StatusCode, err error) {
-	slog.Helper()
-
-	if websocket.CloseStatus(err) == -1 {
-		slogtest.Fatal(t, "expected websocket.CloseError", slogType(err), slog.Error(err))
-	}
-	if websocket.CloseStatus(err) != exp {
-		slogtest.Error(t, "unexpected close status",
-			slog.F("exp", exp),
-			slog.F("act", err),
-		)
-	}
-
-}
-
-func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts *websocket.AcceptOptions) *websocket.Conn {
-	c, err := websocket.Accept(w, r, opts)
-	assert.Success(t, "websocket.Accept", err)
-	return c
-}
-
-func slogType(v interface{}) slog.Field {
-	return slog.F("type", fmt.Sprintf("%T", v))
-}
diff --git a/compress_test.go b/compress_test.go
index 6edfcb1..15d334d 100644
--- a/compress_test.go
+++ b/compress_test.go
@@ -1,13 +1,12 @@
 package websocket
 
 import (
-	"crypto/rand"
-	"encoding/base64"
-	"math/big"
 	"strings"
 	"testing"
 
 	"cdr.dev/slog/sloggers/slogtest/assert"
+
+	"nhooyr.io/websocket/internal/test/xrand"
 )
 
 func Test_slidingWindow(t *testing.T) {
@@ -16,8 +15,8 @@ func Test_slidingWindow(t *testing.T) {
 	const testCount = 99
 	const maxWindow = 99999
 	for i := 0; i < testCount; i++ {
-		input := randStr(t, maxWindow)
-		windowLength := randInt(t, maxWindow)
+		input := xrand.String(maxWindow)
+		windowLength := xrand.Int(maxWindow)
 		r := newSlidingWindow(windowLength)
 		r.write([]byte(input))
 
@@ -27,19 +26,3 @@ func Test_slidingWindow(t *testing.T) {
 		assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf)))
 	}
 }
-
-func randStr(t *testing.T, max int) string {
-	n := randInt(t, max)
-
-	b := make([]byte, n)
-	_, err := rand.Read(b)
-	assert.Success(t, "rand.Read", err)
-
-	return base64.StdEncoding.EncodeToString(b)
-}
-
-func randInt(t *testing.T, max int) int {
-	x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
-	assert.Success(t, "rand.Int", err)
-	return int(x.Int64())
-}
diff --git a/conn_test.go b/conn_test.go
index f1361ad..d246f71 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -3,59 +3,96 @@
 package websocket_test
 
 import (
-	"bufio"
 	"context"
-	"crypto/rand"
 	"io"
-	"math/big"
-	"net"
-	"net/http"
-	"net/http/httptest"
 	"testing"
 	"time"
 
-	"cdr.dev/slog/sloggers/slogtest/assert"
+	"golang.org/x/xerrors"
 
 	"nhooyr.io/websocket"
+	"nhooyr.io/websocket/internal/test/cmp"
+	"nhooyr.io/websocket/internal/test/wstest"
+	"nhooyr.io/websocket/internal/test/xrand"
+	"nhooyr.io/websocket/wsjson"
 )
 
-func goFn(fn func()) func() {
-	done := make(chan struct{})
+func goFn(fn func() error) chan error {
+	errs := make(chan error)
 	go func() {
-		defer close(done)
-		fn()
+		defer close(errs)
+		errs <- fn()
 	}()
 
-	return func() {
-		<-done
-	}
+	return errs
 }
 
 func TestConn(t *testing.T) {
 	t.Parallel()
 
-	t.Run("json", func(t *testing.T) {
+	t.Run("data", func(t *testing.T) {
 		t.Parallel()
 
-		for i := 0; i < 1; i++ {
+		for i := 0; i < 10; i++ {
 			t.Run("", func(t *testing.T) {
-				ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+				t.Parallel()
+
+				ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
 				defer cancel()
 
-				c1, c2 := websocketPipe(t)
+				copts := websocket.CompressionOptions{
+					Mode:      websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled))),
+					Threshold: xrand.Int(9999),
+				}
+
+				c1, c2, err := wstest.Pipe(&websocket.DialOptions{
+					CompressionOptions: copts,
+				}, &websocket.AcceptOptions{
+					CompressionOptions: copts,
+				})
+				if err != nil {
+					t.Fatal(err)
+				}
+				defer c1.Close(websocket.StatusInternalError, "")
+				defer c2.Close(websocket.StatusInternalError, "")
 
-				wait := goFn(func() {
+				echoLoopErr := goFn(func() error {
 					err := echoLoop(ctx, c1)
-					assertCloseStatus(t, websocket.StatusNormalClosure, err)
+					return assertCloseStatus(websocket.StatusNormalClosure, err)
 				})
-				defer wait()
+				defer func() {
+					err := <-echoLoopErr
+					if err != nil {
+						t.Errorf("echo loop error: %v", err)
+					}
+				}()
 				defer cancel()
 
 				c2.SetReadLimit(1 << 30)
 
 				for i := 0; i < 10; i++ {
-					n := randInt(t, 131_072)
-					echoJSON(t, c2, n)
+					n := xrand.Int(131_072)
+
+					msg := xrand.String(n)
+
+					writeErr := goFn(func() error {
+						return wsjson.Write(ctx, c2, msg)
+					})
+
+					var act interface{}
+					err := wsjson.Read(ctx, c2, &act)
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					err = <-writeErr
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					if !cmp.Equal(msg, act) {
+						t.Fatalf("unexpected msg read: %v", cmp.Diff(msg, act))
+					}
 				}
 
 				c2.Close(websocket.StatusNormalClosure, "")
@@ -64,6 +101,16 @@ func TestConn(t *testing.T) {
 	})
 }
 
+func assertCloseStatus(exp websocket.StatusCode, err error) error {
+	if websocket.CloseStatus(err) == -1 {
+		return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err)
+	}
+	if websocket.CloseStatus(err) != exp {
+		return xerrors.Errorf("unexpected close status (%v):%v", exp, err)
+	}
+	return nil
+}
+
 // echoLoop echos every msg received from c until an error
 // occurs or the context expires.
 // The read limit is set to 1 << 30.
@@ -98,75 +145,3 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error {
 		}
 	}
 }
-
-func randBool(t testing.TB) bool {
-	return randInt(t, 2) == 1
-}
-
-func randInt(t testing.TB, max int) int {
-	x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
-	assert.Success(t, "rand.Int", err)
-	return int(x.Int64())
-}
-
-type testHijacker struct {
-	*httptest.ResponseRecorder
-	serverConn net.Conn
-	hijacked   chan struct{}
-}
-
-var _ http.Hijacker = testHijacker{}
-
-func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	close(hj.hijacked)
-	return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
-}
-
-func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) {
-	var serverConn *websocket.Conn
-	tt := testTransport{
-		h: func(w http.ResponseWriter, r *http.Request) {
-			serverConn = acceptWebSocket(t, r, w, nil)
-		},
-	}
-
-	dialOpts := &websocket.DialOptions{
-		HTTPClient: &http.Client{
-			Transport: tt,
-		},
-	}
-
-	clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts)
-	assert.Success(t, "websocket.Dial", err)
-
-	if randBool(t) {
-		return serverConn, clientConn
-	}
-	return clientConn, serverConn
-}
-
-type testTransport struct {
-	h http.HandlerFunc
-}
-
-func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) {
-	clientConn, serverConn := net.Pipe()
-
-	hj := testHijacker{
-		ResponseRecorder: httptest.NewRecorder(),
-		serverConn:       serverConn,
-		hijacked:         make(chan struct{}),
-	}
-
-	done := make(chan struct{})
-	t.h.ServeHTTP(hj, r)
-
-	select {
-	case <-hj.hijacked:
-		resp := hj.ResponseRecorder.Result()
-		resp.Body = clientConn
-		return resp, nil
-	case <-done:
-		return hj.ResponseRecorder.Result(), nil
-	}
-}
diff --git a/dial.go b/dial.go
index 4557602..a1509ab 100644
--- a/dial.go
+++ b/dial.go
@@ -35,6 +35,7 @@ type DialOptions struct {
 
 	// CompressionOptions controls the compression options.
 	// See docs on the CompressionOptions type.
+	// TODO make *
 	CompressionOptions CompressionOptions
 }
 
diff --git a/go.mod b/go.mod
index ee1708a..fc4ebb9 100644
--- a/go.mod
+++ b/go.mod
@@ -11,7 +11,7 @@ require (
 	github.com/gobwas/ws v1.0.2
 	github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
 	github.com/golang/protobuf v1.3.3
-	github.com/google/go-cmp v0.4.0 // indirect
+	github.com/google/go-cmp v0.4.0
 	github.com/gorilla/websocket v1.4.1
 	github.com/mattn/go-isatty v0.0.12 // indirect
 	go.opencensus.io v0.22.3 // indirect
diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go
new file mode 100644
index 0000000..d0eee6d
--- /dev/null
+++ b/internal/test/cmp/cmp.go
@@ -0,0 +1,22 @@
+package cmp
+
+import (
+	"reflect"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
+)
+
+// Equal checks if v1 and v2 are equal with go-cmp.
+func Equal(v1, v2 interface{}) bool {
+	return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool {
+		return true
+	}))
+}
+
+// Diff returns a human readable diff between v1 and v2
+func Diff(v1, v2 interface{}) string {
+	return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool {
+		return true
+	}))
+}
diff --git a/internal/test/doc.go b/internal/test/doc.go
new file mode 100644
index 0000000..94b2e82
--- /dev/null
+++ b/internal/test/doc.go
@@ -0,0 +1,2 @@
+// Package test contains subpackages only used in tests.
+package test
diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go
new file mode 100644
index 0000000..f3d25f5
--- /dev/null
+++ b/internal/test/wstest/pipe.go
@@ -0,0 +1,82 @@
+package wstest
+
+import (
+	"bufio"
+	"context"
+	"net"
+	"net/http"
+	"net/http/httptest"
+
+	"golang.org/x/xerrors"
+
+	"nhooyr.io/websocket"
+	"nhooyr.io/websocket/internal/errd"
+	"nhooyr.io/websocket/internal/test/xrand"
+)
+
+// Pipe is used to create an in memory connection
+// between two websockets analogous to net.Pipe.
+func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (_ *websocket.Conn, _ *websocket.Conn, err error) {
+	defer errd.Wrap(&err, "failed to create ws pipe")
+
+	var serverConn *websocket.Conn
+	var acceptErr error
+	tt := fakeTransport{
+		h: func(w http.ResponseWriter, r *http.Request) {
+			serverConn, acceptErr = websocket.Accept(w, r, acceptOpts)
+		},
+	}
+
+	if dialOpts == nil {
+		dialOpts = &websocket.DialOptions{}
+	}
+	dialOpts.HTTPClient = &http.Client{
+		Transport: tt,
+	}
+
+	clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts)
+	if err != nil {
+		return nil, nil, xerrors.Errorf("failed to dial with fake transport: %w", err)
+	}
+
+	if serverConn == nil {
+		return nil, nil, xerrors.Errorf("failed to get server conn from fake transport: %w", acceptErr)
+	}
+
+	if xrand.True() {
+		return serverConn, clientConn, nil
+	}
+	return clientConn, serverConn, nil
+}
+
+type fakeTransport struct {
+	h http.HandlerFunc
+}
+
+func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+	clientConn, serverConn := net.Pipe()
+
+	hj := testHijacker{
+		ResponseRecorder: httptest.NewRecorder(),
+		serverConn:       serverConn,
+	}
+
+	t.h.ServeHTTP(hj, r)
+
+	resp := hj.ResponseRecorder.Result()
+	if resp.StatusCode == http.StatusSwitchingProtocols {
+		resp.Body = clientConn
+	}
+	return resp, nil
+}
+
+type testHijacker struct {
+	*httptest.ResponseRecorder
+	serverConn net.Conn
+}
+
+var _ http.Hijacker = testHijacker{}
+
+func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
+}
diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go
new file mode 100644
index 0000000..2f3ad30
--- /dev/null
+++ b/internal/test/xrand/xrand.go
@@ -0,0 +1,47 @@
+package xrand
+
+import (
+	"crypto/rand"
+	"fmt"
+	"math/big"
+	"strings"
+)
+
+// Bytes generates random bytes with length n.
+func Bytes(n int) []byte {
+	b := make([]byte, n)
+	_, err := rand.Reader.Read(b)
+	if err != nil {
+		panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
+	}
+	return b
+}
+
+// String generates a random string with length n.
+func String(n int) string {
+	s := strings.ToValidUTF8(string(Bytes(n)), "_")
+	s = strings.ReplaceAll(s, "\x00", "_")
+	if len(s) > n {
+		return s[:n]
+	}
+	if len(s) < n {
+		// Pad with =
+		extra := n - len(s)
+		return s + strings.Repeat("=", extra)
+	}
+	return s
+}
+
+// True returns a randomly generated boolean.
+func True() bool {
+	return Int(2) == 1
+}
+
+// Int returns a randomly generated integer between [0, max).
+func Int(max int) int {
+	x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
+	if err != nil {
+		panic(fmt.Sprintf("failed to get random int: %v", err))
+	}
+	return int(x.Int64())
+}
diff --git a/ws_js_test.go b/ws_js_test.go
index 9f725a5..65309bf 100644
--- a/ws_js_test.go
+++ b/ws_js_test.go
@@ -1,4 +1,4 @@
-package websocket_test
+package websocket
 
 import (
 	"context"
@@ -6,8 +6,6 @@ import (
 	"os"
 	"testing"
 	"time"
-
-	"nhooyr.io/websocket"
 )
 
 func TestEcho(t *testing.T) {
@@ -16,17 +14,17 @@ func TestEcho(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
 	defer cancel()
 
-	c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{
+	c, resp, err := Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &DialOptions{
 		Subprotocols: []string{"echo"},
 	})
 	assert.Success(t, err)
-	defer c.Close(websocket.StatusInternalError, "")
+	defer c.Close(StatusInternalError, "")
 
 	assertSubprotocol(t, c, "echo")
 	assert.Equalf(t, &http.Response{}, resp, "http.Response")
 	echoJSON(t, ctx, c, 1024)
-	assertEcho(t, ctx, c, websocket.MessageBinary, 1024)
+	assertEcho(t, ctx, c, MessageBinary, 1024)
 
-	err = c.Close(websocket.StatusNormalClosure, "")
+	err = c.Close(StatusNormalClosure, "")
 	assert.Success(t, err)
 }
-- 
GitLab