From 78da35ec5b221d5ec664ee9cbf0a8fb034d46f4c Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Feb 2020 00:58:57 -0600 Subject: [PATCH] Get test with multiple messages working --- README.md | 2 +- assert_test.go | 79 +++++++++++++++++++++++++++++++----------- autobahn_test.go | 16 +++++---- conn_test.go | 89 +++++++++++++++++------------------------------- example_test.go | 2 -- read.go | 11 +++--- write.go | 1 + ws_js_test.go | 2 +- 8 files changed, 109 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index e958d2a..2569383 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ go get nhooyr.io/websocket - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- Can target [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) +- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) ## Roadmap diff --git a/assert_test.go b/assert_test.go index 6cfd926..3727d99 100644 --- a/assert_test.go +++ b/assert_test.go @@ -3,10 +3,15 @@ package websocket_test import ( "context" "crypto/rand" + "fmt" + "net/http" + "net/http/httptest" "strings" "testing" + "time" "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest/assert" "nhooyr.io/websocket" @@ -20,26 +25,31 @@ func randBytes(t *testing.T, n int) []byte { return b } -func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { - t.Helper() - defer c.Close(websocket.StatusInternalError, "") +func echoJSON(t *testing.T, c *websocket.Conn, n int) { + slog.Helper() - exp := randString(t, n) - err := wsjson.Write(ctx, c, exp) - assert.Success(t, "wsjson.Write", err) + s := randString(t, n) + writeJSON(t, c, s) + readJSON(t, c, s) +} - assertJSONRead(t, ctx, c, exp) +func writeJSON(t *testing.T, c *websocket.Conn, v interface{}) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() - c.Close(websocket.StatusNormalClosure, "") + err := wsjson.Write(ctx, c, v) + assert.Success(t, "wsjson.Write", err) } -func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { +func readJSON(t *testing.T, c *websocket.Conn, exp interface{}) { slog.Helper() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + var act interface{} err := wsjson.Read(ctx, c, &act) assert.Success(t, "wsjson.Read", err) - assert.Equal(t, "json", exp, act) } @@ -58,7 +68,7 @@ func randString(t *testing.T, n int) string { } func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { - t.Helper() + slog.Helper() p := randBytes(t, n) err := c.Write(ctx, typ, p) @@ -72,17 +82,46 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { - t.Helper() + slog.Helper() assert.Equal(t, "subprotocol", exp, c.Subprotocol()) } -func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { - t.Helper() - defer func() { - if t.Failed() { - t.Logf("error: %+v", err) - } - }() - assert.Equal(t, "closeStatus", exp, websocket.CloseStatus(err)) +func assertCloseStatus(t testing.TB, exp websocket.StatusCode, err error) { + slog.Helper() + + if websocket.CloseStatus(err) == -1 { + slogtest.Fatal(t, "expected websocket.CloseError", slogType(err), slog.Error(err)) + } + if websocket.CloseStatus(err) != exp { + slogtest.Error(t, "unexpected close status", + slog.F("exp", exp), + slog.F("act", err), + ) + } + +} + +func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts *websocket.AcceptOptions) *websocket.Conn { + c, err := websocket.Accept(w, r, opts) + assert.Success(t, "websocket.Accept", err) + return c +} + +func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + if opts == nil { + opts = &websocket.DialOptions{} + } + opts.HTTPClient = s.Client() + + c, resp, err := websocket.Dial(ctx, wsURL(s), opts) + assert.Success(t, "websocket.Dial", err) + return c, resp +} + +func slogType(v interface{}) slog.Field { + return slog.F("type", fmt.Sprintf("%T", v)) } diff --git a/autobahn_test.go b/autobahn_test.go index bcbf867..dd9887f 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httptest" "os" "os/exec" "strconv" @@ -53,15 +54,18 @@ func TestAutobahn(t *testing.T) { func testServerAutobahn(t *testing.T) { t.Parallel() - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) - assert.Success(t, "accept", err) - err = echoLoop(r.Context(), c) + err := echoLoop(r.Context(), c) assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) - defer closeFn() + })) + closeFn := wsgrace(s.Config) + defer func() { + err := closeFn() + assert.Success(t, "closeFn", err) + }() specFile, err := tempJSONFile(map[string]interface{}{ "outdir": "ci/out/wstestServerReports", diff --git a/conn_test.go b/conn_test.go index 4720cba..6f6b8d5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,7 +4,9 @@ package websocket_test import ( "context" + "crypto/rand" "io" + "math/big" "net/http" "net/http/httptest" "strings" @@ -18,77 +20,32 @@ import ( "nhooyr.io/websocket" ) -func TestFuzz(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, - }) - assert.Success(t, "accept", err) - defer c.Close(websocket.StatusInternalError, "") - - err = echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) - defer closeFn() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - opts := &websocket.DialOptions{ - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, - } - opts.HTTPClient = s.Client() - - c, _, err := websocket.Dial(ctx, wsURL(s), opts) - assert.Success(t, "dial", err) - assertJSONEcho(t, ctx, c, 8393) -} - func TestConn(t *testing.T) { t.Parallel() t.Run("json", func(t *testing.T) { - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, - }) - assert.Success(t, "accept", err) - defer c.Close(websocket.StatusInternalError, "") - - err = echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) + t.Parallel() + + s, closeFn := testEchoLoop(t) defer closeFn() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + c, _ := dialWebSocket(t, s, nil) + defer c.Close(websocket.StatusInternalError, "") - opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, + c.SetReadLimit(1 << 30) + + for i := 0; i < 10; i++ { + n := randInt(t, 1_048_576) + echoJSON(t, c, n) } - opts.HTTPClient = s.Client() - c, _, err := websocket.Dial(ctx, wsURL(s), opts) - assert.Success(t, "dial", err) - assertJSONEcho(t, ctx, c, 8393) + c.Close(websocket.StatusNormalClosure, "") }) } -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) { +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) { h := http.HandlerFunc(fn) - if tls { + if randInt(tb, 2) == 1 { s = httptest.NewTLSServer(h) } else { s = httptest.NewServer(h) @@ -179,3 +136,19 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { func wsURL(s *httptest.Server) string { return strings.Replace(s.URL, "http", "ws", 1) } + +func testEchoLoop(t testing.TB) (*httptest.Server, func()) { + return testServer(t, func(w http.ResponseWriter, r *http.Request) { + c := acceptWebSocket(t, r, w, nil) + defer c.Close(websocket.StatusInternalError, "") + + err := echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }) +} + +func randInt(t testing.TB, max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + assert.Success(t, "rand.Int", err) + return int(x.Int64()) +} diff --git a/example_test.go b/example_test.go index bc603af..1842b76 100644 --- a/example_test.go +++ b/example_test.go @@ -33,8 +33,6 @@ func ExampleAccept() { return } - log.Printf("received: %v", v) - c.Close(websocket.StatusNormalClosure, "") }) diff --git a/read.go b/read.go index 73ec0b3..7e74894 100644 --- a/read.go +++ b/read.go @@ -95,6 +95,7 @@ func (mr *msgReader) ensureFlate() { mr.flateReader = getFlateReader(readerFunc(mr.read), nil) } mr.limitReader.r = mr.flateReader + mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) returnFlateReader() { @@ -328,12 +329,12 @@ type msgReader struct { func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 + mr.limitReader.reset(readerFunc(mr.read)) + if mr.flate { mr.ensureFlate() - mr.flateTail.Reset(deflateMessageTail) } - mr.limitReader.reset() mr.setFrame(h) } @@ -423,13 +424,13 @@ func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { c: c, } lr.limit.Store(limit) - lr.r = r - lr.reset() + lr.reset(r) return lr } -func (lr *limitReader) reset() { +func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() + lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { diff --git a/write.go b/write.go index 4a756fa..3454348 100644 --- a/write.go +++ b/write.go @@ -76,6 +76,7 @@ func (mw *msgWriter) ensureFlate() { w: writerFunc(mw.write), } } + mw.trimWriter.reset() mw.flateWriter = getFlateWriter(mw.trimWriter) mw.flate = true diff --git a/ws_js_test.go b/ws_js_test.go index 6e87480..9f725a5 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -24,7 +24,7 @@ func TestEcho(t *testing.T) { assertSubprotocol(t, c, "echo") assert.Equalf(t, &http.Response{}, resp, "http.Response") - assertJSONEcho(t, ctx, c, 1024) + echoJSON(t, ctx, c, 1024) assertEcho(t, ctx, c, websocket.MessageBinary, 1024) err = c.Close(websocket.StatusNormalClosure, "") -- GitLab