From 679ddb825d5cd5ce4cc7136734fff5effe3a2910 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Thu, 29 Aug 2019 15:37:26 -0500
Subject: [PATCH] Drastically improve non autobahn test coverage

Also simplified and refactored the Conn tests.

More changes soon.
---
 accept_test.go     |  33 ++
 ci/test.sh         |  31 +-
 dial_test.go       |   9 +-
 export_test.go     |  12 +-
 header_test.go     |  31 ++
 netconn.go         |   4 +-
 statuscode.go      |   2 +-
 statuscode_test.go | 108 ++++++-
 websocket.go       |  56 ++--
 websocket_test.go  | 781 ++++++++++++++++++++++++++++++---------------
 10 files changed, 761 insertions(+), 306 deletions(-)

diff --git a/accept_test.go b/accept_test.go
index 6f5c3fb..8634066 100644
--- a/accept_test.go
+++ b/accept_test.go
@@ -6,6 +6,39 @@ import (
 	"testing"
 )
 
+func TestAccept(t *testing.T) {
+	t.Parallel()
+
+	t.Run("badClientHandshake", func(t *testing.T) {
+		t.Parallel()
+
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest("GET", "/", nil)
+
+		_, err := Accept(w, r, AcceptOptions{})
+		if err == nil {
+			t.Fatalf("unexpected error value: %v", err)
+		}
+
+	})
+
+	t.Run("requireHttpHijacker", func(t *testing.T) {
+		t.Parallel()
+
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest("GET", "/", nil)
+		r.Header.Set("Connection", "Upgrade")
+		r.Header.Set("Upgrade", "websocket")
+		r.Header.Set("Sec-WebSocket-Version", "13")
+		r.Header.Set("Sec-WebSocket-Key", "meow123")
+
+		_, err := Accept(w, r, AcceptOptions{})
+		if err == nil || !strings.Contains(err.Error(), "http.Hijacker") {
+			t.Fatalf("unexpected error value: %v", err)
+		}
+	})
+}
+
 func Test_verifyClientHandshake(t *testing.T) {
 	t.Parallel()
 
diff --git a/ci/test.sh b/ci/test.sh
index 875216f..1d4a8b0 100755
--- a/ci/test.sh
+++ b/ci/test.sh
@@ -4,19 +4,34 @@ set -euo pipefail
 cd "$(dirname "${0}")"
 cd "$(git rev-parse --show-toplevel)"
 
-mkdir -p ci/out/websocket
-testFlags=(
+argv=(
+  go run gotest.tools/gotestsum
+  # https://circleci.com/docs/2.0/collect-test-data/
+  "--junitfile=ci/out/websocket/testReport.xml"
+  "--format=short-verbose"
+  --
   -race
   "-vet=off"
-  #  "-bench=."
+  "-bench=."
+)
+# Interactive usage probably does not want to enable benchmarks, race detection
+# turn off vet or use gotestsum by default.
+if [[ $# -gt 0 ]]; then
+  argv=(go test "$@")
+fi
+
+# We always want coverage.
+argv+=(
   "-coverprofile=ci/out/coverage.prof"
   "-coverpkg=./..."
 )
-# https://circleci.com/docs/2.0/collect-test-data/
-go run gotest.tools/gotestsum \
-  --junitfile ci/out/websocket/testReport.xml \
-  --format=short-verbose \
-  -- "${testFlags[@]}"
+
+mkdir -p ci/out/websocket
+"${argv[@]}"
+
+# Removes coverage of generated files.
+grep -v _string.go < ci/out/coverage.prof > ci/out/coverage2.prof
+mv ci/out/coverage2.prof ci/out/coverage.prof
 
 go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
 if [[ ${CI:-} ]]; then
diff --git a/dial_test.go b/dial_test.go
index 6400c22..4607493 100644
--- a/dial_test.go
+++ b/dial_test.go
@@ -33,6 +33,10 @@ func TestBadDials(t *testing.T) {
 				},
 			},
 		},
+		{
+			name: "badTLS",
+			url:  "wss://totallyfake.nhooyr.io",
+		},
 	}
 
 	for _, tc := range testCases {
@@ -40,7 +44,10 @@ func TestBadDials(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
 			t.Parallel()
 
-			_, _, err := Dial(context.Background(), tc.url, tc.opts)
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+			defer cancel()
+
+			_, _, err := Dial(ctx, tc.url, tc.opts)
 			if err == nil {
 				t.Fatalf("expected non nil error: %+v", err)
 			}
diff --git a/export_test.go b/export_test.go
index 22ad76f..ab766f1 100644
--- a/export_test.go
+++ b/export_test.go
@@ -1,3 +1,13 @@
 package websocket
 
-var Compute = handleSecWebSocketKey
+import (
+	"context"
+)
+
+type Addr = websocketAddr
+
+type Header = header
+
+func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
+	return c.writeFrame(ctx, fin, opcode, p)
+}
diff --git a/header_test.go b/header_test.go
index 4457c35..45d0535 100644
--- a/header_test.go
+++ b/header_test.go
@@ -2,6 +2,7 @@ package websocket
 
 import (
 	"bytes"
+	"io"
 	"math/rand"
 	"strconv"
 	"testing"
@@ -21,6 +22,36 @@ func randBool() bool {
 func TestHeader(t *testing.T) {
 	t.Parallel()
 
+	t.Run("eof", func(t *testing.T) {
+		t.Parallel()
+
+		testCases := []struct {
+			name  string
+			bytes []byte
+		}{
+			{
+				"start",
+				[]byte{0xff},
+			},
+			{
+				"middle",
+				[]byte{0xff, 0xff, 0xff},
+			},
+		}
+		for _, tc := range testCases {
+			tc := tc
+			t.Run(tc.name, func(t *testing.T) {
+				t.Parallel()
+
+				b := bytes.NewBuffer(tc.bytes)
+				_, err := readHeader(nil, b)
+				if io.ErrUnexpectedEOF != err {
+					t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err)
+				}
+			})
+		}
+	})
+
 	t.Run("writeNegativeLength", func(t *testing.T) {
 		t.Parallel()
 
diff --git a/netconn.go b/netconn.go
index d28eeb8..a6f902d 100644
--- a/netconn.go
+++ b/netconn.go
@@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) {
 			return 0, err
 		}
 		if typ != c.msgType {
-			c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType))
-			return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ)
+			c.c.Close(StatusUnsupportedData, fmt.Sprintf("unexpected frame type read (expected %v): %v", c.msgType, typ))
+			return 0, c.c.closeErr
 		}
 		c.reader = r
 	}
diff --git a/statuscode.go b/statuscode.go
index 42ae40c..498437d 100644
--- a/statuscode.go
+++ b/statuscode.go
@@ -35,7 +35,7 @@ const (
 	StatusTryAgainLater
 	StatusBadGateway
 	// statusTLSHandshake is unexported because we just return
-	// handshake error in dial. We do not return a conn
+	// the handshake error in dial. We do not return a conn
 	// so there is nothing to use this on. At least until WASM.
 	statusTLSHandshake
 )
diff --git a/statuscode_test.go b/statuscode_test.go
index 38ee4c3..b963786 100644
--- a/statuscode_test.go
+++ b/statuscode_test.go
@@ -4,14 +4,13 @@ import (
 	"math"
 	"strings"
 	"testing"
+
+	"github.com/google/go-cmp/cmp"
 )
 
 func TestCloseError(t *testing.T) {
 	t.Parallel()
 
-	// Other parts of close error are tested by websocket_test.go right now
-	// with the autobahn tests.
-
 	testCases := []struct {
 		name    string
 		ce      CloseError
@@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) {
 
 			_, err := tc.ce.bytes()
 			if (err == nil) != tc.success {
-				t.Fatalf("unexpected error value: %v", err)
+				t.Fatalf("unexpected error value: %+v", err)
+			}
+		})
+	}
+}
+
+func Test_parseClosePayload(t *testing.T) {
+	t.Parallel()
+
+	testCases := []struct {
+		name    string
+		p       []byte
+		success bool
+		ce      CloseError
+	}{
+		{
+			name:    "normal",
+			p:       append([]byte{0x3, 0xE8}, []byte("hello")...),
+			success: true,
+			ce: CloseError{
+				Code:   StatusNormalClosure,
+				Reason: "hello",
+			},
+		},
+		{
+			name:    "nothing",
+			success: true,
+			ce: CloseError{
+				Code: StatusNoStatusRcvd,
+			},
+		},
+		{
+			name:    "oneByte",
+			p:       []byte{0},
+			success: false,
+		},
+		{
+			name:    "badStatusCode",
+			p:       []byte{0x17, 0x70},
+			success: false,
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			ce, err := parseClosePayload(tc.p)
+			if (err == nil) != tc.success {
+				t.Fatalf("unexpected expected error value: %+v", err)
+			}
+
+			if tc.success && tc.ce != ce {
+				t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce))
+			}
+		})
+	}
+}
+
+func Test_validWireCloseCode(t *testing.T) {
+	t.Parallel()
+
+	testCases := []struct {
+		name  string
+		code  StatusCode
+		valid bool
+	}{
+		{
+			name:  "normal",
+			code:  StatusNormalClosure,
+			valid: true,
+		},
+		{
+			name:  "noStatus",
+			code:  StatusNoStatusRcvd,
+			valid: false,
+		},
+		{
+			name:  "3000",
+			code:  3000,
+			valid: true,
+		},
+		{
+			name:  "4999",
+			code:  4999,
+			valid: true,
+		},
+		{
+			name:  "unknown",
+			code:  5000,
+			valid: false,
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			if valid := validWireCloseCode(tc.code); tc.valid != valid {
+				t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid)
 			}
 		})
 	}
diff --git a/websocket.go b/websocket.go
index 393ea54..833c120 100644
--- a/websocket.go
+++ b/websocket.go
@@ -7,8 +7,8 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
 	"math/rand"
-	"os"
 	"runtime"
 	"strconv"
 	"sync"
@@ -210,9 +210,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
 		}
 
 		if h.rsv1 || h.rsv2 || h.rsv3 {
-			err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
-			c.Close(StatusProtocolError, err.Error())
-			return header{}, err
+			c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
+			return header{}, c.closeErr
 		}
 
 		if h.opcode.controlOp() {
@@ -227,9 +226,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
 		case opBinary, opText, opContinuation:
 			return h, nil
 		default:
-			err := xerrors.Errorf("received unknown opcode %v", h.opcode)
-			c.Close(StatusProtocolError, err.Error())
-			return header{}, err
+			c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode))
+			return header{}, c.closeErr
 		}
 	}
 }
@@ -273,15 +271,13 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
 
 func (c *Conn) handleControl(ctx context.Context, h header) error {
 	if h.payloadLength > maxControlFramePayload {
-		err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength)
-		c.Close(StatusProtocolError, err.Error())
-		return err
+		c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength))
+		return c.closeErr
 	}
 
 	if !h.fin {
-		err := xerrors.Errorf("received fragmented control frame")
-		c.Close(StatusProtocolError, err.Error())
-		return err
+		c.Close(StatusProtocolError, "received fragmented control frame")
+		return c.closeErr
 	}
 
 	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
@@ -311,8 +307,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 	case opClose:
 		ce, err := parseClosePayload(b)
 		if err != nil {
-			c.Close(StatusProtocolError, "received invalid close payload")
-			return xerrors.Errorf("received invalid close payload: %w", err)
+			err = xerrors.Errorf("received invalid close payload: %w", err)
+			c.Close(StatusProtocolError, err.Error())
+			return c.closeErr
 		}
 		// This ensures the closeErr of the Conn is always the received CloseError
 		// in case the echo close frame write fails.
@@ -376,9 +373,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 
 	if c.activeReader != nil && !c.activeReader.eof() {
 		if h.opcode != opContinuation {
-			err := xerrors.Errorf("received new data message without finishing the previous message")
-			c.Close(StatusProtocolError, err.Error())
-			return 0, nil, err
+			c.Close(StatusProtocolError, "received new data message without finishing the previous message")
+			return 0, nil, c.closeErr
 		}
 
 		if !h.fin || h.payloadLength > 0 {
@@ -392,9 +388,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 			return 0, nil, err
 		}
 	} else if h.opcode == opContinuation {
-		err := xerrors.Errorf("received continuation frame not after data or text frame")
-		c.Close(StatusProtocolError, err.Error())
-		return 0, nil, err
+		c.Close(StatusProtocolError, "received continuation frame not after data or text frame")
+		return 0, nil, c.closeErr
 	}
 
 	c.readerMsgCtx = ctx
@@ -460,9 +455,8 @@ func (r *messageReader) read(p []byte) (int, error) {
 	}
 
 	if r.c.readMsgLeft <= 0 {
-		err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit)
-		r.c.Close(StatusMessageTooBig, err.Error())
-		return 0, err
+		r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit))
+		return 0, r.c.closeErr
 	}
 
 	if int64(len(p)) > r.c.readMsgLeft {
@@ -476,9 +470,8 @@ func (r *messageReader) read(p []byte) (int, error) {
 		}
 
 		if h.opcode != opContinuation {
-			err := xerrors.Errorf("received new data message without finishing the previous message")
-			r.c.Close(StatusProtocolError, err.Error())
-			return 0, err
+			r.c.Close(StatusProtocolError, "received new data message without finishing the previous message")
+			return 0, r.c.closeErr
 		}
 
 		r.c.readerMsgHeader = h
@@ -828,7 +821,7 @@ func (c *Conn) writePong(p []byte) error {
 func (c *Conn) Close(code StatusCode, reason string) error {
 	err := c.exportedClose(code, reason)
 	if err != nil {
-		return xerrors.Errorf("failed to close connection: %w", err)
+		return xerrors.Errorf("failed to close websocket connection: %w", err)
 	}
 	return nil
 }
@@ -844,7 +837,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
 	// Definitely worth seeing what popular browsers do later.
 	p, err := ce.bytes()
 	if err != nil {
-		fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err)
+		log.Printf("websocket: failed to marshal close frame: %+v", err)
 		ce = CloseError{
 			Code: StatusInternalError,
 		}
@@ -853,12 +846,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
 
 	// CloseErrors sent are made opaque to prevent applications from thinking
 	// they received a given status.
-	err = c.writeClose(p, xerrors.Errorf("sent close frame: %v", ce))
+	sentErr := xerrors.Errorf("sent close frame: %v", ce)
+	err = c.writeClose(p, sentErr)
 	if err != nil {
 		return err
 	}
 
-	if !xerrors.Is(c.closeErr, ce) {
+	if !xerrors.Is(c.closeErr, sentErr) {
 		return c.closeErr
 	}
 
diff --git a/websocket_test.go b/websocket_test.go
index 2ef25cd..b45f024 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -4,8 +4,11 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/ptypes/timestamp"
 	"io"
 	"io/ioutil"
+	"math/rand"
 	"net"
 	"net/http"
 	"net/http/cookiejar"
@@ -75,127 +78,6 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 		},
-		{
-			name: "closeError",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				err = wsjson.Write(r.Context(), c, "hello")
-				if err != nil {
-					return err
-				}
-
-				return nil
-			},
-			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
-					Subprotocols: []string{"meow"},
-				})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				var m string
-				err = wsjson.Read(ctx, c, &m)
-				if err != nil {
-					return err
-				}
-
-				if m != "hello" {
-					return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m)
-				}
-
-				_, _, err = c.Reader(ctx)
-				var cerr websocket.CloseError
-				if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError {
-					return xerrors.Errorf("unexpected error: %+v", err)
-				}
-
-				return nil
-			},
-		},
-		{
-			name: "netConn",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				nc := websocket.NetConn(c, websocket.MessageBinary)
-				defer nc.Close()
-
-				nc.SetWriteDeadline(time.Time{})
-				time.Sleep(1)
-				nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
-
-				for i := 0; i < 3; i++ {
-					_, err = nc.Write([]byte("hello"))
-					if err != nil {
-						return err
-					}
-				}
-
-				return nil
-			},
-			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
-					Subprotocols: []string{"meow"},
-				})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				nc := websocket.NetConn(c, websocket.MessageBinary)
-				defer nc.Close()
-
-				nc.SetReadDeadline(time.Time{})
-				time.Sleep(1)
-				nc.SetReadDeadline(time.Now().Add(time.Second * 15))
-
-				read := func() error {
-					p := make([]byte, len("hello"))
-					// We do not use io.ReadFull here as it masks EOFs.
-					// See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024
-					_, err = nc.Read(p)
-					if err != nil {
-						return err
-					}
-
-					if string(p) != "hello" {
-						return xerrors.Errorf("unexpected payload %q received", string(p))
-					}
-					return nil
-				}
-
-				for i := 0; i < 3; i++ {
-					err = read()
-					if err != nil {
-						return err
-					}
-				}
-
-				// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
-				err = read()
-				if err != io.EOF {
-					return err
-				}
-
-				err = read()
-				if err != io.EOF {
-					return err
-				}
-
-				return nil
-			},
-		},
 		{
 			name: "defaultSubprotocol",
 			server: func(w http.ResponseWriter, r *http.Request) error {
@@ -323,22 +205,240 @@ func TestHandshake(t *testing.T) {
 				if err != nil {
 					return err
 				}
-				defer c.Close(websocket.StatusInternalError, "")
+				defer c.Close(websocket.StatusInternalError, "")
+				return nil
+			},
+		},
+		{
+			name: "cookies",
+			server: func(w http.ResponseWriter, r *http.Request) error {
+				cookie, err := r.Cookie("mycookie")
+				if err != nil {
+					return xerrors.Errorf("request is missing mycookie: %w", err)
+				}
+				if cookie.Value != "myvalue" {
+					return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
+				}
+				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				if err != nil {
+					return err
+				}
+				c.Close(websocket.StatusInternalError, "")
+				return nil
+			},
+			client: func(ctx context.Context, u string) error {
+				jar, err := cookiejar.New(nil)
+				if err != nil {
+					return xerrors.Errorf("failed to create cookie jar: %w", err)
+				}
+				parsedURL, err := url.Parse(u)
+				if err != nil {
+					return xerrors.Errorf("failed to parse url: %w", err)
+				}
+				parsedURL.Scheme = "http"
+				jar.SetCookies(parsedURL, []*http.Cookie{
+					{
+						Name:  "mycookie",
+						Value: "myvalue",
+					},
+				})
+				hc := &http.Client{
+					Jar: jar,
+				}
+				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+					HTTPClient: hc,
+				})
+				if err != nil {
+					return err
+				}
+				c.Close(websocket.StatusInternalError, "")
+				return nil
+			},
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			s, closeFn := testServer(t, tc.server, false)
+			defer closeFn()
+
+			wsURL := strings.Replace(s.URL, "http", "ws", 1)
+
+			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+			defer cancel()
+
+			err := tc.client(ctx, wsURL)
+			if err != nil {
+				t.Fatalf("client failed: %+v", err)
+			}
+		})
+	}
+}
+
+func TestConn(t *testing.T) {
+	t.Parallel()
+
+	testCases := []struct {
+		name   string
+		client func(ctx context.Context, c *websocket.Conn) error
+		server func(ctx context.Context, c *websocket.Conn) error
+	}{
+		{
+			name: "closeError",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				return wsjson.Write(ctx, c, "hello")
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				var m string
+				err := wsjson.Read(ctx, c, &m)
+				if err != nil {
+					return err
+				}
+
+				if m != "hello" {
+					return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m)
+				}
+
+				_, _, err = c.Reader(ctx)
+				var cerr websocket.CloseError
+				if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError {
+					return xerrors.Errorf("unexpected error: %+v", err)
+				}
+
+				return nil
+			},
+		},
+		{
+			name: "netConn",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				nc := websocket.NetConn(c, websocket.MessageBinary)
+				defer nc.Close()
+
+				nc.SetWriteDeadline(time.Time{})
+				time.Sleep(1)
+				nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
+
+				if nc.LocalAddr() != (websocket.Addr{}) {
+					return xerrors.Errorf("net conn local address is not equal to websocket.Addr")
+				}
+				if nc.RemoteAddr() != (websocket.Addr{}) {
+					return xerrors.Errorf("net conn remote address is not equal to websocket.Addr")
+				}
+
+				for i := 0; i < 3; i++ {
+					_, err := nc.Write([]byte("hello"))
+					if err != nil {
+						return err
+					}
+				}
+
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				nc := websocket.NetConn(c, websocket.MessageBinary)
+				defer nc.Close()
+
+				nc.SetReadDeadline(time.Time{})
+				time.Sleep(1)
+				nc.SetReadDeadline(time.Now().Add(time.Second * 15))
+
+				read := func() error {
+					p := make([]byte, len("hello"))
+					// We do not use io.ReadFull here as it masks EOFs.
+					// See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024
+					_, err := nc.Read(p)
+					if err != nil {
+						return err
+					}
+
+					if string(p) != "hello" {
+						return xerrors.Errorf("unexpected payload %q received", string(p))
+					}
+					return nil
+				}
+
+				for i := 0; i < 3; i++ {
+					err := read()
+					if err != nil {
+						return err
+					}
+				}
+
+				// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
+				err := read()
+				if err != io.EOF {
+					return err
+				}
+
+				err = read()
+				if err != io.EOF {
+					return err
+				}
+
+				return nil
+			},
+		},
+		{
+			name: "netConn/badReadMsgType",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				nc := websocket.NetConn(c, websocket.MessageBinary)
+				defer nc.Close()
+
+				nc.SetDeadline(time.Now().Add(time.Second * 15))
+
+				_, err := nc.Read(make([]byte, 1))
+				if err == nil {
+					return xerrors.Errorf("expected error")
+				}
+
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				err := wsjson.Write(ctx, c, "meow")
+				if err != nil {
+					return err
+				}
+
+				_, _, err = c.Read(ctx)
+				cerr := &websocket.CloseError{}
+				if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusUnsupportedData {
+					return xerrors.Errorf("expected close error with code StatusUnsupportedData: %+v", err)
+				}
+
+				return nil
+			},
+		},
+		{
+			name: "netConn/badRead",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				nc := websocket.NetConn(c, websocket.MessageBinary)
+				defer nc.Close()
+
+				nc.SetDeadline(time.Now().Add(time.Second * 15))
+
+				_, err := nc.Read(make([]byte, 1))
+				cerr := &websocket.CloseError{}
+				if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusBadGateway {
+					return xerrors.Errorf("expected close error with code StatusBadGateway: %+v", err)
+				}
+
+				_, err = nc.Write([]byte{0xff})
+				if err == nil {
+					return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err)
+				}
+
 				return nil
 			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				return c.Close(websocket.StatusBadGateway, "")
+			},
 		},
 		{
 			name: "jsonEcho",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
-				defer cancel()
-
+			server: func(ctx context.Context, c *websocket.Conn) error {
 				write := func() error {
 					v := map[string]interface{}{
 						"anmol": "wowow",
@@ -346,7 +446,7 @@ func TestHandshake(t *testing.T) {
 					err := wsjson.Write(ctx, c, v)
 					return err
 				}
-				err = write()
+				err := write()
 				if err != nil {
 					return err
 				}
@@ -358,13 +458,7 @@ func TestHandshake(t *testing.T) {
 				c.Close(websocket.StatusNormalClosure, "")
 				return nil
 			},
-			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
+			client: func(ctx context.Context, c *websocket.Conn) error {
 				read := func() error {
 					var v interface{}
 					err := wsjson.Read(ctx, c, &v)
@@ -380,7 +474,7 @@ func TestHandshake(t *testing.T) {
 					}
 					return nil
 				}
-				err = read()
+				err := read()
 				if err != nil {
 					return err
 				}
@@ -395,21 +489,12 @@ func TestHandshake(t *testing.T) {
 		},
 		{
 			name: "protobufEcho",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
-				defer cancel()
-
+			server: func(ctx context.Context, c *websocket.Conn) error {
 				write := func() error {
 					err := wspb.Write(ctx, c, ptypes.DurationProto(100))
 					return err
 				}
-				err = write()
+				err := write()
 				if err != nil {
 					return err
 				}
@@ -417,13 +502,7 @@ func TestHandshake(t *testing.T) {
 				c.Close(websocket.StatusNormalClosure, "")
 				return nil
 			},
-			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
+			client: func(ctx context.Context, c *websocket.Conn) error {
 				read := func() error {
 					var v duration.Duration
 					err := wspb.Read(ctx, c, &v)
@@ -441,7 +520,7 @@ func TestHandshake(t *testing.T) {
 					}
 					return nil
 				}
-				err = read()
+				err := read()
 				if err != nil {
 					return err
 				}
@@ -450,73 +529,21 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 		},
-		{
-			name: "cookies",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				cookie, err := r.Cookie("mycookie")
-				if err != nil {
-					return xerrors.Errorf("request is missing mycookie: %w", err)
-				}
-				if cookie.Value != "myvalue" {
-					return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
-				}
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusInternalError, "")
-				return nil
-			},
-			client: func(ctx context.Context, u string) error {
-				jar, err := cookiejar.New(nil)
-				if err != nil {
-					return xerrors.Errorf("failed to create cookie jar: %w", err)
-				}
-				parsedURL, err := url.Parse(u)
-				if err != nil {
-					return xerrors.Errorf("failed to parse url: %w", err)
-				}
-				parsedURL.Scheme = "http"
-				jar.SetCookies(parsedURL, []*http.Cookie{
-					{
-						Name:  "mycookie",
-						Value: "myvalue",
-					},
-				})
-				hc := &http.Client{
-					Jar: jar,
-				}
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
-					HTTPClient: hc,
-				})
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusInternalError, "")
-				return nil
-			},
-		},
 		{
 			name: "ping",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
+			server: func(ctx context.Context, c *websocket.Conn) error {
 				errc := make(chan error, 1)
 				go func() {
-					_, _, err2 := c.Read(r.Context())
+					_, _, err2 := c.Read(ctx)
 					errc <- err2
 				}()
 
-				err = c.Ping(r.Context())
+				err := c.Ping(ctx)
 				if err != nil {
 					return err
 				}
 
-				err = c.Write(r.Context(), websocket.MessageText, []byte("hi"))
+				err = c.Write(ctx, websocket.MessageText, []byte("hi"))
 				if err != nil {
 					return err
 				}
@@ -528,13 +555,7 @@ func TestHandshake(t *testing.T) {
 				}
 				return xerrors.Errorf("unexpected error: %w", err)
 			},
-			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
+			client: func(ctx context.Context, c *websocket.Conn) error {
 				// We read a message from the connection and then keep reading until
 				// the Ping completes.
 				done := make(chan struct{})
@@ -550,7 +571,7 @@ func TestHandshake(t *testing.T) {
 					c.Read(ctx)
 				}()
 
-				err = c.Ping(ctx)
+				err := c.Ping(ctx)
 				if err != nil {
 					return err
 				}
@@ -563,29 +584,17 @@ func TestHandshake(t *testing.T) {
 		},
 		{
 			name: "readLimit",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				_, _, err = c.Read(r.Context())
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
 				if err == nil {
 					return xerrors.Errorf("expected error but got nil")
 				}
 				return nil
 			},
-			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				go c.CloseRead(ctx)
 
-				go c.Reader(ctx)
-
-				err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769)))
+				err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769)))
 				if err != nil {
 					return err
 				}
@@ -600,20 +609,244 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 		},
-	}
+		{
+			name: "wsjson/binary",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				var v interface{}
+				err := wsjson.Read(ctx, c, &v)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				return wspb.Write(ctx, c, ptypes.DurationProto(100))
+			},
+		},
+		{
+			name: "wsjson/badRead",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				var v interface{}
+				err := wsjson.Read(ctx, c, &v)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				return c.Write(ctx, websocket.MessageText, []byte("notjson"))
+			},
+		},
+		{
+			name: "wsjson/badWrite",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				err := wsjson.Write(ctx, c, fmt.Println)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+		},
+		{
+			name: "wspb/text",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				var v proto.Message
+				err := wspb.Read(ctx, c, v)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				return wsjson.Write(ctx, c, "hi")
+			},
+		},
+		{
+			name: "wspb/badRead",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				var v timestamp.Timestamp
+				err := wspb.Read(ctx, c, &v)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				return c.Write(ctx, websocket.MessageBinary, []byte("notpb"))
+			},
+		},
+		{
+			name: "wspb/badWrite",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				err := wspb.Write(ctx, c, nil)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+		},
+		{
+			name: "wspb/badWrite",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				err := wspb.Write(ctx, c, nil)
+				if err == nil {
+					return xerrors.Errorf("expected error: %v", err)
+				}
+				return nil
+			},
+		},
+		{
+			name: "badClose",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				return c.Close(9999, "")
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
+				cerr := &websocket.CloseError{}
+				if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusInternalError {
+					return xerrors.Errorf("expected close error with StatusInternalError: %+v", err)
+				}
+				return nil
+			},
+		},
+		{
+			name: "pingTimeout",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				ctx, cancel := context.WithTimeout(ctx, time.Second)
+				defer cancel()
+				err := c.Ping(ctx)
+				if err == nil {
+					return xerrors.Errorf("expected nil error: %+v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				time.Sleep(time.Second)
+				return nil
+			},
+		},
+		{
+			name: "writeTimeout",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				c.Writer(ctx, websocket.MessageBinary)
 
+				ctx, cancel := context.WithTimeout(ctx, time.Second)
+				defer cancel()
+				err := c.Write(ctx, websocket.MessageBinary, []byte("meow"))
+				if !xerrors.Is(err, context.DeadlineExceeded) {
+					return xerrors.Errorf("expected deadline exceeded error: %+v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				time.Sleep(time.Second)
+				return nil
+			},
+		},
+		{
+			name: "readTimeout",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				ctx, cancel := context.WithTimeout(ctx, time.Second)
+				defer cancel()
+				_, r, err := c.Reader(ctx)
+				if err != nil {
+					return err
+				}
+				<-ctx.Done()
+				_, err = r.Read(make([]byte, 1))
+				if !xerrors.Is(err, context.DeadlineExceeded){
+					return xerrors.Errorf("expected deadline exceeded error: %+v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				time.Sleep(time.Second)
+				return nil
+			},
+		},
+		{
+			name: "badOpCode",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				_, err := c.WriteFrame(ctx, true, 13, []byte("meow"))
+				if err != nil {
+					return err
+				}
+				_, _, err = c.Read(ctx)
+				cerr := &websocket.CloseError{}
+				if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError {
+					return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
+				if err == nil || strings.Contains(err.Error(), "opcode") {
+					return xerrors.Errorf("expected error that contains opcode: %+v", err)
+				}
+				return nil
+			},
+		},
+		{
+			name: "noRsv",
+			server: func(ctx context.Context, c *websocket.Conn) error {
+				_, err := c.WriteFrame(ctx, true, 99, []byte("meow"))
+				if err != nil {
+					return err
+				}
+				_, _, err = c.Read(ctx)
+				cerr := &websocket.CloseError{}
+				if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError {
+					return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err)
+				}
+				return nil
+			},
+			client: func(ctx context.Context, c *websocket.Conn) error {
+				_, _, err := c.Read(ctx)
+				if err == nil || !strings.Contains(err.Error(), "rsv") {
+					return xerrors.Errorf("expected error that contains rsv: %+v", err)
+				}
+				return nil
+			},
+		},
+	}
 	for _, tc := range testCases {
 		tc := tc
 		t.Run(tc.name, func(t *testing.T) {
 			t.Parallel()
 
-			s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
-				err := tc.server(w, r)
+			// Run random tests over TLS.
+			tls := rand.Intn(2) == 1
+
+			s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error {
+				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
 				if err != nil {
-					t.Errorf("server failed: %+v", err)
-					return
+					return err
 				}
-			})
+				defer c.Close(websocket.StatusInternalError, "")
+				tc.server(r.Context(), c)
+				return nil
+			}, tls)
 			defer closeFn()
 
 			wsURL := strings.Replace(s.URL, "http", "ws", 1)
@@ -621,7 +854,18 @@ func TestHandshake(t *testing.T) {
 			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 			defer cancel()
 
-			err := tc.client(ctx, wsURL)
+			opts := websocket.DialOptions{}
+			if tls {
+				opts.HTTPClient = s.Client()
+			}
+
+			c, _, err := websocket.Dial(ctx, wsURL, opts)
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer c.Close(websocket.StatusInternalError, "")
+
+			err = tc.client(ctx, c)
 			if err != nil {
 				t.Fatalf("client failed: %+v", err)
 			}
@@ -629,14 +873,31 @@ func TestHandshake(t *testing.T) {
 	}
 }
 
-func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) {
+func init() {
+	rand.Seed(time.Now().UnixNano())
+}
+
+func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) {
 	var conns int64
-	s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		atomic.AddInt64(&conns, 1)
 		defer atomic.AddInt64(&conns, -1)
 
-		fn.ServeHTTP(w, r)
-	}))
+		ctx, cancel := context.WithTimeout(r.Context(), time.Second*30)
+		defer cancel()
+
+		r = r.WithContext(ctx)
+
+		err := fn(w, r)
+		if err != nil {
+			tb.Errorf("server failed: %+v", err)
+		}
+	})
+	if tls {
+		s = httptest.NewTLSServer(h)
+	} else {
+		s = httptest.NewServer(h)
+	}
 	return s, func() {
 		s.Close()
 
@@ -654,7 +915,9 @@ func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn
 // https://github.com/crossbario/autobahn-python/tree/master/wstest
 func TestAutobahnServer(t *testing.T) {
 	t.Parallel()
-	t.Skip()
+	if os.Getenv("AUTOBAHN") == "" {
+		t.Skip("Set $AUTOBAHN to run the autobahn test suite.")
+	}
 
 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		c, err := websocket.Accept(w, r, websocket.AcceptOptions{
@@ -795,7 +1058,9 @@ func unusedListenAddr() (string, error) {
 // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py
 func TestAutobahnClient(t *testing.T) {
 	t.Parallel()
-	t.Skip()
+	if os.Getenv("AUTOBAHN") == "" {
+		t.Skip("Set $AUTOBAHN to run the autobahn test suite.")
+	}
 
 	serverAddr, err := unusedListenAddr()
 	if err != nil {
@@ -941,18 +1206,18 @@ func checkWSTestIndex(t *testing.T, path string) {
 }
 
 func benchConn(b *testing.B, echo, stream bool, size int) {
-	s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error {
 		c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
 		if err != nil {
-			b.Logf("server handshake failed: %+v", err)
-			return
+			return err
 		}
 		if echo {
 			echoLoop(r.Context(), c)
 		} else {
 			discardLoop(r.Context(), c)
 		}
-	}))
+		return nil
+	}, false)
 	defer closeFn()
 
 	wsURL := strings.Replace(s.URL, "http", "ws", 1)
-- 
GitLab