From faadcc9613d9e663ef39dd9d71196e033f3f2901 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sat, 8 Feb 2020 23:14:03 -0500 Subject: [PATCH] Simplify tests --- assert_test.go | 112 ------------------------ compress_test.go | 25 +----- conn_test.go | 165 +++++++++++++++-------------------- dial.go | 1 + go.mod | 2 +- internal/test/cmp/cmp.go | 22 +++++ internal/test/doc.go | 2 + internal/test/wstest/pipe.go | 82 +++++++++++++++++ internal/test/xrand/xrand.go | 47 ++++++++++ ws_js_test.go | 12 ++- 10 files changed, 234 insertions(+), 236 deletions(-) delete mode 100644 assert_test.go create mode 100644 internal/test/cmp/cmp.go create mode 100644 internal/test/doc.go create mode 100644 internal/test/wstest/pipe.go create mode 100644 internal/test/xrand/xrand.go diff --git a/assert_test.go b/assert_test.go deleted file mode 100644 index a51b2c3..0000000 --- a/assert_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package websocket_test - -import ( - "context" - "crypto/rand" - "fmt" - "net/http" - "strings" - "testing" - "time" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "cdr.dev/slog/sloggers/slogtest/assert" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" -) - -func randBytes(t *testing.T, n int) []byte { - b := make([]byte, n) - _, err := rand.Reader.Read(b) - assert.Success(t, "readRandBytes", err) - return b -} - -func echoJSON(t *testing.T, c *websocket.Conn, n int) { - slog.Helper() - - s := randString(t, n) - go writeJSON(t, c, s) - readJSON(t, c, s) -} - -func writeJSON(t *testing.T, c *websocket.Conn, v interface{}) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := wsjson.Write(ctx, c, v) - assert.Success(t, "wsjson.Write", err) -} - -func readJSON(t *testing.T, c *websocket.Conn, exp interface{}) { - slog.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - var act interface{} - err := wsjson.Read(ctx, c, &act) - assert.Success(t, "wsjson.Read", err) - assert.Equal(t, "json", exp, act) -} - -func randString(t *testing.T, n int) string { - s := strings.ToValidUTF8(string(randBytes(t, n)), "_") - s = strings.ReplaceAll(s, "\x00", "_") - if len(s) > n { - return s[:n] - } - if len(s) < n { - // Pad with = - extra := n - len(s) - return s + strings.Repeat("=", extra) - } - return s -} - -func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { - slog.Helper() - - p := randBytes(t, n) - err := c.Write(ctx, typ, p) - assert.Success(t, "write", err) - - typ2, p2, err := c.Read(ctx) - assert.Success(t, "read", err) - - assert.Equal(t, "dataType", typ, typ2) - assert.Equal(t, "payload", p, p2) -} - -func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { - slog.Helper() - - assert.Equal(t, "subprotocol", exp, c.Subprotocol()) -} - -func assertCloseStatus(t testing.TB, exp websocket.StatusCode, err error) { - slog.Helper() - - if websocket.CloseStatus(err) == -1 { - slogtest.Fatal(t, "expected websocket.CloseError", slogType(err), slog.Error(err)) - } - if websocket.CloseStatus(err) != exp { - slogtest.Error(t, "unexpected close status", - slog.F("exp", exp), - slog.F("act", err), - ) - } - -} - -func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts *websocket.AcceptOptions) *websocket.Conn { - c, err := websocket.Accept(w, r, opts) - assert.Success(t, "websocket.Accept", err) - return c -} - -func slogType(v interface{}) slog.Field { - return slog.F("type", fmt.Sprintf("%T", v)) -} diff --git a/compress_test.go b/compress_test.go index 6edfcb1..15d334d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,13 +1,12 @@ package websocket import ( - "crypto/rand" - "encoding/base64" - "math/big" "strings" "testing" "cdr.dev/slog/sloggers/slogtest/assert" + + "nhooyr.io/websocket/internal/test/xrand" ) func Test_slidingWindow(t *testing.T) { @@ -16,8 +15,8 @@ func Test_slidingWindow(t *testing.T) { const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { - input := randStr(t, maxWindow) - windowLength := randInt(t, maxWindow) + input := xrand.String(maxWindow) + windowLength := xrand.Int(maxWindow) r := newSlidingWindow(windowLength) r.write([]byte(input)) @@ -27,19 +26,3 @@ func Test_slidingWindow(t *testing.T) { assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf))) } } - -func randStr(t *testing.T, max int) string { - n := randInt(t, max) - - b := make([]byte, n) - _, err := rand.Read(b) - assert.Success(t, "rand.Read", err) - - return base64.StdEncoding.EncodeToString(b) -} - -func randInt(t *testing.T, max int) int { - x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) - assert.Success(t, "rand.Int", err) - return int(x.Int64()) -} diff --git a/conn_test.go b/conn_test.go index f1361ad..d246f71 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,59 +3,96 @@ package websocket_test import ( - "bufio" "context" - "crypto/rand" "io" - "math/big" - "net" - "net/http" - "net/http/httptest" "testing" "time" - "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/xerrors" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/wsjson" ) -func goFn(fn func()) func() { - done := make(chan struct{}) +func goFn(fn func() error) chan error { + errs := make(chan error) go func() { - defer close(done) - fn() + defer close(errs) + errs <- fn() }() - return func() { - <-done - } + return errs } func TestConn(t *testing.T) { t.Parallel() - t.Run("json", func(t *testing.T) { + t.Run("data", func(t *testing.T) { t.Parallel() - for i := 0; i < 1; i++ { + for i := 0; i < 10; i++ { t.Run("", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - c1, c2 := websocketPipe(t) + copts := websocket.CompressionOptions{ + Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled))), + Threshold: xrand.Int(9999), + } + + c1, c2, err := wstest.Pipe(&websocket.DialOptions{ + CompressionOptions: copts, + }, &websocket.AcceptOptions{ + CompressionOptions: copts, + }) + if err != nil { + t.Fatal(err) + } + defer c1.Close(websocket.StatusInternalError, "") + defer c2.Close(websocket.StatusInternalError, "") - wait := goFn(func() { + echoLoopErr := goFn(func() error { err := echoLoop(ctx, c1) - assertCloseStatus(t, websocket.StatusNormalClosure, err) + return assertCloseStatus(websocket.StatusNormalClosure, err) }) - defer wait() + defer func() { + err := <-echoLoopErr + if err != nil { + t.Errorf("echo loop error: %v", err) + } + }() defer cancel() c2.SetReadLimit(1 << 30) for i := 0; i < 10; i++ { - n := randInt(t, 131_072) - echoJSON(t, c2, n) + n := xrand.Int(131_072) + + msg := xrand.String(n) + + writeErr := goFn(func() error { + return wsjson.Write(ctx, c2, msg) + }) + + var act interface{} + err := wsjson.Read(ctx, c2, &act) + if err != nil { + t.Fatal(err) + } + + err = <-writeErr + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(msg, act) { + t.Fatalf("unexpected msg read: %v", cmp.Diff(msg, act)) + } } c2.Close(websocket.StatusNormalClosure, "") @@ -64,6 +101,16 @@ func TestConn(t *testing.T) { }) } +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return xerrors.Errorf("unexpected close status (%v):%v", exp, err) + } + return nil +} + // echoLoop echos every msg received from c until an error // occurs or the context expires. // The read limit is set to 1 << 30. @@ -98,75 +145,3 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } } - -func randBool(t testing.TB) bool { - return randInt(t, 2) == 1 -} - -func randInt(t testing.TB, max int) int { - x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) - assert.Success(t, "rand.Int", err) - return int(x.Int64()) -} - -type testHijacker struct { - *httptest.ResponseRecorder - serverConn net.Conn - hijacked chan struct{} -} - -var _ http.Hijacker = testHijacker{} - -func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { - close(hj.hijacked) - return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil -} - -func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) { - var serverConn *websocket.Conn - tt := testTransport{ - h: func(w http.ResponseWriter, r *http.Request) { - serverConn = acceptWebSocket(t, r, w, nil) - }, - } - - dialOpts := &websocket.DialOptions{ - HTTPClient: &http.Client{ - Transport: tt, - }, - } - - clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) - assert.Success(t, "websocket.Dial", err) - - if randBool(t) { - return serverConn, clientConn - } - return clientConn, serverConn -} - -type testTransport struct { - h http.HandlerFunc -} - -func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) { - clientConn, serverConn := net.Pipe() - - hj := testHijacker{ - ResponseRecorder: httptest.NewRecorder(), - serverConn: serverConn, - hijacked: make(chan struct{}), - } - - done := make(chan struct{}) - t.h.ServeHTTP(hj, r) - - select { - case <-hj.hijacked: - resp := hj.ResponseRecorder.Result() - resp.Body = clientConn - return resp, nil - case <-done: - return hj.ResponseRecorder.Result(), nil - } -} diff --git a/dial.go b/dial.go index 4557602..a1509ab 100644 --- a/dial.go +++ b/dial.go @@ -35,6 +35,7 @@ type DialOptions struct { // CompressionOptions controls the compression options. // See docs on the CompressionOptions type. + // TODO make * CompressionOptions CompressionOptions } diff --git a/go.mod b/go.mod index ee1708a..fc4ebb9 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/gobwas/ws v1.0.2 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.3.3 - github.com/google/go-cmp v0.4.0 // indirect + github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 github.com/mattn/go-isatty v0.0.12 // indirect go.opencensus.io v0.22.3 // indirect diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go new file mode 100644 index 0000000..d0eee6d --- /dev/null +++ b/internal/test/cmp/cmp.go @@ -0,0 +1,22 @@ +package cmp + +import ( + "reflect" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +// Equal checks if v1 and v2 are equal with go-cmp. +func Equal(v1, v2 interface{}) bool { + return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { + return true + })) +} + +// Diff returns a human readable diff between v1 and v2 +func Diff(v1, v2 interface{}) string { + return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { + return true + })) +} diff --git a/internal/test/doc.go b/internal/test/doc.go new file mode 100644 index 0000000..94b2e82 --- /dev/null +++ b/internal/test/doc.go @@ -0,0 +1,2 @@ +// Package test contains subpackages only used in tests. +package test diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go new file mode 100644 index 0000000..f3d25f5 --- /dev/null +++ b/internal/test/wstest/pipe.go @@ -0,0 +1,82 @@ +package wstest + +import ( + "bufio" + "context" + "net" + "net/http" + "net/http/httptest" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/xrand" +) + +// Pipe is used to create an in memory connection +// between two websockets analogous to net.Pipe. +func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (_ *websocket.Conn, _ *websocket.Conn, err error) { + defer errd.Wrap(&err, "failed to create ws pipe") + + var serverConn *websocket.Conn + var acceptErr error + tt := fakeTransport{ + h: func(w http.ResponseWriter, r *http.Request) { + serverConn, acceptErr = websocket.Accept(w, r, acceptOpts) + }, + } + + if dialOpts == nil { + dialOpts = &websocket.DialOptions{} + } + dialOpts.HTTPClient = &http.Client{ + Transport: tt, + } + + clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) + if err != nil { + return nil, nil, xerrors.Errorf("failed to dial with fake transport: %w", err) + } + + if serverConn == nil { + return nil, nil, xerrors.Errorf("failed to get server conn from fake transport: %w", acceptErr) + } + + if xrand.True() { + return serverConn, clientConn, nil + } + return clientConn, serverConn, nil +} + +type fakeTransport struct { + h http.HandlerFunc +} + +func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { + clientConn, serverConn := net.Pipe() + + hj := testHijacker{ + ResponseRecorder: httptest.NewRecorder(), + serverConn: serverConn, + } + + t.h.ServeHTTP(hj, r) + + resp := hj.ResponseRecorder.Result() + if resp.StatusCode == http.StatusSwitchingProtocols { + resp.Body = clientConn + } + return resp, nil +} + +type testHijacker struct { + *httptest.ResponseRecorder + serverConn net.Conn +} + +var _ http.Hijacker = testHijacker{} + +func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil +} diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go new file mode 100644 index 0000000..2f3ad30 --- /dev/null +++ b/internal/test/xrand/xrand.go @@ -0,0 +1,47 @@ +package xrand + +import ( + "crypto/rand" + "fmt" + "math/big" + "strings" +) + +// Bytes generates random bytes with length n. +func Bytes(n int) []byte { + b := make([]byte, n) + _, err := rand.Reader.Read(b) + if err != nil { + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + return b +} + +// String generates a random string with length n. +func String(n int) string { + s := strings.ToValidUTF8(string(Bytes(n)), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] + } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s +} + +// True returns a randomly generated boolean. +func True() bool { + return Int(2) == 1 +} + +// Int returns a randomly generated integer between [0, max). +func Int(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + panic(fmt.Sprintf("failed to get random int: %v", err)) + } + return int(x.Int64()) +} diff --git a/ws_js_test.go b/ws_js_test.go index 9f725a5..65309bf 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -1,4 +1,4 @@ -package websocket_test +package websocket import ( "context" @@ -6,8 +6,6 @@ import ( "os" "testing" "time" - - "nhooyr.io/websocket" ) func TestEcho(t *testing.T) { @@ -16,17 +14,17 @@ func TestEcho(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ + c, resp, err := Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &DialOptions{ Subprotocols: []string{"echo"}, }) assert.Success(t, err) - defer c.Close(websocket.StatusInternalError, "") + defer c.Close(StatusInternalError, "") assertSubprotocol(t, c, "echo") assert.Equalf(t, &http.Response{}, resp, "http.Response") echoJSON(t, ctx, c, 1024) - assertEcho(t, ctx, c, websocket.MessageBinary, 1024) + assertEcho(t, ctx, c, MessageBinary, 1024) - err = c.Close(websocket.StatusNormalClosure, "") + err = c.Close(StatusNormalClosure, "") assert.Success(t, err) } -- GitLab