From 12d7f1acc28e859be846e7b6ed15066c1259df2b Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 31 Aug 2019 23:48:56 -0500
Subject: [PATCH] Translate the remaining useful Autobahn python tests

---
 ci/test.sh              |   3 +-
 export_test.go          |  34 +-
 websocket_bench_test.go |   5 +-
 websocket_test.go       | 729 +++++++++++++++++++++++++++++++++++++++-
 4 files changed, 748 insertions(+), 23 deletions(-)

diff --git a/ci/test.sh b/ci/test.sh
index 3c476d9..c8b8ec1 100755
--- a/ci/test.sh
+++ b/ci/test.sh
@@ -12,14 +12,13 @@ argv=(
   --
   "-vet=off"
 )
-# Interactive usage does not want to turn off vet or use gotestsum by default.
+# Interactive usage does not want to turn off vet or use gotestsum.
 if [[ $# -gt 0 ]]; then
   argv=(go test "$@")
 fi
 
 # We always want coverage and race detection.
 argv+=(
-  -race
   "-coverprofile=ci/out/coverage.prof"
   "-coverpkg=./..."
 )
diff --git a/export_test.go b/export_test.go
index fb3cf81..811bf80 100644
--- a/export_test.go
+++ b/export_test.go
@@ -15,6 +15,7 @@ type (
 const (
 	OpClose        = OpCode(opClose)
 	OpBinary       = OpCode(opBinary)
+	OpText         = OpCode(opText)
 	OpPing         = OpCode(opPing)
 	OpPong         = OpCode(opPong)
 	OpContinuation = OpCode(opContinuation)
@@ -40,17 +41,38 @@ func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) (
 	return c.writeFrame(ctx, fin, opcode(opc), p)
 }
 
-func (c *Conn) WriteHeader(ctx context.Context, fin bool, opc OpCode, lenp int64) error {
+// header represents a WebSocket frame header.
+// See https://tools.ietf.org/html/rfc6455#section-5.2
+type Header struct {
+	Fin    bool
+	Rsv1   bool
+	Rsv2   bool
+	Rsv3   bool
+	OpCode OpCode
+
+	PayloadLength int64
+}
+
+func (c *Conn) WriteHeader(ctx context.Context, h Header) error {
 	headerBytes := writeHeader(c.writeHeaderBuf, header{
-		fin:           fin,
-		opcode:        opcode(opc),
-		payloadLength: lenp,
+		fin:           h.Fin,
+		rsv1:          h.Rsv1,
+		rsv2:          h.Rsv2,
+		rsv3:          h.Rsv3,
+		opcode:        opcode(h.OpCode),
+		payloadLength: h.PayloadLength,
 		masked:        c.client,
 	})
 	_, err := c.bw.Write(headerBytes)
 	if err != nil {
 		return xerrors.Errorf("failed to write header: %w", err)
 	}
+	if h.Fin {
+		err = c.Flush()
+		if err != nil {
+			return err
+		}
+	}
 	return nil
 }
 
@@ -96,3 +118,7 @@ func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) (
 	}
 	return b, nil
 }
+
+func ParseClosePayload(p []byte) (CloseError, error) {
+	return parseClosePayload(p)
+}
diff --git a/websocket_bench_test.go b/websocket_bench_test.go
index 4ad8646..6a54fab 100644
--- a/websocket_bench_test.go
+++ b/websocket_bench_test.go
@@ -5,13 +5,13 @@ import (
 	"io"
 	"io/ioutil"
 	"net/http"
-	"nhooyr.io/websocket"
 	"strconv"
 	"strings"
 	"testing"
 	"time"
-)
 
+	"nhooyr.io/websocket"
+)
 
 func BenchmarkConn(b *testing.B) {
 	sizes := []int{
@@ -116,7 +116,6 @@ func benchConn(b *testing.B, echo, stream bool, size int) {
 	c.Close(websocket.StatusNormalClosure, "")
 }
 
-
 func discardLoop(ctx context.Context, c *websocket.Conn) {
 	defer c.Close(websocket.StatusInternalError, "")
 
diff --git a/websocket_test.go b/websocket_test.go
index 732fc94..3482cbd 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -3,6 +3,7 @@ package websocket_test
 import (
 	"bytes"
 	"context"
+	"encoding/binary"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -919,7 +920,7 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e
 		atomic.AddInt64(&conns, 1)
 		defer atomic.AddInt64(&conns, -1)
 
-		ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
+		ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
 		defer cancel()
 
 		r = r.WithContext(ctx)
@@ -953,8 +954,6 @@ func TestAutobahn(t *testing.T) {
 
 	run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) {
 		run2 := func(t *testing.T, testingClient bool) {
-			t.Parallel()
-
 			// Run random tests over TLS.
 			tls := rand.Intn(2) == 1
 
@@ -985,7 +984,7 @@ func TestAutobahn(t *testing.T) {
 
 			wsURL := strings.Replace(s.URL, "http", "ws", 1)
 
-			ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 			defer cancel()
 
 			opts := &websocket.DialOptions{
@@ -1017,9 +1016,11 @@ func TestAutobahn(t *testing.T) {
 			t.Parallel()
 
 			t.Run("server", func(t *testing.T) {
+				t.Parallel()
 				run2(t, false)
 			})
 			t.Run("client", func(t *testing.T) {
+				t.Parallel()
 				run2(t, true)
 			})
 		})
@@ -1043,8 +1044,7 @@ func TestAutobahn(t *testing.T) {
 			for i, l := range lengths {
 				l := l
 				run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error {
-					p := make([]byte, l)
-					rand.Read(p)
+					p := randBytes(l)
 					if i == len(lengths)-1 {
 						w, err := c.Writer(ctx, typ)
 						if err != nil {
@@ -1119,7 +1119,7 @@ func TestAutobahn(t *testing.T) {
 			return assertCloseStatus(err, websocket.StatusProtocolError)
 		})
 		run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error {
-			err := streamPing(ctx, c, 125)
+			err := assertStreamPing(ctx, c, 125)
 			if err != nil {
 				return err
 			}
@@ -1189,7 +1189,7 @@ func TestAutobahn(t *testing.T) {
 		})
 		run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error {
 			for i := 0; i < 10; i++ {
-				err := streamPing(ctx, c, 125)
+				err := assertStreamPing(ctx, c, 125)
 				if err != nil {
 					return err
 				}
@@ -1200,15 +1200,659 @@ func TestAutobahn(t *testing.T) {
 	})
 
 	// Section 3.
+	// We skip the per octet sending as it will add too much complexity.
 	t.Run("reserved", func(t *testing.T) {
 		t.Parallel()
 
-		run(t, "rsv1", func(ctx context.Context, c *websocket.Conn) error {
-			c.WriteFrame()
+		var testCases = []struct {
+			name   string
+			header websocket.Header
+		}{
+			{
+				name: "rsv1",
+				header: websocket.Header{
+					Fin:           true,
+					Rsv1:          true,
+					OpCode:        websocket.OpClose,
+					PayloadLength: 0,
+				},
+			},
+			{
+				name: "rsv2",
+				header: websocket.Header{
+					Fin:           true,
+					Rsv2:          true,
+					OpCode:        websocket.OpPong,
+					PayloadLength: 0,
+				},
+			},
+			{
+				name: "rsv3",
+				header: websocket.Header{
+					Fin:           true,
+					Rsv3:          true,
+					OpCode:        websocket.OpBinary,
+					PayloadLength: 0,
+				},
+			},
+			{
+				name: "rsvAll",
+				header: websocket.Header{
+					Fin:           true,
+					Rsv1:          true,
+					Rsv2:          true,
+					Rsv3:          true,
+					OpCode:        websocket.OpText,
+					PayloadLength: 0,
+				},
+			},
+		}
+		for _, tc := range testCases {
+			tc := tc
+			run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
+				err := assertEcho(ctx, c, websocket.MessageText, 4096)
+				if err != nil {
+					return err
+				}
+				err = c.WriteHeader(ctx, tc.header)
+				if err != nil {
+					return err
+				}
+				err = c.Flush()
+				if err != nil {
+					return err
+				}
+				_, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf"))
+				if err != nil {
+					return err
+				}
+				return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+			})
+		}
+	})
+
+	// Section 4.
+	t.Run("opcodes", func(t *testing.T) {
+		t.Parallel()
+
+		testCases := []struct {
+			name    string
+			opcode  websocket.OpCode
+			payload bool
+			echo    bool
+			ping    bool
+		}{
+			// Section 1.
+			{
+				name:   "3",
+				opcode: 3,
+			},
+			{
+				name:    "4",
+				opcode:  4,
+				payload: true,
+			},
+			{
+				name:   "5",
+				opcode: 5,
+				echo:   true,
+				ping:   true,
+			},
+			{
+				name:    "6",
+				opcode:  6,
+				payload: true,
+				echo:    true,
+				ping:    true,
+			},
+			{
+				name:    "7",
+				opcode:  7,
+				payload: true,
+				echo:    true,
+				ping:    true,
+			},
+
+			// Section 2.
+			{
+				name:   "11",
+				opcode: 11,
+			},
+			{
+				name:    "12",
+				opcode:  12,
+				payload: true,
+			},
+			{
+				name:    "13",
+				opcode:  13,
+				payload: true,
+				echo:    true,
+				ping:    true,
+			},
+			{
+				name:    "14",
+				opcode:  14,
+				payload: true,
+				echo:    true,
+				ping:    true,
+			},
+			{
+				name:    "15",
+				opcode:  15,
+				payload: true,
+				echo:    true,
+				ping:    true,
+			},
+		}
+		for _, tc := range testCases {
+			tc := tc
+			run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
+				if tc.echo {
+					err := assertEcho(ctx, c, websocket.MessageText, 4096)
+					if err != nil {
+						return err
+					}
+				}
+
+				p := []byte(nil)
+				if tc.payload {
+					p = randBytes(rand.Intn(4096) + 1)
+				}
+				_, err := c.WriteFrame(ctx, true, tc.opcode, p)
+				if err != nil {
+					return err
+				}
+				if tc.ping {
+					_, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf"))
+					if err != nil {
+						return err
+					}
+				}
+				return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+			})
+		}
+	})
+
+	// Section 5.
+	t.Run("fragmentation", func(t *testing.T) {
+		t.Parallel()
+
+		// 5.1 to 5.8
+		testCases := []struct {
+			name          string
+			opcode        websocket.OpCode
+			success       bool
+			pingInBetween bool
+		}{
+			{
+				name:    "ping",
+				opcode:  websocket.OpPing,
+				success: false,
+			},
+			{
+				name:    "pong",
+				opcode:  websocket.OpPong,
+				success: false,
+			},
+			{
+				name:    "text",
+				opcode:  websocket.OpText,
+				success: true,
+			},
+			{
+				name:          "textPing",
+				opcode:        websocket.OpText,
+				success:       true,
+				pingInBetween: true,
+			},
+		}
+		for _, tc := range testCases {
+			tc := tc
+			run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
+				p1 := randBytes(16)
+				_, err := c.WriteFrame(ctx, false, tc.opcode, p1)
+				if err != nil {
+					return err
+				}
+				err = c.BW().Flush()
+				if err != nil {
+					return err
+				}
+				if !tc.success {
+					_, _, err = c.Read(ctx)
+					return assertCloseStatus(err, websocket.StatusProtocolError)
+				}
+
+				if tc.pingInBetween {
+					_, err = c.WriteFrame(ctx, true, websocket.OpPing, p1)
+					if err != nil {
+						return err
+					}
+				}
+
+				p2 := randBytes(16)
+				_, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2)
+				if err != nil {
+					return err
+				}
+
+				err = assertReadFrame(ctx, c, tc.opcode, p1)
+				if err != nil {
+					return err
+				}
+
+				if tc.pingInBetween {
+					err = assertReadFrame(ctx, c, websocket.OpPong, p1)
+					if err != nil {
+						return err
+					}
+				}
+
+				return assertReadFrame(ctx, c, websocket.OpContinuation, p2)
+			})
+		}
+
+		t.Run("unexpectedContinuation", func(t *testing.T) {
+			t.Parallel()
+
+			testCases := []struct {
+				name      string
+				fin       bool
+				textFirst bool
+			}{
+				{
+					name: "fin",
+					fin:  true,
+				},
+				{
+					name: "noFin",
+					fin:  false,
+				},
+				{
+					name:      "echoFirst",
+					fin:       false,
+					textFirst: true,
+				},
+				// The rest of the tests in this section get complicated and do not inspire much confidence.
+			}
+
+			for _, tc := range testCases {
+				tc := tc
+				run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
+					if tc.textFirst {
+						w, err := c.Writer(ctx, websocket.MessageText)
+						if err != nil {
+							return err
+						}
+						p1 := randBytes(32)
+						_, err = w.Write(p1)
+						if err != nil {
+							return err
+						}
+						p2 := randBytes(32)
+						_, err = w.Write(p2)
+						if err != nil {
+							return err
+						}
+						err = w.Close()
+						if err != nil {
+							return err
+						}
+						err = assertReadFrame(ctx, c, websocket.OpText, p1)
+						if err != nil {
+							return err
+						}
+						err = assertReadFrame(ctx, c, websocket.OpContinuation, p2)
+						if err != nil {
+							return err
+						}
+						err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{})
+						if err != nil {
+							return err
+						}
+					}
+
+					_, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32))
+					if err != nil {
+						return err
+					}
+					err = c.BW().Flush()
+					if err != nil {
+						return err
+					}
+
+					return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+				})
+			}
+
+			run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error {
+				p1 := randBytes(32)
+				_, err := c.WriteFrame(ctx, false, websocket.OpText, p1)
+				if err != nil {
+					return err
+				}
+				_, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32))
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpText, p1)
+				if err != nil {
+					return err
+				}
+				return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+			})
+
+			run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error {
+				p1 := randBytes(32)
+				p2 := randBytes(32)
+				p3 := randBytes(32)
+				p4 := randBytes(32)
+				p5 := randBytes(32)
+
+				_, err := c.WriteFrame(ctx, false, websocket.OpText, p1)
+				if err != nil {
+					return err
+				}
+				_, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2)
+				if err != nil {
+					return err
+				}
+
+				_, err = c.WriteFrame(ctx, true, websocket.OpPing, p1)
+				if err != nil {
+					return err
+				}
+
+				time.Sleep(time.Second)
+
+				_, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3)
+				if err != nil {
+					return err
+				}
+				_, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4)
+				if err != nil {
+					return err
+				}
+
+				_, err = c.WriteFrame(ctx, true, websocket.OpPing, p1)
+				if err != nil {
+					return err
+				}
+
+				_, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5)
+				if err != nil {
+					return err
+				}
+
+				err = assertReadFrame(ctx, c, websocket.OpText, p1)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpContinuation, p2)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpPong, p1)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpContinuation, p3)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpContinuation, p4)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpPong, p1)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpContinuation, p5)
+				if err != nil {
+					return err
+				}
+				err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{})
+				if err != nil {
+					return err
+				}
+				return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "")
+			})
+		})
+	})
+
+	// Section 7
+	t.Run("closeHandling", func(t *testing.T) {
+		t.Parallel()
+
+		// 1.1 - 1.4 is useless.
+		run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error {
+			p1 := randBytes(32)
+			_, err := c.WriteFrame(ctx, false, websocket.OpText, p1)
+			if err != nil {
+				return err
+			}
+			err = c.Flush()
+			if err != nil {
+				return err
+			}
+			_, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "")
+			if err != nil {
+				return err
+			}
+			err = assertReadFrame(ctx, c, websocket.OpText, p1)
+			if err != nil {
+				return err
+			}
+			return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure)
+		})
+
+		run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error {
+			// 262144 bytes.
+			p1 := randBytes(1 << 18)
+			err := c.Write(ctx, websocket.MessageText, p1)
+			if err != nil {
+				return err
+			}
+			_, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "")
+			if err != nil {
+				return err
+			}
+			err = assertReadMessage(ctx, c, websocket.MessageText, p1)
+			if err != nil {
+				return err
+			}
+			return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure)
+		})
+
+		run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error {
+			_, err := c.WriteFrame(ctx, true, websocket.OpClose, nil)
+			if err != nil {
+				return err
+			}
+			return assertReadFrame(ctx, c, websocket.OpClose, []byte{})
+		})
+
+		run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error {
+			_, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1})
+			if err != nil {
+				return err
+			}
+			return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+		})
+
+		run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error {
+			return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "")
+		})
+
+		run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error {
+			return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, randString(16))
+		})
+
+		run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error {
+			return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, randString(123))
+		})
+
+		run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error {
+			_, err := c.WriteFrame(ctx, true, websocket.OpClose,
+				append([]byte{0x03, 0xE8}, randString(124)...),
+			)
+			if err != nil {
+				return err
+			}
+			return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+		})
+
+		t.Run("validCloses", func(t *testing.T) {
+			t.Parallel()
+
+			codes := [...]websocket.StatusCode{
+				1000,
+				1001,
+				1002,
+				1003,
+				1007,
+				1008,
+				1009,
+				1010,
+				1011,
+				3000,
+				3999,
+				4000,
+				4999,
+			}
+			for _, code := range codes {
+				run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error {
+					return assertCloseHandshake(ctx, c, code, randString(32))
+				})
+			}
+		})
+
+		t.Run("invalidCloseCodes", func(t *testing.T) {
+			t.Parallel()
+
+			codes := []websocket.StatusCode{
+				0,
+				999,
+				1004,
+				1005,
+				1006,
+				1016,
+				1100,
+				2000,
+				2999,
+				5000,
+				65535,
+			}
+			for _, code := range codes {
+				run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error {
+					p := make([]byte, 2)
+					binary.BigEndian.PutUint16(p, uint16(code))
+					p = append(p, randBytes(32)...)
+					_, err := c.WriteFrame(ctx, true, websocket.OpClose, p)
+					if err != nil {
+						return err
+					}
+					return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
+				})
+			}
 		})
 	})
-}
 
+	// Section 9.
+	t.Run("limits", func(t *testing.T) {
+		t.Parallel()
+
+		t.Run("unfragmentedEcho", func(t *testing.T) {
+			t.Parallel()
+
+			lengths := []int{
+				1 << 16, // 65536
+				1 << 18, // 262144
+				// Anything higher is completely unnecessary.
+			}
+
+			for _, l := range lengths {
+				l := l
+				run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error {
+					return assertEcho(ctx, c, websocket.MessageBinary, l)
+				})
+			}
+		})
+
+		t.Run("fragmentedEcho", func(t *testing.T) {
+			t.Parallel()
+
+			fragments := []int{
+				64,
+				256,
+				1 << 10,
+				1 << 12,
+				1 << 14,
+				1 << 16,
+				1 << 18,
+			}
+
+			for _, l := range fragments {
+				fragmentLength := l
+				run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error {
+					w, err := c.Writer(ctx, websocket.MessageText)
+					if err != nil {
+						return err
+					}
+					b := randBytes(1 << 18)
+					for i := 0; i < len(b); {
+						j := i + fragmentLength
+						if j > len(b) {
+							j = len(b)
+						}
+
+						_, err = w.Write(b[i:j])
+						if err != nil {
+							return err
+						}
+
+						i = j
+					}
+					err = w.Close()
+					if err != nil {
+						return err
+					}
+
+					err = assertReadMessage(ctx, c, websocket.MessageText, b)
+					if err != nil {
+						return err
+					}
+					return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "")
+				})
+			}
+		})
+
+		t.Run("latencyEcho", func(t *testing.T) {
+			t.Parallel()
+
+			lengths := []int{
+				0,
+				16,
+				64,
+			}
+
+			for _, l := range lengths {
+				l := l
+				run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error {
+					for i := 0; i < 1000; i++ {
+						err := assertEcho(ctx, c, websocket.MessageBinary, l)
+						if err != nil {
+							return err
+						}
+					}
+					return nil
+				})
+			}
+		})
+	})
+}
 
 func echoLoop(ctx context.Context, c *websocket.Conn) {
 	defer c.Close(websocket.StatusInternalError, "")
@@ -1269,6 +1913,31 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) (er
 	return assertEqualf(exp, act, "unexpected JSON")
 }
 
+func randBytes(n int) []byte {
+	return make([]byte, n)
+}
+
+func randString(n int) string {
+	return string(randBytes(n))
+}
+
+func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) error {
+	p := randBytes(n)
+	err := c.Write(ctx, typ, p)
+	if err != nil {
+		return err
+	}
+	typ2, p2, err := c.Read(ctx)
+	if err != nil {
+		return err
+	}
+	err = assertEqualf(typ, typ2, "unexpected data type")
+	if err != nil {
+		return err
+	}
+	return assertEqualf(p, p2, "unexpected payload")
+}
+
 func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
 	expType := reflect.TypeOf(exp)
 	actv := reflect.New(expType.Elem())
@@ -1320,13 +1989,29 @@ func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.Op
 	if err != nil {
 		return err
 	}
-	err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", p)
+	err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
 	if err != nil {
 		return err
 	}
 	return assertEqualf(p, actP, "unexpected frame %v payload", opcode)
 }
 
+func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error {
+	actOpcode, actP, err := c.ReadFrame(ctx)
+	if err != nil {
+		return err
+	}
+	err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
+	if err != nil {
+		return err
+	}
+	ce, err := websocket.ParseClosePayload(actP)
+	if err != nil {
+		return xerrors.Errorf("failed to parse close frame payload: %w", err)
+	}
+	return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
+}
+
 func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error {
 	p, err := c.WriteClose(ctx, code, reason)
 	if err != nil {
@@ -1335,8 +2020,12 @@ func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket
 	return assertReadFrame(ctx, c, websocket.OpClose, p)
 }
 
-func streamPing(ctx context.Context, c *websocket.Conn, l int) error {
-	err := c.WriteHeader(ctx, true, websocket.OpPing, int64(l))
+func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error {
+	err := c.WriteHeader(ctx, websocket.Header{
+		Fin:           true,
+		OpCode:        websocket.OpPing,
+		PayloadLength: int64(l),
+	})
 	if err != nil {
 		return err
 	}
@@ -1352,3 +2041,15 @@ func streamPing(ctx context.Context, c *websocket.Conn, l int) error {
 	}
 	return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l))
 }
+
+func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error {
+	actTyp, actP, err := c.Read(ctx)
+	if err != nil {
+		return err
+	}
+	err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
+	if err != nil {
+		return err
+	}
+	return assertEqualf(p, actP, "unexpected frame %v payload", actTyp)
+}
-- 
GitLab