From 12d7f1acc28e859be846e7b6ed15066c1259df2b Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sat, 31 Aug 2019 23:48:56 -0500 Subject: [PATCH] Translate the remaining useful Autobahn python tests --- ci/test.sh | 3 +- export_test.go | 34 +- websocket_bench_test.go | 5 +- websocket_test.go | 729 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 748 insertions(+), 23 deletions(-) diff --git a/ci/test.sh b/ci/test.sh index 3c476d9..c8b8ec1 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -12,14 +12,13 @@ argv=( -- "-vet=off" ) -# Interactive usage does not want to turn off vet or use gotestsum by default. +# Interactive usage does not want to turn off vet or use gotestsum. if [[ $# -gt 0 ]]; then argv=(go test "$@") fi # We always want coverage and race detection. argv+=( - -race "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) diff --git a/export_test.go b/export_test.go index fb3cf81..811bf80 100644 --- a/export_test.go +++ b/export_test.go @@ -15,6 +15,7 @@ type ( const ( OpClose = OpCode(opClose) OpBinary = OpCode(opBinary) + OpText = OpCode(opText) OpPing = OpCode(opPing) OpPong = OpCode(opPong) OpContinuation = OpCode(opContinuation) @@ -40,17 +41,38 @@ func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) ( return c.writeFrame(ctx, fin, opcode(opc), p) } -func (c *Conn) WriteHeader(ctx context.Context, fin bool, opc OpCode, lenp int64) error { +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + Rsv1 bool + Rsv2 bool + Rsv3 bool + OpCode OpCode + + PayloadLength int64 +} + +func (c *Conn) WriteHeader(ctx context.Context, h Header) error { headerBytes := writeHeader(c.writeHeaderBuf, header{ - fin: fin, - opcode: opcode(opc), - payloadLength: lenp, + fin: h.Fin, + rsv1: h.Rsv1, + rsv2: h.Rsv2, + rsv3: h.Rsv3, + opcode: opcode(h.OpCode), + payloadLength: h.PayloadLength, masked: c.client, }) _, err := c.bw.Write(headerBytes) if err != nil { return xerrors.Errorf("failed to write header: %w", err) } + if h.Fin { + err = c.Flush() + if err != nil { + return err + } + } return nil } @@ -96,3 +118,7 @@ func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) ( } return b, nil } + +func ParseClosePayload(p []byte) (CloseError, error) { + return parseClosePayload(p) +} diff --git a/websocket_bench_test.go b/websocket_bench_test.go index 4ad8646..6a54fab 100644 --- a/websocket_bench_test.go +++ b/websocket_bench_test.go @@ -5,13 +5,13 @@ import ( "io" "io/ioutil" "net/http" - "nhooyr.io/websocket" "strconv" "strings" "testing" "time" -) + "nhooyr.io/websocket" +) func BenchmarkConn(b *testing.B) { sizes := []int{ @@ -116,7 +116,6 @@ func benchConn(b *testing.B, echo, stream bool, size int) { c.Close(websocket.StatusNormalClosure, "") } - func discardLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") diff --git a/websocket_test.go b/websocket_test.go index 732fc94..3482cbd 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -3,6 +3,7 @@ package websocket_test import ( "bytes" "context" + "encoding/binary" "encoding/json" "fmt" "io" @@ -919,7 +920,7 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e atomic.AddInt64(&conns, 1) defer atomic.AddInt64(&conns, -1) - ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) defer cancel() r = r.WithContext(ctx) @@ -953,8 +954,6 @@ func TestAutobahn(t *testing.T) { run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) { run2 := func(t *testing.T, testingClient bool) { - t.Parallel() - // Run random tests over TLS. tls := rand.Intn(2) == 1 @@ -985,7 +984,7 @@ func TestAutobahn(t *testing.T) { wsURL := strings.Replace(s.URL, "http", "ws", 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() opts := &websocket.DialOptions{ @@ -1017,9 +1016,11 @@ func TestAutobahn(t *testing.T) { t.Parallel() t.Run("server", func(t *testing.T) { + t.Parallel() run2(t, false) }) t.Run("client", func(t *testing.T) { + t.Parallel() run2(t, true) }) }) @@ -1043,8 +1044,7 @@ func TestAutobahn(t *testing.T) { for i, l := range lengths { l := l run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error { - p := make([]byte, l) - rand.Read(p) + p := randBytes(l) if i == len(lengths)-1 { w, err := c.Writer(ctx, typ) if err != nil { @@ -1119,7 +1119,7 @@ func TestAutobahn(t *testing.T) { return assertCloseStatus(err, websocket.StatusProtocolError) }) run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error { - err := streamPing(ctx, c, 125) + err := assertStreamPing(ctx, c, 125) if err != nil { return err } @@ -1189,7 +1189,7 @@ func TestAutobahn(t *testing.T) { }) run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error { for i := 0; i < 10; i++ { - err := streamPing(ctx, c, 125) + err := assertStreamPing(ctx, c, 125) if err != nil { return err } @@ -1200,15 +1200,659 @@ func TestAutobahn(t *testing.T) { }) // Section 3. + // We skip the per octet sending as it will add too much complexity. t.Run("reserved", func(t *testing.T) { t.Parallel() - run(t, "rsv1", func(ctx context.Context, c *websocket.Conn) error { - c.WriteFrame() + var testCases = []struct { + name string + header websocket.Header + }{ + { + name: "rsv1", + header: websocket.Header{ + Fin: true, + Rsv1: true, + OpCode: websocket.OpClose, + PayloadLength: 0, + }, + }, + { + name: "rsv2", + header: websocket.Header{ + Fin: true, + Rsv2: true, + OpCode: websocket.OpPong, + PayloadLength: 0, + }, + }, + { + name: "rsv3", + header: websocket.Header{ + Fin: true, + Rsv3: true, + OpCode: websocket.OpBinary, + PayloadLength: 0, + }, + }, + { + name: "rsvAll", + header: websocket.Header{ + Fin: true, + Rsv1: true, + Rsv2: true, + Rsv3: true, + OpCode: websocket.OpText, + PayloadLength: 0, + }, + }, + } + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + err := assertEcho(ctx, c, websocket.MessageText, 4096) + if err != nil { + return err + } + err = c.WriteHeader(ctx, tc.header) + if err != nil { + return err + } + err = c.Flush() + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + }) + + // Section 4. + t.Run("opcodes", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + opcode websocket.OpCode + payload bool + echo bool + ping bool + }{ + // Section 1. + { + name: "3", + opcode: 3, + }, + { + name: "4", + opcode: 4, + payload: true, + }, + { + name: "5", + opcode: 5, + echo: true, + ping: true, + }, + { + name: "6", + opcode: 6, + payload: true, + echo: true, + ping: true, + }, + { + name: "7", + opcode: 7, + payload: true, + echo: true, + ping: true, + }, + + // Section 2. + { + name: "11", + opcode: 11, + }, + { + name: "12", + opcode: 12, + payload: true, + }, + { + name: "13", + opcode: 13, + payload: true, + echo: true, + ping: true, + }, + { + name: "14", + opcode: 14, + payload: true, + echo: true, + ping: true, + }, + { + name: "15", + opcode: 15, + payload: true, + echo: true, + ping: true, + }, + } + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + if tc.echo { + err := assertEcho(ctx, c, websocket.MessageText, 4096) + if err != nil { + return err + } + } + + p := []byte(nil) + if tc.payload { + p = randBytes(rand.Intn(4096) + 1) + } + _, err := c.WriteFrame(ctx, true, tc.opcode, p) + if err != nil { + return err + } + if tc.ping { + _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) + if err != nil { + return err + } + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + }) + + // Section 5. + t.Run("fragmentation", func(t *testing.T) { + t.Parallel() + + // 5.1 to 5.8 + testCases := []struct { + name string + opcode websocket.OpCode + success bool + pingInBetween bool + }{ + { + name: "ping", + opcode: websocket.OpPing, + success: false, + }, + { + name: "pong", + opcode: websocket.OpPong, + success: false, + }, + { + name: "text", + opcode: websocket.OpText, + success: true, + }, + { + name: "textPing", + opcode: websocket.OpText, + success: true, + pingInBetween: true, + }, + } + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(16) + _, err := c.WriteFrame(ctx, false, tc.opcode, p1) + if err != nil { + return err + } + err = c.BW().Flush() + if err != nil { + return err + } + if !tc.success { + _, _, err = c.Read(ctx) + return assertCloseStatus(err, websocket.StatusProtocolError) + } + + if tc.pingInBetween { + _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) + if err != nil { + return err + } + } + + p2 := randBytes(16) + _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2) + if err != nil { + return err + } + + err = assertReadFrame(ctx, c, tc.opcode, p1) + if err != nil { + return err + } + + if tc.pingInBetween { + err = assertReadFrame(ctx, c, websocket.OpPong, p1) + if err != nil { + return err + } + } + + return assertReadFrame(ctx, c, websocket.OpContinuation, p2) + }) + } + + t.Run("unexpectedContinuation", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + fin bool + textFirst bool + }{ + { + name: "fin", + fin: true, + }, + { + name: "noFin", + fin: false, + }, + { + name: "echoFirst", + fin: false, + textFirst: true, + }, + // The rest of the tests in this section get complicated and do not inspire much confidence. + } + + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + if tc.textFirst { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + return err + } + p1 := randBytes(32) + _, err = w.Write(p1) + if err != nil { + return err + } + p2 := randBytes(32) + _, err = w.Write(p2) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) + if err != nil { + return err + } + } + + _, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32)) + if err != nil { + return err + } + err = c.BW().Flush() + if err != nil { + return err + } + + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + + run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(32) + _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32)) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + + run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(32) + p2 := randBytes(32) + p3 := randBytes(32) + p4 := randBytes(32) + p5 := randBytes(32) + + _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2) + if err != nil { + return err + } + + _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) + if err != nil { + return err + } + + time.Sleep(time.Second) + + _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3) + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4) + if err != nil { + return err + } + + _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) + if err != nil { + return err + } + + _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5) + if err != nil { + return err + } + + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpPong, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p3) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p4) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpPong, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p5) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) + if err != nil { + return err + } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + }) + }) + + // Section 7 + t.Run("closeHandling", func(t *testing.T) { + t.Parallel() + + // 1.1 - 1.4 is useless. + run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(32) + _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) + if err != nil { + return err + } + err = c.Flush() + if err != nil { + return err + } + _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) + }) + + run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error { + // 262144 bytes. + p1 := randBytes(1 << 18) + err := c.Write(ctx, websocket.MessageText, p1) + if err != nil { + return err + } + _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") + if err != nil { + return err + } + err = assertReadMessage(ctx, c, websocket.MessageText, p1) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) + }) + + run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpClose, nil) + if err != nil { + return err + } + return assertReadFrame(ctx, c, websocket.OpClose, []byte{}) + }) + + run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1}) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + + run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + + run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, randString(16)) + }) + + run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, randString(123)) + }) + + run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpClose, + append([]byte{0x03, 0xE8}, randString(124)...), + ) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + + t.Run("validCloses", func(t *testing.T) { + t.Parallel() + + codes := [...]websocket.StatusCode{ + 1000, + 1001, + 1002, + 1003, + 1007, + 1008, + 1009, + 1010, + 1011, + 3000, + 3999, + 4000, + 4999, + } + for _, code := range codes { + run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, code, randString(32)) + }) + } + }) + + t.Run("invalidCloseCodes", func(t *testing.T) { + t.Parallel() + + codes := []websocket.StatusCode{ + 0, + 999, + 1004, + 1005, + 1006, + 1016, + 1100, + 2000, + 2999, + 5000, + 65535, + } + for _, code := range codes { + run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { + p := make([]byte, 2) + binary.BigEndian.PutUint16(p, uint16(code)) + p = append(p, randBytes(32)...) + _, err := c.WriteFrame(ctx, true, websocket.OpClose, p) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } }) }) -} + // Section 9. + t.Run("limits", func(t *testing.T) { + t.Parallel() + + t.Run("unfragmentedEcho", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 1 << 16, // 65536 + 1 << 18, // 262144 + // Anything higher is completely unnecessary. + } + + for _, l := range lengths { + l := l + run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { + return assertEcho(ctx, c, websocket.MessageBinary, l) + }) + } + }) + + t.Run("fragmentedEcho", func(t *testing.T) { + t.Parallel() + + fragments := []int{ + 64, + 256, + 1 << 10, + 1 << 12, + 1 << 14, + 1 << 16, + 1 << 18, + } + + for _, l := range fragments { + fragmentLength := l + run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + return err + } + b := randBytes(1 << 18) + for i := 0; i < len(b); { + j := i + fragmentLength + if j > len(b) { + j = len(b) + } + + _, err = w.Write(b[i:j]) + if err != nil { + return err + } + + i = j + } + err = w.Close() + if err != nil { + return err + } + + err = assertReadMessage(ctx, c, websocket.MessageText, b) + if err != nil { + return err + } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + } + }) + + t.Run("latencyEcho", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 0, + 16, + 64, + } + + for _, l := range lengths { + l := l + run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { + for i := 0; i < 1000; i++ { + err := assertEcho(ctx, c, websocket.MessageBinary, l) + if err != nil { + return err + } + } + return nil + }) + } + }) + }) +} func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") @@ -1269,6 +1913,31 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) (er return assertEqualf(exp, act, "unexpected JSON") } +func randBytes(n int) []byte { + return make([]byte, n) +} + +func randString(n int) string { + return string(randBytes(n)) +} + +func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) error { + p := randBytes(n) + err := c.Write(ctx, typ, p) + if err != nil { + return err + } + typ2, p2, err := c.Read(ctx) + if err != nil { + return err + } + err = assertEqualf(typ, typ2, "unexpected data type") + if err != nil { + return err + } + return assertEqualf(p, p2, "unexpected payload") +} + func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { expType := reflect.TypeOf(exp) actv := reflect.New(expType.Elem()) @@ -1320,13 +1989,29 @@ func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.Op if err != nil { return err } - err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", p) + err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) if err != nil { return err } return assertEqualf(p, actP, "unexpected frame %v payload", opcode) } +func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error { + actOpcode, actP, err := c.ReadFrame(ctx) + if err != nil { + return err + } + err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) + if err != nil { + return err + } + ce, err := websocket.ParseClosePayload(actP) + if err != nil { + return xerrors.Errorf("failed to parse close frame payload: %w", err) + } + return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) +} + func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error { p, err := c.WriteClose(ctx, code, reason) if err != nil { @@ -1335,8 +2020,12 @@ func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket return assertReadFrame(ctx, c, websocket.OpClose, p) } -func streamPing(ctx context.Context, c *websocket.Conn, l int) error { - err := c.WriteHeader(ctx, true, websocket.OpPing, int64(l)) +func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error { + err := c.WriteHeader(ctx, websocket.Header{ + Fin: true, + OpCode: websocket.OpPing, + PayloadLength: int64(l), + }) if err != nil { return err } @@ -1352,3 +2041,15 @@ func streamPing(ctx context.Context, c *websocket.Conn, l int) error { } return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l)) } + +func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error { + actTyp, actP, err := c.Read(ctx) + if err != nil { + return err + } + err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) + if err != nil { + return err + } + return assertEqualf(p, actP, "unexpected frame %v payload", actTyp) +} -- GitLab