From d0a80496108cf7cdd4e20c24e4689cd5934b5b89 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 18 Nov 2019 22:52:18 -0500
Subject: [PATCH] Rewrite core

Too many improvements and changes to list.

Will include a detailed changelog for release.
---
 accept.go                                     |   63 +-
 assert_test.go                                |   14 +
 autobahn_test.go                              |  252 ++
 close.go                                      |  158 +-
 close_test.go                                 |    9 +-
 compress.go                                   |   86 +
 conn.go                                       | 1133 +-------
 conn_export_test.go                           |  129 -
 conn_test.go                                  | 2382 +----------------
 dial.go                                       |   78 +-
 dial_test.go                                  |    2 +-
 example_echo_test.go                          |    3 +-
 internal/wsframe/mask.go => frame.go          |  162 +-
 .../wsframe/mask_test.go => frame_test.go     |  108 +-
 internal/assert/assert.go                     |   40 +-
 internal/atomicint/atomicint.go               |   32 -
 internal/bufpool/buf.go                       |    6 +-
 internal/bufpool/bufio.go                     |   40 -
 internal/errd/errd.go                         |   11 +
 internal/wsecho/wsecho.go                     |   55 -
 internal/wsframe/frame.go                     |  194 --
 internal/wsframe/frame_stringer.go            |   91 -
 internal/wsframe/frame_test.go                |  157 --
 internal/wsgrace/wsgrace.go                   |   50 -
 js_test.go                                    |   50 -
 read.go                                       |  479 ++++
 reader.go                                     |   31 -
 write.go                                      |  348 +++
 writer.go                                     |    5 -
 ws_js.go                                      |   12 +-
 wsjson/wsjson.go                              |    2 +
 31 files changed, 1844 insertions(+), 4338 deletions(-)
 create mode 100644 autobahn_test.go
 delete mode 100644 conn_export_test.go
 rename internal/wsframe/mask.go => frame.go (57%)
 rename internal/wsframe/mask_test.go => frame_test.go (51%)
 delete mode 100644 internal/atomicint/atomicint.go
 delete mode 100644 internal/bufpool/bufio.go
 create mode 100644 internal/errd/errd.go
 delete mode 100644 internal/wsecho/wsecho.go
 delete mode 100644 internal/wsframe/frame.go
 delete mode 100644 internal/wsframe/frame_stringer.go
 delete mode 100644 internal/wsframe/frame_test.go
 delete mode 100644 internal/wsgrace/wsgrace.go
 delete mode 100644 js_test.go
 create mode 100644 read.go
 delete mode 100644 reader.go
 create mode 100644 write.go
 delete mode 100644 writer.go

diff --git a/accept.go b/accept.go
index 5ff2ea4..2028d4b 100644
--- a/accept.go
+++ b/accept.go
@@ -60,10 +60,15 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
 	return c, nil
 }
 
-func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
+func (opts *AcceptOptions) ensure() *AcceptOptions {
 	if opts == nil {
-		opts = &AcceptOptions{}
+		return &AcceptOptions{}
 	}
+	return opts
+}
+
+func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
+	opts = opts.ensure()
 
 	err := verifyClientRequest(w, r)
 	if err != nil {
@@ -114,31 +119,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
 	b, _ := brw.Reader.Peek(brw.Reader.Buffered())
 	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
 
-	c := &Conn{
+	return newConn(connConfig{
 		subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
+		rwc:         netConn,
+		client:      false,
+		copts:       copts,
 		br:          brw.Reader,
 		bw:          brw.Writer,
-		closer:      netConn,
-		copts:       copts,
-	}
-	c.init()
-
-	return c, nil
-}
-
-func authenticateOrigin(r *http.Request) error {
-	origin := r.Header.Get("Origin")
-	if origin == "" {
-		return nil
-	}
-	u, err := url.Parse(origin)
-	if err != nil {
-		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
-	}
-	if !strings.EqualFold(u.Host, r.Host) {
-		return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
-	}
-	return nil
+	}), nil
 }
 
 func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
@@ -181,15 +169,37 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
 	return nil
 }
 
+func authenticateOrigin(r *http.Request) error {
+	origin := r.Header.Get("Origin")
+	if origin == "" {
+		return nil
+	}
+	u, err := url.Parse(origin)
+	if err != nil {
+		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
+	}
+	if !strings.EqualFold(u.Host, r.Host) {
+		return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
+	}
+	return nil
+}
+
 func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) {
 	key := r.Header.Get("Sec-WebSocket-Key")
 	w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
 }
 
 func selectSubprotocol(r *http.Request, subprotocols []string) string {
+	cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
+	if len(cps) == 0 {
+		return ""
+	}
+
 	for _, sp := range subprotocols {
-		if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) {
-			return sp
+		for _, cp := range cps {
+			if strings.EqualFold(sp, cp) {
+				return cp
+			}
 		}
 	}
 	return ""
@@ -266,7 +276,6 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
 	return copts, nil
 }
 
-
 func headerContainsToken(h http.Header, key, token string) bool {
 	token = strings.ToLower(token)
 
diff --git a/assert_test.go b/assert_test.go
index af30099..0cc9dfe 100644
--- a/assert_test.go
+++ b/assert_test.go
@@ -23,6 +23,8 @@ func randBytes(n int) []byte {
 }
 
 func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) {
+	t.Helper()
+
 	exp := randString(n)
 	err := wsjson.Write(ctx, c, exp)
 	assert.Success(t, err)
@@ -35,6 +37,8 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int)
 }
 
 func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
+	t.Helper()
+
 	var act interface{}
 	err := wsjson.Read(ctx, c, &act)
 	assert.Success(t, err)
@@ -56,6 +60,8 @@ func randString(n int) string {
 }
 
 func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) {
+	t.Helper()
+
 	p := randBytes(n)
 	err := c.Write(ctx, typ, p)
 	assert.Success(t, err)
@@ -68,5 +74,13 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc
 }
 
 func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) {
+	t.Helper()
+
 	assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol")
 }
+
+func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) {
+	t.Helper()
+
+	assert.Equalf(t, exp, websocket.CloseStatus(err), "unexpected status code")
+}
diff --git a/autobahn_test.go b/autobahn_test.go
new file mode 100644
index 0000000..27f8a1b
--- /dev/null
+++ b/autobahn_test.go
@@ -0,0 +1,252 @@
+package websocket_test
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"io/ioutil"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"nhooyr.io/websocket"
+	"os"
+	"os/exec"
+	"strconv"
+	"strings"
+	"testing"
+	"time"
+)
+
+func TestAutobahn(t *testing.T) {
+	// This test contains the old autobahn test suite tests that use the
+	// python binary. The approach is clunky and slow so new tests
+	// have been written in pure Go in websocket_test.go.
+	// These have been kept for correctness purposes and are occasionally ran.
+	if os.Getenv("AUTOBAHN") == "" {
+		t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite")
+	}
+
+	t.Run("server", testServerAutobahnPython)
+	t.Run("client", testClientAutobahnPython)
+}
+
+// https://github.com/crossbario/autobahn-python/tree/master/wstest
+func testServerAutobahnPython(t *testing.T) {
+	t.Parallel()
+
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
+			Subprotocols: []string{"echo"},
+		})
+		if err != nil {
+			t.Logf("server handshake failed: %+v", err)
+			return
+		}
+		echoLoop(r.Context(), c)
+	}))
+	defer s.Close()
+
+	spec := map[string]interface{}{
+		"outdir": "ci/out/wstestServerReports",
+		"servers": []interface{}{
+			map[string]interface{}{
+				"agent": "main",
+				"url":   strings.Replace(s.URL, "http", "ws", 1),
+			},
+		},
+		"cases": []string{"*"},
+		// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
+		// more performance overhead. 7.5.1 is the same.
+		"exclude-cases": []string{"6.*", "7.5.1"},
+	}
+	specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json")
+	if err != nil {
+		t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err)
+	}
+	defer specFile.Close()
+
+	e := json.NewEncoder(specFile)
+	e.SetIndent("", "\t")
+	err = e.Encode(spec)
+	if err != nil {
+		t.Fatalf("failed to write spec: %v", err)
+	}
+
+	err = specFile.Close()
+	if err != nil {
+		t.Fatalf("failed to close file: %v", err)
+	}
+
+	ctx := context.Background()
+	ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
+	defer cancel()
+
+	args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()}
+	wstest := exec.CommandContext(ctx, "wstest", args...)
+	out, err := wstest.CombinedOutput()
+	if err != nil {
+		t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out)
+	}
+
+	checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
+}
+
+func unusedListenAddr() (string, error) {
+	l, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		return "", err
+	}
+	l.Close()
+	return l.Addr().String(), nil
+}
+
+// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py
+func testClientAutobahnPython(t *testing.T) {
+	t.Parallel()
+
+	if os.Getenv("AUTOBAHN_PYTHON") == "" {
+		t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite")
+	}
+
+	serverAddr, err := unusedListenAddr()
+	if err != nil {
+		t.Fatalf("failed to get unused listen addr for wstest: %v", err)
+	}
+
+	wsServerURL := "ws://" + serverAddr
+
+	spec := map[string]interface{}{
+		"url":    wsServerURL,
+		"outdir": "ci/out/wstestClientReports",
+		"cases":  []string{"*"},
+		// See TestAutobahnServer for the reasons why we exclude these.
+		"exclude-cases": []string{"6.*", "7.5.1"},
+	}
+	specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json")
+	if err != nil {
+		t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err)
+	}
+	defer specFile.Close()
+
+	e := json.NewEncoder(specFile)
+	e.SetIndent("", "\t")
+	err = e.Encode(spec)
+	if err != nil {
+		t.Fatalf("failed to write spec: %v", err)
+	}
+
+	err = specFile.Close()
+	if err != nil {
+		t.Fatalf("failed to close file: %v", err)
+	}
+
+	ctx := context.Background()
+	ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
+	defer cancel()
+
+	args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(),
+		// Disables some server that runs as part of fuzzingserver mode.
+		// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
+		"--webport=0",
+	}
+	wstest := exec.CommandContext(ctx, "wstest", args...)
+	err = wstest.Start()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		err := wstest.Process.Kill()
+		if err != nil {
+			t.Error(err)
+		}
+	}()
+
+	// Let it come up.
+	time.Sleep(time.Second * 5)
+
+	var cases int
+	func() {
+		c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close(websocket.StatusInternalError, "")
+
+		_, r, err := c.Reader(ctx)
+		if err != nil {
+			t.Fatal(err)
+		}
+		b, err := ioutil.ReadAll(r)
+		if err != nil {
+			t.Fatal(err)
+		}
+		cases, err = strconv.Atoi(string(b))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		c.Close(websocket.StatusNormalClosure, "")
+	}()
+
+	for i := 1; i <= cases; i++ {
+		func() {
+			ctx, cancel := context.WithTimeout(ctx, time.Second*45)
+			defer cancel()
+
+			c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+			echoLoop(ctx, c)
+		}()
+	}
+
+	c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	c.Close(websocket.StatusNormalClosure, "")
+
+	checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
+}
+
+func checkWSTestIndex(t *testing.T, path string) {
+	wstestOut, err := ioutil.ReadFile(path)
+	if err != nil {
+		t.Fatalf("failed to read index.json: %v", err)
+	}
+
+	var indexJSON map[string]map[string]struct {
+		Behavior      string `json:"behavior"`
+		BehaviorClose string `json:"behaviorClose"`
+	}
+	err = json.Unmarshal(wstestOut, &indexJSON)
+	if err != nil {
+		t.Fatalf("failed to unmarshal index.json: %v", err)
+	}
+
+	var failed bool
+	for _, tests := range indexJSON {
+		for test, result := range tests {
+			switch result.Behavior {
+			case "OK", "NON-STRICT", "INFORMATIONAL":
+			default:
+				failed = true
+				t.Errorf("test %v failed", test)
+			}
+			switch result.BehaviorClose {
+			case "OK", "INFORMATIONAL":
+			default:
+				failed = true
+				t.Errorf("bad close behaviour for test %v", test)
+			}
+		}
+	}
+
+	if failed {
+		path = strings.Replace(path, ".json", ".html", 1)
+		if os.Getenv("CI") == "" {
+			t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path)
+		}
+	}
+}
diff --git a/close.go b/close.go
index 4f48f1b..b1bc50e 100644
--- a/close.go
+++ b/close.go
@@ -5,7 +5,9 @@ import (
 	"encoding/binary"
 	"errors"
 	"fmt"
-	"nhooyr.io/websocket/internal/wsframe"
+	"log"
+	"nhooyr.io/websocket/internal/bufpool"
+	"time"
 )
 
 // StatusCode represents a WebSocket status code.
@@ -74,6 +76,87 @@ func CloseStatus(err error) StatusCode {
 	return -1
 }
 
+// Close closes the WebSocket connection with the given status code and reason.
+//
+// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
+// the peer to send a close frame.
+// Thus, it implements the full WebSocket close handshake.
+// All data messages received from the peer during the close handshake
+// will be discarded.
+//
+// The connection can only be closed once. Additional calls to Close
+// are no-ops.
+//
+// The maximum length of reason must be 125 bytes otherwise an internal
+// error will be sent to the peer. For this reason, you should avoid
+// sending a dynamic reason.
+//
+// Close will unblock all goroutines interacting with the connection once
+// complete.
+func (c *Conn) Close(code StatusCode, reason string) error {
+	err := c.closeHandshake(code, reason)
+	if err != nil {
+		return fmt.Errorf("failed to close websocket: %w", err)
+	}
+	return nil
+}
+
+func (c *Conn) closeHandshake(code StatusCode, reason string) error {
+	err := c.cw.sendClose(code, reason)
+	if err != nil {
+		return err
+	}
+
+	return c.cr.waitClose()
+}
+
+func (cw *connWriter) error(code StatusCode, err error) {
+	cw.c.setCloseErr(err)
+	cw.sendClose(code, err.Error())
+	cw.c.close(nil)
+}
+
+func (cw *connWriter) sendClose(code StatusCode, reason string) error {
+	ce := CloseError{
+		Code:   code,
+		Reason: reason,
+	}
+
+	cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
+
+	var p []byte
+	if ce.Code != StatusNoStatusRcvd {
+		p = ce.bytes()
+	}
+
+	return cw.control(context.Background(), opClose, p)
+}
+
+func (cr *connReader) waitClose() error {
+	defer cr.c.close(nil)
+
+	return nil
+
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	defer cancel()
+
+	err := cr.mu.Lock(ctx)
+	if err != nil {
+		return err
+	}
+	defer cr.mu.Unlock()
+
+	b := bufpool.Get()
+	buf := b.Bytes()
+	buf = buf[:cap(buf)]
+	defer bufpool.Put(b)
+
+	for {
+		// TODO
+		return nil
+	}
+}
+
 func parseClosePayload(p []byte) (CloseError, error) {
 	if len(p) == 0 {
 		return CloseError{
@@ -81,14 +164,13 @@ func parseClosePayload(p []byte) (CloseError, error) {
 		}, nil
 	}
 
-	code, reason, err := wsframe.ParseClosePayload(p)
-	if err != nil {
-		return CloseError{}, err
+	if len(p) < 2 {
+		return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
 	}
 
 	ce := CloseError{
-		Code:   StatusCode(code),
-		Reason: reason,
+		Code:   StatusCode(binary.BigEndian.Uint16(p)),
+		Reason: string(p[2:]),
 	}
 
 	if !validWireCloseCode(ce.Code) {
@@ -116,11 +198,25 @@ func validWireCloseCode(code StatusCode) bool {
 	return false
 }
 
-func (ce CloseError) bytes() ([]byte, error) {
-	// TODO move check into frame write
-	if len(ce.Reason) > wsframe.MaxControlFramePayload-2 {
-		return nil, fmt.Errorf("reason string max is %v but got %q with length %v", wsframe.MaxControlFramePayload-2, ce.Reason, len(ce.Reason))
+func (ce CloseError) bytes() []byte {
+	p, err := ce.bytesErr()
+	if err != nil {
+		log.Printf("websocket: failed to marshal close frame: %+v", err)
+		ce = CloseError{
+			Code: StatusInternalError,
+		}
+		p, _ = ce.bytesErr()
 	}
+	return p
+}
+
+const maxCloseReason = maxControlPayload - 2
+
+func (ce CloseError) bytesErr() ([]byte, error) {
+	if len(ce.Reason) > maxCloseReason {
+		return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
+	}
+
 	if !validWireCloseCode(ce.Code) {
 		return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
 	}
@@ -131,44 +227,16 @@ func (ce CloseError) bytes() ([]byte, error) {
 	return buf, nil
 }
 
-// CloseRead will start a goroutine to read from the connection until it is closed or a data message
-// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
-// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
-// After calling this method, you cannot read any data messages from the connection.
-// The returned context will be cancelled when the connection is closed.
-//
-// Use this when you do not want to read data messages from the connection anymore but will
-// want to write messages to it.
-func (c *Conn) CloseRead(ctx context.Context) context.Context {
-	c.isReadClosed.Store(1)
-
-	ctx, cancel := context.WithCancel(ctx)
-	go func() {
-		defer cancel()
-		// We use the unexported reader method so that we don't get the read closed error.
-		c.reader(ctx, true)
-		// Either the connection is already closed since there was a read error
-		// or the context was cancelled or a message was read and we should close
-		// the connection.
-		c.Close(StatusPolicyViolation, "unexpected data message")
-	}()
-	return ctx
-}
-
-// SetReadLimit sets the max number of bytes to read for a single message.
-// It applies to the Reader and Read methods.
-//
-// By default, the connection has a message read limit of 32768 bytes.
-//
-// When the limit is hit, the connection will be closed with StatusMessageTooBig.
-func (c *Conn) SetReadLimit(n int64) {
-	c.msgReadLimit.Store(n)
+func (c *Conn) setCloseErr(err error) {
+	c.closeMu.Lock()
+	c.setCloseErrNoLock(err)
+	c.closeMu.Unlock()
 }
 
-func (c *Conn) setCloseErr(err error) {
-	c.closeErrOnce.Do(func() {
+func (c *Conn) setCloseErrNoLock(err error) {
+	if c.closeErr == nil {
 		c.closeErr = fmt.Errorf("websocket closed: %w", err)
-	})
+	}
 }
 
 func (c *Conn) isClosed() bool {
diff --git a/close_test.go b/close_test.go
index 78096d7..ee10cd3 100644
--- a/close_test.go
+++ b/close_test.go
@@ -5,7 +5,6 @@ import (
 	"io"
 	"math"
 	"nhooyr.io/websocket/internal/assert"
-	"nhooyr.io/websocket/internal/wsframe"
 	"strings"
 	"testing"
 )
@@ -22,7 +21,7 @@ func TestCloseError(t *testing.T) {
 			name: "normal",
 			ce: CloseError{
 				Code:   StatusNormalClosure,
-				Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2),
+				Reason: strings.Repeat("x", maxCloseReason),
 			},
 			success: true,
 		},
@@ -30,7 +29,7 @@ func TestCloseError(t *testing.T) {
 			name: "bigReason",
 			ce: CloseError{
 				Code:   StatusNormalClosure,
-				Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-1),
+				Reason: strings.Repeat("x", maxCloseReason+1),
 			},
 			success: false,
 		},
@@ -38,7 +37,7 @@ func TestCloseError(t *testing.T) {
 			name: "bigCode",
 			ce: CloseError{
 				Code:   math.MaxUint16,
-				Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2),
+				Reason: strings.Repeat("x", maxCloseReason),
 			},
 			success: false,
 		},
@@ -49,7 +48,7 @@ func TestCloseError(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
 			t.Parallel()
 
-			_, err := tc.ce.bytes()
+			_, err := tc.ce.bytesErr()
 			if (err == nil) != tc.success {
 				t.Fatalf("unexpected error value: %+v", err)
 			}
diff --git a/compress.go b/compress.go
index 5b5fdce..9e07543 100644
--- a/compress.go
+++ b/compress.go
@@ -3,7 +3,10 @@
 package websocket
 
 import (
+	"compress/flate"
+	"io"
 	"net/http"
+	"sync"
 )
 
 // CompressionMode controls the modes available RFC 7692's deflate extension.
@@ -76,3 +79,86 @@ func (copts *compressionOptions) setHeader(h http.Header) {
 // we need to add them back otherwise flate.Reader keeps
 // trying to return more bytes.
 const deflateMessageTail = "\x00\x00\xff\xff"
+
+func (c *Conn) writeNoContextTakeOver() bool {
+	return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover
+}
+
+func (c *Conn) readNoContextTakeOver() bool {
+	return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover
+}
+
+type trimLastFourBytesWriter struct {
+	w    io.Writer
+	tail []byte
+}
+
+func (tw *trimLastFourBytesWriter) reset() {
+	tw.tail = tw.tail[:0]
+}
+
+func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
+	extra := len(tw.tail) + len(p) - 4
+
+	if extra <= 0 {
+		tw.tail = append(tw.tail, p...)
+		return len(p), nil
+	}
+
+	// Now we need to write as many extra bytes as we can from the previous tail.
+	if extra > len(tw.tail) {
+		extra = len(tw.tail)
+	}
+	if extra > 0 {
+		_, err := tw.w.Write(tw.tail[:extra])
+		if err != nil {
+			return 0, err
+		}
+		tw.tail = tw.tail[extra:]
+	}
+
+	// If p is less than or equal to 4 bytes,
+	// all of it is is part of the tail.
+	if len(p) <= 4 {
+		tw.tail = append(tw.tail, p...)
+		return len(p), nil
+	}
+
+	// Otherwise, only the last 4 bytes are.
+	tw.tail = append(tw.tail, p[len(p)-4:]...)
+
+	p = p[:len(p)-4]
+	n, err := tw.w.Write(p)
+	return n + 4, err
+}
+
+var flateReaderPool sync.Pool
+
+func getFlateReader(r io.Reader) io.Reader {
+	fr, ok := flateReaderPool.Get().(io.Reader)
+	if !ok {
+		return flate.NewReader(r)
+	}
+	fr.(flate.Resetter).Reset(r, nil)
+	return fr
+}
+
+func putFlateReader(fr io.Reader) {
+	flateReaderPool.Put(fr)
+}
+
+var flateWriterPool sync.Pool
+
+func getFlateWriter(w io.Writer) *flate.Writer {
+	fw, ok := flateWriterPool.Get().(*flate.Writer)
+	if !ok {
+		fw, _ = flate.NewWriter(w, flate.BestSpeed)
+		return fw
+	}
+	fw.Reset(w)
+	return fw
+}
+
+func putFlateWriter(w *flate.Writer) {
+	flateWriterPool.Put(w)
+}
diff --git a/conn.go b/conn.go
index 791d9b4..e3f2417 100644
--- a/conn.go
+++ b/conn.go
@@ -4,25 +4,14 @@ package websocket
 
 import (
 	"bufio"
-	"compress/flate"
 	"context"
-	"crypto/rand"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
-	"log"
-	"nhooyr.io/websocket/internal/atomicint"
-	"nhooyr.io/websocket/internal/wsframe"
 	"runtime"
 	"strconv"
-	"strings"
 	"sync"
 	"sync/atomic"
-	"time"
-
-	"nhooyr.io/websocket/internal/bufpool"
 )
 
 // MessageType represents the type of a WebSocket message.
@@ -51,91 +40,54 @@ const (
 // This applies to the Read methods in the wsjson/wspb subpackages as well.
 type Conn struct {
 	subprotocol string
-	fw          *flate.Writer
-	bw          *bufio.Writer
-	// writeBuf is used for masking, its the buffer in bufio.Writer.
-	// Only used by the client for masking the bytes in the buffer.
-	writeBuf []byte
-	closer   io.Closer
-	client   bool
-	copts    *compressionOptions
-
-	closeOnce     sync.Once
-	closeErrOnce  sync.Once
-	closeErr      error
-	closed        chan struct{}
-	closing       *atomicint.Int64
-	closeReceived error
+	rwc         io.ReadWriteCloser
+	client      bool
+	copts       *compressionOptions
 
-	// messageWriter state.
-	// writeMsgLock is acquired to write a data message.
-	writeMsgLock chan struct{}
-	// writeFrameLock is acquired to write a single frame.
-	// Effectively meaning whoever holds it gets to write to bw.
-	writeFrameLock chan struct{}
-	writeHeaderBuf []byte
-	writeHeader    *header
-	// read limit for a message in bytes.
-	msgReadLimit *atomicint.Int64
+	cr connReader
+	cw connWriter
 
-	// Used to ensure a previous writer is not used after being closed.
-	activeWriter atomic.Value
-	// messageWriter state.
-	writeMsgOpcode opcode
-	writeMsgCtx    context.Context
+	closed chan struct{}
 
-	setReadTimeout  chan context.Context
-	setWriteTimeout chan context.Context
+	closeMu           sync.Mutex
+	closeErr          error
+	closeHandshakeErr error
 
-	pingCounter   *atomicint.Int64
+	pingCounter   int32
 	activePingsMu sync.Mutex
 	activePings   map[string]chan<- struct{}
-
-	logf func(format string, v ...interface{})
 }
 
-func (c *Conn) init() {
-	c.closed = make(chan struct{})
-	c.closing = &atomicint.Int64{}
-
-	c.msgReadLimit = &atomicint.Int64{}
-	c.msgReadLimit.Store(32768)
+type connConfig struct {
+	subprotocol string
+	rwc         io.ReadWriteCloser
+	client      bool
+	copts       *compressionOptions
 
-	c.writeMsgLock = make(chan struct{}, 1)
-	c.writeFrameLock = make(chan struct{}, 1)
+	bw *bufio.Writer
+	br *bufio.Reader
+}
 
-	c.readFrameLock = make(chan struct{}, 1)
-	c.readLock = make(chan struct{}, 1)
-	c.payloadReader = framePayloadReader{c}
+func newConn(cfg connConfig) *Conn {
+	c := &Conn{}
+	c.subprotocol = cfg.subprotocol
+	c.rwc = cfg.rwc
+	c.client = cfg.client
+	c.copts = cfg.copts
 
-	c.setReadTimeout = make(chan context.Context)
-	c.setWriteTimeout = make(chan context.Context)
+	c.cr.init(c, cfg.br)
+	c.cw.init(c, cfg.bw)
 
-	c.pingCounter = &atomicint.Int64{}
+	c.closed = make(chan struct{})
 	c.activePings = make(map[string]chan<- struct{})
 
-	c.writeHeaderBuf = makeWriteHeaderBuf()
-	c.writeHeader = &header{}
-	c.readHeaderBuf = makeReadHeaderBuf()
-	c.isReadClosed = &atomicint.Int64{}
-	c.controlPayloadBuf = make([]byte, maxControlFramePayload)
-
 	runtime.SetFinalizer(c, func(c *Conn) {
 		c.close(errors.New("connection garbage collected"))
 	})
 
-	c.logf = log.Printf
-
-	if c.copts != nil {
-		if !c.readNoContextTakeOver() {
-			c.fr = getFlateReader(c.payloadReader)
-		}
-		if !c.writeNoContextTakeOver() {
-			c.fw = getFlateWriter(c.bw)
-		}
-	}
-
 	go c.timeoutLoop()
+
+	return c
 }
 
 // Subprotocol returns the negotiated subprotocol.
@@ -145,38 +97,25 @@ func (c *Conn) Subprotocol() string {
 }
 
 func (c *Conn) close(err error) {
-	c.closeOnce.Do(func() {
-		runtime.SetFinalizer(c, nil)
+	c.closeMu.Lock()
+	defer c.closeMu.Unlock()
 
-		c.setCloseErr(err)
-		close(c.closed)
-
-		// Have to close after c.closed is closed to ensure any goroutine that wakes up
-		// from the connection being closed also sees that c.closed is closed and returns
-		// closeErr.
-		c.closer.Close()
+	if c.isClosed() {
+		return
+	}
+	close(c.closed)
+	runtime.SetFinalizer(c, nil)
+	c.setCloseErrNoLock(err)
 
-		// By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer
-		// and we can safely return them.
-		// Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent
-		// a deadlock.
-		// As of now, this is in writeFrame, readFramePayload and readHeader.
-		c.readFrameLock <- struct{}{}
-		if c.client {
-			returnBufioReader(c.br)
-		}
-		if c.fr != nil {
-			putFlateReader(c.fr)
-		}
+	// Have to close after c.closed is closed to ensure any goroutine that wakes up
+	// from the connection being closed also sees that c.closed is closed and returns
+	// closeErr.
+	c.rwc.Close()
 
-		c.writeFrameLock <- struct{}{}
-		if c.client {
-			returnBufioWriter(c.bw)
-		}
-		if c.fw != nil {
-			putFlateWriter(c.fw)
-		}
-	})
+	go func() {
+		c.cr.close()
+		c.cw.close()
+	}()
 }
 
 func (c *Conn) timeoutLoop() {
@@ -188,20 +127,13 @@ func (c *Conn) timeoutLoop() {
 		case <-c.closed:
 			return
 
-		case writeCtx = <-c.setWriteTimeout:
-		case readCtx = <-c.setReadTimeout:
+		case writeCtx = <-c.cw.timeout:
+		case readCtx = <-c.cr.timeout:
 
 		case <-readCtx.Done():
 			c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
-			// Guaranteed to eventually close the connection since we can only ever send
-			// one close frame.
-			go func() {
-				c.exportedClose(StatusPolicyViolation, "read timed out", true)
-				// Ensure the connection closes, i.e if we already sent a close frame and timed out
-				// to read the peer's close frame.
-				c.close(nil)
-			}()
-			readCtx = context.Background()
+			c.cw.error(StatusPolicyViolation, errors.New("timed out"))
+			return
 		case <-writeCtx.Done():
 			c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
 			return
@@ -209,843 +141,8 @@ func (c *Conn) timeoutLoop() {
 	}
 }
 
-func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
-	select {
-	case <-ctx.Done():
-		var err error
-		switch lock {
-		case c.writeFrameLock, c.writeMsgLock:
-			err = fmt.Errorf("could not acquire write lock: %v", ctx.Err())
-		case c.readFrameLock, c.readLock:
-			err = fmt.Errorf("could not acquire read lock: %v", ctx.Err())
-		default:
-			panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err()))
-		}
-		c.close(err)
-		return ctx.Err()
-	case <-c.closed:
-		return c.closeErr
-	case lock <- struct{}{}:
-		return nil
-	}
-}
-
-func (c *Conn) releaseLock(lock chan struct{}) {
-	// Allow multiple releases.
-	select {
-	case <-lock:
-	default:
-	}
-}
-
-func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
-	for {
-		h, err := c.readFrameHeader(ctx)
-		if err != nil {
-			return header{}, err
-		}
-
-		if (h.rsv1 && (c.copts == nil || h.opcode.controlOp() || h.opcode == opContinuation)) || h.rsv2 || h.rsv3 {
-			err := fmt.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
-			c.exportedClose(StatusProtocolError, err.Error(), false)
-			return header{}, err
-		}
-
-		if h.opcode.controlOp() {
-			err = c.handleControl(ctx, h)
-			if err != nil {
-				// Pass through CloseErrors when receiving a close frame.
-				if h.opcode == opClose && CloseStatus(err) != -1 {
-					return header{}, err
-				}
-				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
-			}
-			continue
-		}
-
-		switch h.opcode {
-		case opBinary, opText, opContinuation:
-			return h, nil
-		default:
-			err := fmt.Errorf("received unknown opcode %v", h.opcode)
-			c.exportedClose(StatusProtocolError, err.Error(), false)
-			return header{}, err
-		}
-	}
-}
-
-func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
-	wrap := func(err error) error {
-		return fmt.Errorf("failed to read frame header: %w", err)
-	}
-	defer func() {
-		if err != nil {
-			err = wrap(err)
-		}
-	}()
-
-	err = c.acquireLock(ctx, c.readFrameLock)
-	if err != nil {
-		return header{}, err
-	}
-	defer c.releaseLock(c.readFrameLock)
-
-	select {
-	case <-c.closed:
-		return header{}, c.closeErr
-	case c.setReadTimeout <- ctx:
-	}
-
-	h, err := readHeader(c.readHeaderBuf, c.br)
-	if err != nil {
-		select {
-		case <-c.closed:
-			return header{}, c.closeErr
-		case <-ctx.Done():
-			err = ctx.Err()
-		default:
-		}
-		c.releaseLock(c.readFrameLock)
-		c.close(wrap(err))
-		return header{}, err
-	}
-
-	select {
-	case <-c.closed:
-		return header{}, c.closeErr
-	case c.setReadTimeout <- context.Background():
-	}
-
-	return h, nil
-}
-
-func (c *Conn) handleControl(ctx context.Context, h header) error {
-	if h.payloadLength > maxControlFramePayload {
-		err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength)
-		c.exportedClose(StatusProtocolError, err.Error(), false)
-		return err
-	}
-
-	if !h.fin {
-		err := errors.New("received fragmented control frame")
-		c.exportedClose(StatusProtocolError, err.Error(), false)
-		return err
-	}
-
-	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
-	defer cancel()
-
-	b := c.controlPayloadBuf[:h.payloadLength]
-	_, err := c.readFramePayload(ctx, b)
-	if err != nil {
-		return err
-	}
-
-	if h.masked {
-		mask(h.maskKey, b)
-	}
-
-	switch h.opcode {
-	case opPing:
-		return c.writeControl(ctx, opPong, b)
-	case opPong:
-		c.activePingsMu.Lock()
-		pong, ok := c.activePings[string(b)]
-		c.activePingsMu.Unlock()
-		if ok {
-			close(pong)
-		}
-		return nil
-	case opClose:
-		ce, err := parseClosePayload(b)
-		if err != nil {
-			err = fmt.Errorf("received invalid close payload: %w", err)
-			c.exportedClose(StatusProtocolError, err.Error(), false)
-			c.closeReceived = err
-			return err
-		}
-
-		err = fmt.Errorf("received close: %w", ce)
-		c.closeReceived = err
-		c.writeClose(b, err, false)
-
-		if ctx.Err() != nil {
-			// The above close probably has been returned by the peer in response
-			// to our read timing out so we have to return the read timed out error instead.
-			return fmt.Errorf("read timed out: %w", ctx.Err())
-		}
-
-		return err
-	default:
-		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
-	}
-}
-
-// Reader waits until there is a WebSocket data message to read
-// from the connection.
-// It returns the type of the message and a reader to read it.
-// The passed context will also bound the reader.
-// Ensure you read to EOF otherwise the connection will hang.
-//
-// All returned errors will cause the connection
-// to be closed so you do not need to write your own error message.
-// This applies to the Read methods in the wsjson/wspb subpackages as well.
-//
-// You must read from the connection for control frames to be handled.
-// Thus if you expect messages to take a long time to be responded to,
-// you should handle such messages async to reading from the connection
-// to ensure control frames are promptly handled.
-//
-// If you do not expect any data messages from the peer, call CloseRead.
-//
-// Only one Reader may be open at a time.
-//
-// If you need a separate timeout on the Reader call and then the message
-// Read, use time.AfterFunc to cancel the context passed in early.
-// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
-// Most users should not need this.
-func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
-	if c.isReadClosed.Load() == 1 {
-		return 0, nil, errors.New("websocket connection read closed")
-	}
-
-	typ, r, err := c.reader(ctx, true)
-	if err != nil {
-		return 0, nil, fmt.Errorf("failed to get reader: %w", err)
-	}
-	return typ, r, nil
-}
-
-func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, error) {
-	if lock {
-		err := c.acquireLock(ctx, c.readLock)
-		if err != nil {
-			return 0, nil, err
-		}
-		defer c.releaseLock(c.readLock)
-	}
-
-	if c.activeReader != nil && !c.readerFrameEOF {
-		// The only way we know for sure the previous reader is not yet complete is
-		// if there is an active frame not yet fully read.
-		// Otherwise, a user may have read the last byte but not the EOF if the EOF
-		// is in the next frame so we check for that below.
-		return 0, nil, errors.New("previous message not read to completion")
-	}
-
-	h, err := c.readTillMsg(ctx)
-	if err != nil {
-		return 0, nil, err
-	}
-
-	if c.activeReader != nil && !c.activeReader.eof() {
-		if h.opcode != opContinuation {
-			err := errors.New("received new data message without finishing the previous message")
-			c.exportedClose(StatusProtocolError, err.Error(), false)
-			return 0, nil, err
-		}
-
-		if !h.fin || h.payloadLength > 0 {
-			return 0, nil, fmt.Errorf("previous message not read to completion")
-		}
-
-		c.activeReader = nil
-
-		h, err = c.readTillMsg(ctx)
-		if err != nil {
-			return 0, nil, err
-		}
-	} else if h.opcode == opContinuation {
-		err := errors.New("received continuation frame not after data or text frame")
-		c.exportedClose(StatusProtocolError, err.Error(), false)
-		return 0, nil, err
-	}
-
-	c.readerMsgCtx = ctx
-	c.readerMsgHeader = h
-
-	c.readerPayloadCompressed = h.rsv1
-
-	if c.readerPayloadCompressed {
-		c.readerCompressTail.Reset(deflateMessageTail)
-	}
-
-	c.readerFrameEOF = false
-	c.readerMaskKey = h.maskKey
-	c.readMsgLeft = c.msgReadLimit.Load()
-
-	r := &messageReader{
-		c: c,
-	}
-	c.activeReader = r
-	if c.readerPayloadCompressed && c.readNoContextTakeOver() {
-		c.fr = getFlateReader(c.payloadReader)
-	}
-	return MessageType(h.opcode), r, nil
-}
-
-type framePayloadReader struct {
-	c *Conn
-}
-
-func (r framePayloadReader) Read(p []byte) (int, error) {
-	if r.c.readerFrameEOF {
-		if r.c.readerPayloadCompressed && r.c.readerMsgHeader.fin {
-			n, _ := r.c.readerCompressTail.Read(p)
-			return n, nil
-		}
-
-		h, err := r.c.readTillMsg(r.c.readerMsgCtx)
-		if err != nil {
-			return 0, err
-		}
-
-		if h.opcode != opContinuation {
-			err := errors.New("received new data message without finishing the previous message")
-			r.c.exportedClose(StatusProtocolError, err.Error(), false)
-			return 0, err
-		}
-
-		r.c.readerMsgHeader = h
-		r.c.readerFrameEOF = false
-		r.c.readerMaskKey = h.maskKey
-	}
-
-	h := r.c.readerMsgHeader
-	if int64(len(p)) > h.payloadLength {
-		p = p[:h.payloadLength]
-	}
-
-	n, err := r.c.readFramePayload(r.c.readerMsgCtx, p)
-
-	h.payloadLength -= int64(n)
-	if h.masked {
-		r.c.readerMaskKey = mask(r.c.readerMaskKey, p)
-	}
-	r.c.readerMsgHeader = h
-
-	if err != nil {
-		return n, err
-	}
-
-	if h.payloadLength == 0 {
-		r.c.readerFrameEOF = true
-
-		if h.fin && !r.c.readerPayloadCompressed {
-			return n, io.EOF
-		}
-	}
-
-	return n, nil
-}
-
-// messageReader enables reading a data frame from the WebSocket connection.
-type messageReader struct {
-	c *Conn
-}
-
-func (r *messageReader) eof() bool {
-	return r.c.activeReader != r
-}
-
-// Read reads as many bytes as possible into p.
-func (r *messageReader) Read(p []byte) (int, error) {
-	return r.exportedRead(p, true)
-}
-
-func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
-	n, err := r.read(p, lock)
-	if err != nil {
-		// Have to return io.EOF directly for now, we cannot wrap as errors.Is
-		// isn't used widely yet.
-		if errors.Is(err, io.EOF) {
-			return n, io.EOF
-		}
-		return n, fmt.Errorf("failed to read: %w", err)
-	}
-	return n, nil
-}
-
-func (r *messageReader) readUnlocked(p []byte) (int, error) {
-	return r.exportedRead(p, false)
-}
-
-func (r *messageReader) read(p []byte, lock bool) (int, error) {
-	if lock {
-		// If we cannot acquire the read lock, then
-		// there is either a concurrent read or the close handshake
-		// is proceeding.
-		select {
-		case r.c.readLock <- struct{}{}:
-			defer r.c.releaseLock(r.c.readLock)
-		default:
-			if r.c.closing.Load() == 1 {
-				<-r.c.closed
-				return 0, r.c.closeErr
-			}
-			return 0, errors.New("concurrent read detected")
-		}
-	}
-
-	if r.eof() {
-		return 0, errors.New("cannot use EOFed reader")
-	}
-
-	if r.c.readMsgLeft <= 0 {
-		err := fmt.Errorf("read limited at %v bytes", r.c.msgReadLimit)
-		r.c.exportedClose(StatusMessageTooBig, err.Error(), false)
-		return 0, err
-	}
-
-	if int64(len(p)) > r.c.readMsgLeft {
-		p = p[:r.c.readMsgLeft]
-	}
-
-	pr := io.Reader(r.c.payloadReader)
-	if r.c.readerPayloadCompressed {
-		pr = r.c.fr
-	}
-
-	n, err := pr.Read(p)
-
-	r.c.readMsgLeft -= int64(n)
-
-	if r.c.readerFrameEOF && r.c.readerMsgHeader.fin {
-		if r.c.readerPayloadCompressed && r.c.readNoContextTakeOver() {
-			putFlateReader(r.c.fr)
-			r.c.fr = nil
-		}
-		r.c.activeReader = nil
-		if err == nil {
-			err = io.EOF
-		}
-	}
-
-	return n, err
-}
-
-func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
-	wrap := func(err error) error {
-		return fmt.Errorf("failed to read frame payload: %w", err)
-	}
-	defer func() {
-		if err != nil {
-			err = wrap(err)
-		}
-	}()
-
-	err = c.acquireLock(ctx, c.readFrameLock)
-	if err != nil {
-		return 0, err
-	}
-	defer c.releaseLock(c.readFrameLock)
-
-	select {
-	case <-c.closed:
-		return 0, c.closeErr
-	case c.setReadTimeout <- ctx:
-	}
-
-	n, err := io.ReadFull(c.br, p)
-	if err != nil {
-		select {
-		case <-c.closed:
-			return n, c.closeErr
-		case <-ctx.Done():
-			err = ctx.Err()
-		default:
-		}
-		c.releaseLock(c.readFrameLock)
-		c.close(wrap(err))
-		return n, err
-	}
-
-	select {
-	case <-c.closed:
-		return n, c.closeErr
-	case c.setReadTimeout <- context.Background():
-	}
-
-	return n, err
-}
-
-// Read is a convenience method to read a single message from the connection.
-//
-// See the Reader method if you want to be able to reuse buffers or want to stream a message.
-// The docs on Reader apply to this method as well.
-func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
-	typ, r, err := c.Reader(ctx)
-	if err != nil {
-		return 0, nil, err
-	}
-
-	b, err := ioutil.ReadAll(r)
-	return typ, b, err
-}
-
-// Writer returns a writer bounded by the context that will write
-// a WebSocket message of type dataType to the connection.
-//
-// You must close the writer once you have written the entire message.
-//
-// Only one writer can be open at a time, multiple calls will block until the previous writer
-// is closed.
-func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
-	wc, err := c.writer(ctx, typ)
-	if err != nil {
-		return nil, fmt.Errorf("failed to get writer: %w", err)
-	}
-	return wc, nil
-}
-
-func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
-	err := c.acquireLock(ctx, c.writeMsgLock)
-	if err != nil {
-		return nil, err
-	}
-	c.writeMsgCtx = ctx
-	c.writeMsgOpcode = opcode(typ)
-	w := &messageWriter{
-		c: c,
-	}
-	c.activeWriter.Store(w)
-	return w, nil
-}
-
-// Write is a convenience method to write a message to the connection.
-//
-// See the Writer method if you want to stream a message.
-func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
-	_, err := c.write(ctx, typ, p)
-	if err != nil {
-		return fmt.Errorf("failed to write msg: %w", err)
-	}
-	return nil
-}
-
-func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
-	err := c.acquireLock(ctx, c.writeMsgLock)
-	if err != nil {
-		return 0, err
-	}
-	defer c.releaseLock(c.writeMsgLock)
-
-	n, err := c.writeFrame(ctx, true, opcode(typ), p)
-	return n, err
-}
-
-// messageWriter enables writing to a WebSocket connection.
-type messageWriter struct {
-	c *Conn
-}
-
-func (w *messageWriter) closed() bool {
-	return w != w.c.activeWriter.Load()
-}
-
-// Write writes the given bytes to the WebSocket connection.
-func (w *messageWriter) Write(p []byte) (int, error) {
-	n, err := w.write(p)
-	if err != nil {
-		return n, fmt.Errorf("failed to write: %w", err)
-	}
-	return n, nil
-}
-
-func (w *messageWriter) write(p []byte) (int, error) {
-	if w.closed() {
-		return 0, fmt.Errorf("cannot use closed writer")
-	}
-	n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
-	if err != nil {
-		return n, fmt.Errorf("failed to write data frame: %w", err)
-	}
-	w.c.writeMsgOpcode = opContinuation
-	return n, nil
-}
-
-// Close flushes the frame to the connection.
-// This must be called for every messageWriter.
-func (w *messageWriter) Close() error {
-	err := w.close()
-	if err != nil {
-		return fmt.Errorf("failed to close writer: %w", err)
-	}
-	return nil
-}
-
-func (w *messageWriter) close() error {
-	if w.closed() {
-		return fmt.Errorf("cannot use closed writer")
-	}
-	w.c.activeWriter.Store((*messageWriter)(nil))
-
-	_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
-	if err != nil {
-		return fmt.Errorf("failed to write fin frame: %w", err)
-	}
-
-	w.c.releaseLock(w.c.writeMsgLock)
-	return nil
-}
-
-func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
-	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
-	defer cancel()
-
-	_, err := c.writeFrame(ctx, true, opcode, p)
-	if err != nil {
-		return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
-	}
-	return nil
-}
-
-// writeFrame handles all writes to the connection.
-func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
-	err := c.acquireLock(ctx, c.writeFrameLock)
-	if err != nil {
-		return 0, err
-	}
-	defer c.releaseLock(c.writeFrameLock)
-
-	select {
-	case <-c.closed:
-		return 0, c.closeErr
-	case c.setWriteTimeout <- ctx:
-	}
-
-	c.writeHeader.fin = fin
-	c.writeHeader.opcode = opcode
-	c.writeHeader.masked = c.client
-	c.writeHeader.payloadLength = int64(len(p))
-
-	if c.client {
-		err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey)
-		if err != nil {
-			return 0, fmt.Errorf("failed to generate masking key: %w", err)
-		}
-	}
-
-	n, err := c.realWriteFrame(ctx, *c.writeHeader, p)
-	if err != nil {
-		return n, err
-	}
-
-	// We already finished writing, no need to potentially brick the connection if
-	// the context expires.
-	select {
-	case <-c.closed:
-		return n, c.closeErr
-	case c.setWriteTimeout <- context.Background():
-	}
-
-	return n, nil
-}
-
-func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) {
-	defer func() {
-		if err != nil {
-			select {
-			case <-c.closed:
-				err = c.closeErr
-			case <-ctx.Done():
-				err = ctx.Err()
-			default:
-			}
-
-			err = fmt.Errorf("failed to write %v frame: %w", h.opcode, err)
-			// We need to release the lock first before closing the connection to ensure
-			// the lock can be acquired inside close to ensure no one can access c.bw.
-			c.releaseLock(c.writeFrameLock)
-			c.close(err)
-		}
-	}()
-
-	headerBytes := writeHeader(c.writeHeaderBuf, h)
-	_, err = c.bw.Write(headerBytes)
-	if err != nil {
-		return 0, err
-	}
-
-	if c.client {
-		maskKey := h.maskKey
-		for len(p) > 0 {
-			if c.bw.Available() == 0 {
-				err = c.bw.Flush()
-				if err != nil {
-					return n, err
-				}
-			}
-
-			// Start of next write in the buffer.
-			i := c.bw.Buffered()
-
-			p2 := p
-			if len(p) > c.bw.Available() {
-				p2 = p[:c.bw.Available()]
-			}
-
-			n2, err := c.bw.Write(p2)
-			if err != nil {
-				return n, err
-			}
-
-			maskKey = mask(maskKey, c.writeBuf[i:i+n2])
-
-			p = p[n2:]
-			n += n2
-		}
-	} else {
-		n, err = c.bw.Write(p)
-		if err != nil {
-			return n, err
-		}
-	}
-
-	if h.fin {
-		err = c.bw.Flush()
-		if err != nil {
-			return n, err
-		}
-	}
-
-	return n, nil
-}
-
-// Close closes the WebSocket connection with the given status code and reason.
-//
-// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
-// the peer to send a close frame.
-// Thus, it implements the full WebSocket close handshake.
-// All data messages received from the peer during the close handshake
-// will be discarded.
-//
-// The connection can only be closed once. Additional calls to Close
-// are no-ops.
-//
-// The maximum length of reason must be 125 bytes otherwise an internal
-// error will be sent to the peer. For this reason, you should avoid
-// sending a dynamic reason.
-//
-// Close will unblock all goroutines interacting with the connection once
-// complete.
-func (c *Conn) Close(code StatusCode, reason string) error {
-	err := c.exportedClose(code, reason, true)
-	var ec errClosing
-	if errors.As(err, &ec) {
-		<-c.closed
-		// We wait until the connection closes.
-		// We use writeClose and not exportedClose to avoid a second failed to marshal close frame error.
-		err = c.writeClose(nil, ec.ce, true)
-	}
-	if err != nil {
-		return fmt.Errorf("failed to close websocket connection: %w", err)
-	}
-	return nil
-}
-
-func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) error {
-	ce := CloseError{
-		Code:   code,
-		Reason: reason,
-	}
-
-	// This function also will not wait for a close frame from the peer like the RFC
-	// wants because that makes no sense and I don't think anyone actually follows that.
-	// Definitely worth seeing what popular browsers do later.
-	p, err := ce.bytes()
-	if err != nil {
-		c.logf("websocket: failed to marshal close frame: %+v", err)
-		ce = CloseError{
-			Code: StatusInternalError,
-		}
-		p, _ = ce.bytes()
-	}
-
-	return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake)
-}
-
-type errClosing struct {
-	ce error
-}
-
-func (e errClosing) Error() string {
-	return "already closing connection"
-}
-
-func (c *Conn) writeClose(p []byte, ce error, handshake bool) error {
-	if c.isClosed() {
-		return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
-	}
-
-	if !c.closing.CAS(0, 1) {
-		// Normally, we would want to wait until the connection is closed,
-		// at least for when a user calls into Close, so we handle that case in
-		// the exported Close function.
-		//
-		// But for internal library usage, we always want to return early, e.g.
-		// if we are performing a close handshake and the peer sends their close frame,
-		// we do not want to block here waiting for c.closed to close because it won't,
-		// at least not until we return since the gorouine that will close it is this one.
-		return errClosing{
-			ce: ce,
-		}
-	}
-
-	// No matter what happens next, close error should be set.
-	c.setCloseErr(ce)
-	defer c.close(nil)
-
-	err := c.writeControl(context.Background(), opClose, p)
-	if err != nil {
-		return err
-	}
-
-	if handshake {
-		err = c.waitClose()
-		if CloseStatus(err) == -1 {
-			// waitClose exited not due to receiving a close frame.
-			return fmt.Errorf("failed to wait for peer close frame: %w", err)
-		}
-	}
-
-	return nil
-}
-
-func (c *Conn) waitClose() error {
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	defer cancel()
-
-	err := c.acquireLock(ctx, c.readLock)
-	if err != nil {
-		return err
-	}
-	defer c.releaseLock(c.readLock)
-
-	if c.closeReceived != nil {
-		// goroutine reading just received the close.
-		return c.closeReceived
-	}
-
-	b := bufpool.Get()
-	buf := b.Bytes()
-	buf = buf[:cap(buf)]
-	defer bufpool.Put(b)
-
-	for {
-		if c.activeReader == nil || c.readerFrameEOF {
-			_, _, err := c.reader(ctx, false)
-			if err != nil {
-				return fmt.Errorf("failed to get reader: %w", err)
-			}
-		}
-
-		r := readerFunc(c.activeReader.readUnlocked)
-		_, err = io.CopyBuffer(ioutil.Discard, r, buf)
-		if err != nil {
-			return err
-		}
-	}
+func (c *Conn) deflateNegotiated() bool {
+	return c.copts != nil
 }
 
 // Ping sends a ping to the peer and waits for a pong.
@@ -1056,9 +153,9 @@ func (c *Conn) waitClose() error {
 //
 // TCP Keepalives should suffice for most use cases.
 func (c *Conn) Ping(ctx context.Context) error {
-	p := c.pingCounter.Increment(1)
+	p := atomic.AddInt32(&c.pingCounter, 1)
 
-	err := c.ping(ctx, strconv.FormatInt(p, 10))
+	err := c.ping(ctx, strconv.Itoa(int(p)))
 	if err != nil {
 		return fmt.Errorf("failed to ping: %w", err)
 	}
@@ -1078,7 +175,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
 		c.activePingsMu.Unlock()
 	}()
 
-	err := c.writeControl(ctx, opPing, []byte(p))
+	err := c.cw.control(ctx, opPing, []byte(p))
 	if err != nil {
 		return err
 	}
@@ -1095,109 +192,37 @@ func (c *Conn) ping(ctx context.Context, p string) error {
 	}
 }
 
-type readerFunc func(p []byte) (int, error)
-
-func (f readerFunc) Read(p []byte) (int, error) {
-	return f(p)
-}
-
-type writerFunc func(p []byte) (int, error)
-
-func (f writerFunc) Write(p []byte) (int, error) {
-	return f(p)
-}
-
-// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
-// and stores it in c.writeBuf.
-func (c *Conn) extractBufioWriterBuf(w io.Writer) {
-	c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
-		c.writeBuf = p2[:cap(p2)]
-		return len(p2), nil
-	}))
-
-	c.bw.WriteByte(0)
-	c.bw.Flush()
-
-	c.bw.Reset(w)
-}
-
-var flateWriterPool = &sync.Pool{
-	New: func() interface{} {
-		w, _ := flate.NewWriter(nil, flate.BestSpeed)
-		return w
-	},
-}
-
-func getFlateWriter(w io.Writer) *flate.Writer {
-	fw := flateWriterPool.Get().(*flate.Writer)
-	fw.Reset(w)
-	return fw
-}
-
-func putFlateWriter(w *flate.Writer) {
-	flateWriterPool.Put(w)
+type mu struct {
+	once sync.Once
+	ch chan struct{}
 }
 
-var flateReaderPool = &sync.Pool{
-	New: func() interface{} {
-		return flate.NewReader(nil)
-	},
-}
-
-func getFlateReader(r io.Reader) io.Reader {
-	fr := flateReaderPool.Get().(io.Reader)
-	fr.(flate.Resetter).Reset(r, nil)
-	return fr
-}
-
-func putFlateReader(fr io.Reader) {
-	flateReaderPool.Put(fr)
-}
-
-func (c *Conn) writeNoContextTakeOver() bool {
-	return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover
-}
-
-func (c *Conn) readNoContextTakeOver() bool {
-	return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover
-}
-
-type trimLastFourBytesWriter struct {
-	w    io.Writer
-	tail []byte
+func (m *mu) init() {
+	m.once.Do(func() {
+		m.ch = make(chan struct{}, 1)
+	})
 }
 
-func (w *trimLastFourBytesWriter) Write(p []byte) (int, error) {
-	extra := len(w.tail) + len(p) - 4
-
-	if extra <= 0 {
-		w.tail = append(w.tail, p...)
-		return len(p), nil
-	}
-
-	// Now we need to write as many extra bytes as we can from the previous tail.
-	if extra > len(w.tail) {
-		extra = len(w.tail)
-	}
-	if extra > 0 {
-		_, err := w.Write(w.tail[:extra])
-		if err != nil {
-			return 0, err
-		}
-		w.tail = w.tail[extra:]
+func (m *mu) Lock(ctx context.Context) error {
+	m.init()
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	case m.ch <- struct{}{}:
+		return nil
 	}
+}
 
-	// If p is less than or equal to 4 bytes,
-	// all of it is is part of the tail.
-	if len(p) <= 4 {
-		w.tail = append(w.tail, p...)
-		return len(p), nil
+func (m *mu) TryLock() bool {
+	m.init()
+	select {
+	case m.ch <- struct{}{}:
+		return true
+	default:
+		return false
 	}
+}
 
-	// Otherwise, only the last 4 bytes are.
-	w.tail = append(w.tail, p[len(p)-4:]...)
-
-	p = p[:len(p)-4]
-	n, err := w.w.Write(p)
-	return n + 4, err
+func (m *mu) Unlock() {
+	<-m.ch
 }
diff --git a/conn_export_test.go b/conn_export_test.go
deleted file mode 100644
index d5f5aa2..0000000
--- a/conn_export_test.go
+++ /dev/null
@@ -1,129 +0,0 @@
-// +build !js
-
-package websocket
-
-import (
-	"bufio"
-	"context"
-	"fmt"
-)
-
-type (
-	Addr   = websocketAddr
-	OpCode int
-)
-
-const (
-	OpClose        = OpCode(opClose)
-	OpBinary       = OpCode(opBinary)
-	OpText         = OpCode(opText)
-	OpPing         = OpCode(opPing)
-	OpPong         = OpCode(opPong)
-	OpContinuation = OpCode(opContinuation)
-)
-
-func (c *Conn) SetLogf(fn func(format string, v ...interface{})) {
-	c.logf = fn
-}
-
-func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) {
-	h, err := c.readFrameHeader(ctx)
-	if err != nil {
-		return 0, nil, err
-	}
-	b := make([]byte, h.payloadLength)
-	_, err = c.readFramePayload(ctx, b)
-	if err != nil {
-		return 0, nil, err
-	}
-	if h.masked {
-		mask(h.maskKey, b)
-	}
-	return OpCode(h.opcode), b, nil
-}
-
-func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) (int, error) {
-	return c.writeFrame(ctx, fin, opcode(opc), p)
-}
-
-// header represents a WebSocket frame header.
-// See https://tools.ietf.org/html/rfc6455#section-5.2
-type Header struct {
-	Fin    bool
-	Rsv1   bool
-	Rsv2   bool
-	Rsv3   bool
-	OpCode OpCode
-
-	PayloadLength int64
-}
-
-func (c *Conn) WriteHeader(ctx context.Context, h Header) error {
-	headerBytes := writeHeader(c.writeHeaderBuf, header{
-		fin:           h.Fin,
-		rsv1:          h.Rsv1,
-		rsv2:          h.Rsv2,
-		rsv3:          h.Rsv3,
-		opcode:        opcode(h.OpCode),
-		payloadLength: h.PayloadLength,
-		masked:        c.client,
-	})
-	_, err := c.bw.Write(headerBytes)
-	if err != nil {
-		return fmt.Errorf("failed to write header: %w", err)
-	}
-	if h.Fin {
-		err = c.Flush()
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func (c *Conn) PingWithPayload(ctx context.Context, p string) error {
-	return c.ping(ctx, p)
-}
-
-func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) {
-	return c.realWriteFrame(ctx, header{
-		fin:           true,
-		opcode:        opBinary,
-		payloadLength: 10,
-	}, make([]byte, 5))
-}
-
-func (c *Conn) CloseUnderlyingConn() {
-	c.closer.Close()
-}
-
-func (c *Conn) Flush() error {
-	return c.bw.Flush()
-}
-
-func (c CloseError) Bytes() ([]byte, error) {
-	return c.bytes()
-}
-
-func (c *Conn) BW() *bufio.Writer {
-	return c.bw
-}
-
-func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) ([]byte, error) {
-	b, err := CloseError{
-		Code:   code,
-		Reason: reason,
-	}.Bytes()
-	if err != nil {
-		return nil, err
-	}
-	_, err = c.WriteFrame(ctx, true, OpClose, b)
-	if err != nil {
-		return nil, err
-	}
-	return b, nil
-}
-
-func ParseClosePayload(p []byte) (CloseError, error) {
-	return parseClosePayload(p)
-}
diff --git a/conn_test.go b/conn_test.go
index d03a721..992c886 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -3,969 +3,28 @@
 package websocket_test
 
 import (
-	"bytes"
 	"context"
-	"encoding/binary"
-	"encoding/json"
-	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
-	"math/rand"
-	"net"
 	"net/http"
-	"net/http/cookiejar"
 	"net/http/httptest"
-	"net/url"
-	"os"
-	"os/exec"
-	"reflect"
-	"strconv"
+	"nhooyr.io/websocket/internal/assert"
 	"strings"
+	"sync/atomic"
 	"testing"
 	"time"
 
-	"github.com/golang/protobuf/proto"
-	"github.com/golang/protobuf/ptypes"
-	"github.com/golang/protobuf/ptypes/timestamp"
-	"go.uber.org/multierr"
-
 	"nhooyr.io/websocket"
-	"nhooyr.io/websocket/internal/assert"
-	"nhooyr.io/websocket/internal/wsecho"
-	"nhooyr.io/websocket/internal/wsgrace"
-	"nhooyr.io/websocket/wsjson"
-	"nhooyr.io/websocket/wspb"
 )
 
-func init() {
-	rand.Seed(time.Now().UnixNano())
-}
-
-func TestHandshake(t *testing.T) {
-	t.Parallel()
-
-	testCases := []struct {
-		name   string
-		client func(ctx context.Context, url string) error
-		server func(w http.ResponseWriter, r *http.Request) error
-	}{
-		{
-			name: "badOrigin",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, nil)
-				if err == nil {
-					c.Close(websocket.StatusInternalError, "")
-					return errors.New("expected error regarding bad origin")
-				}
-				return assertErrorContains(err, "not authorized")
-			},
-			client: func(ctx context.Context, u string) error {
-				h := http.Header{}
-				h.Set("Origin", "http://unauthorized.com")
-				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
-					HTTPHeader: h,
-				})
-				if err == nil {
-					c.Close(websocket.StatusInternalError, "")
-					return errors.New("expected handshake failure")
-				}
-				return assertErrorContains(err, "403")
-			},
-		},
-		{
-			name: "acceptSecureOrigin",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, nil)
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			},
-			client: func(ctx context.Context, u string) error {
-				h := http.Header{}
-				h.Set("Origin", u)
-				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
-					HTTPHeader: h,
-				})
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			},
-		},
-		{
-			name: "acceptInsecureOrigin",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-					InsecureSkipVerify: true,
-				})
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			},
-			client: func(ctx context.Context, u string) error {
-				h := http.Header{}
-				h.Set("Origin", "https://example.com")
-				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
-					HTTPHeader: h,
-				})
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			},
-		},
-		{
-			name: "cookies",
-			server: func(w http.ResponseWriter, r *http.Request) error {
-				cookie, err := r.Cookie("mycookie")
-				if err != nil {
-					return fmt.Errorf("request is missing mycookie: %w", err)
-				}
-				err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value")
-				if err != nil {
-					return err
-				}
-				c, err := websocket.Accept(w, r, nil)
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			},
-			client: func(ctx context.Context, u string) error {
-				jar, err := cookiejar.New(nil)
-				if err != nil {
-					return fmt.Errorf("failed to create cookie jar: %w", err)
-				}
-				parsedURL, err := url.Parse(u)
-				if err != nil {
-					return fmt.Errorf("failed to parse url: %w", err)
-				}
-				parsedURL.Scheme = "http"
-				jar.SetCookies(parsedURL, []*http.Cookie{
-					{
-						Name:  "mycookie",
-						Value: "myvalue",
-					},
-				})
-				hc := &http.Client{
-					Jar: jar,
-				}
-				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
-					HTTPClient: hc,
-				})
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			},
-		},
-	}
-
-	for _, tc := range testCases {
-		tc := tc
-		t.Run(tc.name, func(t *testing.T) {
-			t.Parallel()
-
-			s, closeFn := testServer(t, tc.server, false)
-			defer closeFn()
-
-			wsURL := strings.Replace(s.URL, "http", "ws", 1)
-
-			ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
-			defer cancel()
-
-			err := tc.client(ctx, wsURL)
-			if err != nil {
-				t.Fatalf("client failed: %+v", err)
-			}
-		})
-	}
-}
-
-func TestConn(t *testing.T) {
-	t.Parallel()
-
-	testCases := []struct {
-		name string
-
-		acceptOpts *websocket.AcceptOptions
-		server     func(ctx context.Context, c *websocket.Conn) error
-
-		dialOpts *websocket.DialOptions
-		response func(resp *http.Response) error
-		client   func(ctx context.Context, c *websocket.Conn) error
-	}{
-		{
-			name: "handshake",
-			acceptOpts: &websocket.AcceptOptions{
-				Subprotocols: []string{"myproto"},
-			},
-			dialOpts: &websocket.DialOptions{
-				Subprotocols: []string{"myproto"},
-			},
-			response: func(resp *http.Response) error {
-				headers := map[string]string{
-					"Connection":             "Upgrade",
-					"Upgrade":                "websocket",
-					"Sec-WebSocket-Protocol": "myproto",
-				}
-				for h, exp := range headers {
-					value := resp.Header.Get(h)
-					err := assert.Equalf(exp, value, "unexpected value for header %v", h)
-					if err != nil {
-						return err
-					}
-				}
-				return nil
-			},
-		},
-		{
-			name: "handshake/defaultSubprotocol",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return assertSubprotocol(c, "")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return assertSubprotocol(c, "")
-			},
-		},
-		{
-			name: "handshake/subprotocolPriority",
-			acceptOpts: &websocket.AcceptOptions{
-				Subprotocols: []string{"echo", "lar"},
-			},
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return assertSubprotocol(c, "echo")
-			},
-			dialOpts: &websocket.DialOptions{
-				Subprotocols: []string{"poof", "echo"},
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return assertSubprotocol(c, "echo")
-			},
-		},
-		{
-			name: "closeError",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return wsjson.Write(ctx, c, "hello")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := assertJSONRead(ctx, c, "hello")
-				if err != nil {
-					return err
-				}
-
-				_, _, err = c.Reader(ctx)
-				return assertCloseStatus(err, websocket.StatusInternalError)
-			},
-		},
-		{
-			name: "netConn",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
-				defer nc.Close()
-
-				nc.SetWriteDeadline(time.Time{})
-				time.Sleep(1)
-				nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
-
-				err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
-				if err != nil {
-					return err
-				}
-				err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
-				if err != nil {
-					return err
-				}
-
-				for i := 0; i < 3; i++ {
-					_, err := nc.Write([]byte("hello"))
-					if err != nil {
-						return err
-					}
-				}
-
-				return nil
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
-
-				nc.SetReadDeadline(time.Time{})
-				time.Sleep(1)
-				nc.SetReadDeadline(time.Now().Add(time.Second * 15))
-
-				for i := 0; i < 3; i++ {
-					err := assertNetConnRead(nc, "hello")
-					if err != nil {
-						return err
-					}
-				}
-
-				// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
-				err2 := assertNetConnRead(nc, "hello")
-				err := assert.Equalf(io.EOF, err2, "unexpected error")
-				if err != nil {
-					return err
-				}
-
-				err2 = assertNetConnRead(nc, "hello")
-				return assert.Equalf(io.EOF, err2, "unexpected error")
-			},
-		},
-		{
-			name: "netConn/badReadMsgType",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
-
-				nc.SetDeadline(time.Now().Add(time.Second * 15))
-
-				_, err := nc.Read(make([]byte, 1))
-				return assertErrorContains(err, "unexpected frame type")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := wsjson.Write(ctx, c, "meow")
-				if err != nil {
-					return err
-				}
-
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusUnsupportedData)
-			},
-		},
-		{
-			name: "netConn/badRead",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
-				defer nc.Close()
-
-				nc.SetDeadline(time.Now().Add(time.Second * 15))
-
-				_, err2 := nc.Read(make([]byte, 1))
-				err := assertCloseStatus(err2, websocket.StatusBadGateway)
-				if err != nil {
-					return err
-				}
-
-				_, err2 = nc.Write([]byte{0xff})
-				return assertErrorContains(err2, "websocket closed")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Close(websocket.StatusBadGateway, "")
-			},
-		},
-		{
-			name: "wsjson/echo",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return wsjson.Write(ctx, c, "meow")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return assertJSONRead(ctx, c, "meow")
-			},
-		},
-		{
-			name: "protobuf/echo",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return wspb.Write(ctx, c, ptypes.DurationProto(100))
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return assertProtobufRead(ctx, c, ptypes.DurationProto(100))
-			},
-		},
-		{
-			name: "ping",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				ctx = c.CloseRead(ctx)
-
-				err := c.Ping(ctx)
-				if err != nil {
-					return err
-				}
-
-				err = wsjson.Write(ctx, c, "hi")
-				if err != nil {
-					return err
-				}
-
-				<-ctx.Done()
-				err = c.Ping(context.Background())
-				return assertCloseStatus(err, websocket.StatusNormalClosure)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				// We read a message from the connection and then keep reading until
-				// the Ping completes.
-				pingErrc := make(chan error, 1)
-				go func() {
-					pingErrc <- c.Ping(ctx)
-				}()
-
-				// Once this completes successfully, that means they sent their ping and we responded to it.
-				err := assertJSONRead(ctx, c, "hi")
-				if err != nil {
-					return err
-				}
-
-				// Now we need to ensure we're reading for their pong from our ping.
-				// Need new var to not race with above goroutine.
-				ctx2 := c.CloseRead(ctx)
-
-				// Now we wait for our pong.
-				select {
-				case err = <-pingErrc:
-					return err
-				case <-ctx2.Done():
-					return fmt.Errorf("failed to wait for pong: %w", ctx2.Err())
-				}
-			},
-		},
-		{
-			name: "readLimit",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err2 := c.Read(ctx)
-				return assertErrorContains(err2, "read limited at 32768 bytes")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769)))
-				if err != nil {
-					return err
-				}
-
-				_, _, err2 := c.Read(ctx)
-				return assertCloseStatus(err2, websocket.StatusMessageTooBig)
-			},
-		},
-		{
-			name: "wsjson/binary",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				var v interface{}
-				err2 := wsjson.Read(ctx, c, &v)
-				return assertErrorContains(err2, "unexpected frame type")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return wspb.Write(ctx, c, ptypes.DurationProto(100))
-			},
-		},
-		{
-			name: "wsjson/badRead",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				var v interface{}
-				err2 := wsjson.Read(ctx, c, &v)
-				return assertErrorContains(err2, "failed to unmarshal json")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Write(ctx, websocket.MessageText, []byte("notjson"))
-			},
-		},
-		{
-			name: "wsjson/badWrite",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err2 := c.Read(ctx)
-				return assertCloseStatus(err2, websocket.StatusNormalClosure)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := wsjson.Write(ctx, c, fmt.Println)
-				return assertErrorContains(err, "failed to encode json")
-			},
-		},
-		{
-			name: "wspb/text",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				var v proto.Message
-				err := wspb.Read(ctx, c, v)
-				return assertErrorContains(err, "unexpected frame type")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return wsjson.Write(ctx, c, "hi")
-			},
-		},
-		{
-			name: "wspb/badRead",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				var v timestamp.Timestamp
-				err := wspb.Read(ctx, c, &v)
-				return assertErrorContains(err, "failed to unmarshal protobuf")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Write(ctx, websocket.MessageBinary, []byte("notpb"))
-			},
-		},
-		{
-			name: "wspb/badWrite",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusNormalClosure)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := wspb.Write(ctx, c, nil)
-				return assertErrorIs(proto.ErrNil, err)
-			},
-		},
-		{
-			name: "badClose",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Close(9999, "")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusInternalError)
-			},
-		},
-		{
-			name: "pingTimeout",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				ctx, cancel := context.WithTimeout(ctx, time.Second)
-				defer cancel()
-				err := c.Ping(ctx)
-				return assertErrorIs(context.DeadlineExceeded, err)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				err1 := assertErrorContains(err, "connection reset")
-				err2 := assertErrorIs(io.EOF, err)
-				if err1 != nil || err2 != nil {
-					return nil
-				}
-				return multierr.Combine(err1, err2)
-			},
-		},
-		{
-			name: "writeTimeout",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				c.Writer(ctx, websocket.MessageBinary)
-
-				ctx, cancel := context.WithTimeout(ctx, time.Second)
-				defer cancel()
-				err := c.Write(ctx, websocket.MessageBinary, []byte("meow"))
-				return assertErrorIs(context.DeadlineExceeded, err)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorIs(io.EOF, err)
-			},
-		},
-		{
-			name: "readTimeout",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				ctx, cancel := context.WithTimeout(ctx, time.Second)
-				defer cancel()
-				_, _, err := c.Read(ctx)
-				return assertErrorIs(context.DeadlineExceeded, err)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorIs(websocket.CloseError{
-					Code:   websocket.StatusPolicyViolation,
-					Reason: "read timed out",
-				}, err)
-			},
-		},
-		{
-			name: "badOpCode",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, err := c.WriteFrame(ctx, true, 13, []byte("meow"))
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Read(ctx)
-				return assertErrorContains(err, "unknown opcode")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorContains(err, "unknown opcode")
-			},
-		},
-		{
-			name: "noRsv",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, err := c.WriteFrame(ctx, true, 99, []byte("meow"))
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusProtocolError)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				if err == nil || !strings.Contains(err.Error(), "rsv") {
-					return fmt.Errorf("expected error that contains rsv: %+v", err)
-				}
-				return nil
-			},
-		},
-		{
-			name: "largeControlFrame",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				err := c.WriteHeader(ctx, websocket.Header{
-					Fin:           true,
-					OpCode:        websocket.OpClose,
-					PayloadLength: 4096,
-				})
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusProtocolError)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorContains(err, "too big")
-			},
-		},
-		{
-			name: "fragmentedControlFrame",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, err := c.WriteFrame(ctx, false, websocket.OpPing, []byte(strings.Repeat("x", 32)))
-				if err != nil {
-					return err
-				}
-				err = c.Flush()
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusProtocolError)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorContains(err, "fragmented")
-			},
-		},
-		{
-			name: "invalidClosePayload",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{0x17, 0x70})
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusProtocolError)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorContains(err, "invalid status code")
-			},
-		},
-		{
-			name: "doubleReader",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, r, err := c.Reader(ctx)
-				if err != nil {
-					return err
-				}
-				p := make([]byte, 10)
-				_, err = io.ReadFull(r, p)
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Reader(ctx)
-				return assertErrorContains(err, "previous message not read to completion")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11)))
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusInternalError)
-			},
-		},
-		{
-			name: "doubleFragmentedReader",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, r, err := c.Reader(ctx)
-				if err != nil {
-					return err
-				}
-				p := make([]byte, 10)
-				_, err = io.ReadFull(r, p)
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Reader(ctx)
-				return assertErrorContains(err, "previous message not read to completion")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				w, err := c.Writer(ctx, websocket.MessageBinary)
-				if err != nil {
-					return err
-				}
-				_, err = w.Write([]byte(strings.Repeat("x", 10)))
-				if err != nil {
-					return fmt.Errorf("expected non nil error")
-				}
-				err = c.Flush()
-				if err != nil {
-					return fmt.Errorf("failed to flush: %w", err)
-				}
-				_, err = w.Write([]byte(strings.Repeat("x", 10)))
-				if err != nil {
-					return fmt.Errorf("expected non nil error")
-				}
-				err = c.Flush()
-				if err != nil {
-					return fmt.Errorf("failed to flush: %w", err)
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusInternalError)
-			},
-		},
-		{
-			name: "newMessageInFragmentedMessage",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, r, err := c.Reader(ctx)
-				if err != nil {
-					return err
-				}
-				p := make([]byte, 10)
-				_, err = io.ReadFull(r, p)
-				if err != nil {
-					return err
-				}
-				_, _, err = c.Reader(ctx)
-				return assertErrorContains(err, "received new data message without finishing")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				w, err := c.Writer(ctx, websocket.MessageBinary)
-				if err != nil {
-					return err
-				}
-				_, err = w.Write([]byte(strings.Repeat("x", 10)))
-				if err != nil {
-					return fmt.Errorf("expected non nil error")
-				}
-				err = c.Flush()
-				if err != nil {
-					return fmt.Errorf("failed to flush: %w", err)
-				}
-				_, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10)))
-				if err != nil {
-					return fmt.Errorf("expected non nil error")
-				}
-				_, _, err = c.Read(ctx)
-				return assertErrorContains(err, "received new data message without finishing")
-			},
-		},
-		{
-			name: "continuationFrameWithoutDataFrame",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Reader(ctx)
-				return assertErrorContains(err, "received continuation frame not after data")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, err := c.WriteFrame(ctx, false, websocket.OpContinuation, []byte(strings.Repeat("x", 10)))
-				return err
-			},
-		},
-		{
-			name: "readBeforeEOF",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, r, err := c.Reader(ctx)
-				if err != nil {
-					return err
-				}
-				var v interface{}
-				d := json.NewDecoder(r)
-				err = d.Decode(&v)
-				if err != nil {
-					return err
-				}
-				err = assert.Equalf("hi", v, "unexpected JSON")
-				if err != nil {
-					return err
-				}
-				_, b, err := c.Read(ctx)
-				if err != nil {
-					return err
-				}
-				return assert.Equalf("hi", string(b), "unexpected JSON")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				err := wsjson.Write(ctx, c, "hi")
-				if err != nil {
-					return err
-				}
-				return c.Write(ctx, websocket.MessageText, []byte("hi"))
-			},
-		},
-		{
-			name: "newMessageInFragmentedMessage2",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, r, err := c.Reader(ctx)
-				if err != nil {
-					return err
-				}
-				p := make([]byte, 11)
-				_, err = io.ReadFull(r, p)
-				return assertErrorContains(err, "received new data message without finishing")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				w, err := c.Writer(ctx, websocket.MessageBinary)
-				if err != nil {
-					return err
-				}
-				_, err = w.Write([]byte(strings.Repeat("x", 10)))
-				if err != nil {
-					return fmt.Errorf("expected non nil error")
-				}
-				err = c.Flush()
-				if err != nil {
-					return fmt.Errorf("failed to flush: %w", err)
-				}
-				_, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10)))
-				if err != nil {
-					return fmt.Errorf("expected non nil error")
-				}
-				_, _, err = c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusProtocolError)
-			},
-		},
-		{
-			name: "doubleRead",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, r, err := c.Reader(ctx)
-				if err != nil {
-					return err
-				}
-				_, err = ioutil.ReadAll(r)
-				if err != nil {
-					return err
-				}
-				_, err = r.Read(make([]byte, 1))
-				return assertErrorContains(err, "cannot use EOFed reader")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Write(ctx, websocket.MessageBinary, []byte("hi"))
-			},
-		},
-		{
-			name: "eofInPayload",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertErrorContains(err, "failed to read frame payload")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				_, err := c.WriteHalfFrame(ctx)
-				if err != nil {
-					return err
-				}
-				c.CloseUnderlyingConn()
-				return nil
-			},
-		},
-		{
-			name: "closeHandshake",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Close(websocket.StatusNormalClosure, "")
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				return c.Close(websocket.StatusNormalClosure, "")
-			},
-		},
-		{
-			// Issue #164
-			name: "closeHandshake_concurrentRead",
-			server: func(ctx context.Context, c *websocket.Conn) error {
-				_, _, err := c.Read(ctx)
-				return assertCloseStatus(err, websocket.StatusNormalClosure)
-			},
-			client: func(ctx context.Context, c *websocket.Conn) error {
-				errc := make(chan error, 1)
-				go func() {
-					_, _, err := c.Read(ctx)
-					errc <- err
-				}()
-
-				err := c.Close(websocket.StatusNormalClosure, "")
-				if err != nil {
-					return err
-				}
-
-				err = <-errc
-				return assertCloseStatus(err, websocket.StatusNormalClosure)
-			},
-		},
-	}
-	for _, tc := range testCases {
-		tc := tc
-		t.Run(tc.name, func(t *testing.T) {
-			t.Parallel()
-
-			// Run random tests over TLS.
-			tls := rand.Intn(2) == 1
-
-			s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, tc.acceptOpts)
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-				c.SetLogf(t.Logf)
-				if tc.server == nil {
-					return nil
-				}
-				return tc.server(r.Context(), c)
-			}, tls)
-			defer closeFn()
-
-			wsURL := strings.Replace(s.URL, "http", "ws", 1)
-
-			ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
-			defer cancel()
-
-			opts := tc.dialOpts
-			if tls {
-				if opts == nil {
-					opts = &websocket.DialOptions{}
-				}
-				opts.HTTPClient = s.Client()
-			}
-
-			c, resp, err := websocket.Dial(ctx, wsURL, opts)
-			if err != nil {
-				t.Fatal(err)
-			}
-			defer c.Close(websocket.StatusInternalError, "")
-			c.SetLogf(t.Logf)
-
-			if tc.response != nil {
-				err = tc.response(resp)
-				if err != nil {
-					t.Fatalf("response asserter failed: %+v", err)
-				}
-			}
-
-			if tc.client != nil {
-				err = tc.client(ctx, c)
-				if err != nil {
-					t.Fatalf("client failed: %+v", err)
-				}
-			}
-
-			c.Close(websocket.StatusNormalClosure, "")
-		})
-	}
-}
-
-func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) {
-	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		err := fn(w, r)
-		if err != nil {
-			tb.Errorf("server failed: %+v", err)
-		}
-	})
+func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) {
+	h := http.HandlerFunc(fn)
 	if tls {
 		s = httptest.NewTLSServer(h)
 	} else {
 		s = httptest.NewServer(h)
 	}
-	closeFn2 := wsgrace.Grace(s.Config)
+	closeFn2 := wsgrace(s.Config)
 	return s, func() {
 		err := closeFn2()
 		if err != nil {
@@ -974,1417 +33,112 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e
 	}
 }
 
-func TestAutobahn(t *testing.T) {
-	t.Parallel()
-
-	run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) {
-		run2 := func(t *testing.T, testingClient bool) {
-			// Run random tests over TLS.
-			tls := rand.Intn(2) == 1
-
-			s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-					Subprotocols: []string{"echo"},
-				})
-				if err != nil {
-					return err
-				}
-				defer c.Close(websocket.StatusInternalError, "")
-
-				ctx := r.Context()
-				if testingClient {
-					err = wsecho.Loop(ctx, c)
-					if err != nil {
-						t.Logf("failed to wsecho: %+v", err)
-					}
-					return nil
-				}
-
-				c.SetReadLimit(1 << 30)
-				err = fn(ctx, c)
-				if err != nil {
-					return err
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return nil
-			}, tls)
-			defer closeFn()
-
-			wsURL := strings.Replace(s.URL, "http", "ws", 1)
-
-			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-			defer cancel()
-
-			opts := &websocket.DialOptions{
-				Subprotocols: []string{"echo"},
-			}
-			if tls {
-				opts.HTTPClient = s.Client()
-			}
-
-			c, _, err := websocket.Dial(ctx, wsURL, opts)
-			if err != nil {
-				t.Fatal(err)
-			}
-			defer c.Close(websocket.StatusInternalError, "")
-
-			if testingClient {
-				c.SetReadLimit(1 << 30)
-				err = fn(ctx, c)
-				if err != nil {
-					t.Fatalf("client failed: %+v", err)
-				}
-				c.Close(websocket.StatusNormalClosure, "")
-				return
-			}
-
-			err = wsecho.Loop(ctx, c)
-			if err != nil {
-				t.Logf("failed to wsecho: %+v", err)
-			}
-		}
-		t.Run(name, func(t *testing.T) {
-			t.Parallel()
+// grace wraps s.Handler to gracefully shutdown WebSocket connections.
+// The returned function must be used to close the server instead of s.Close.
+func wsgrace(s *http.Server) (closeFn func() error) {
+	h := s.Handler
+	var conns int64
+	s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		atomic.AddInt64(&conns, 1)
+		defer atomic.AddInt64(&conns, -1)
 
-			run2(t, true)
-		})
-	}
-
-	// Section 1.
-	t.Run("echo", func(t *testing.T) {
-		t.Parallel()
-
-		lengths := []int{
-			0,
-			125,
-			126,
-			127,
-			128,
-			65535,
-			65536,
-			65536,
-		}
-		run := func(typ websocket.MessageType) {
-			for i, l := range lengths {
-				l := l
-				run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error {
-					p := randBytes(l)
-					if i == len(lengths)-1 {
-						w, err := c.Writer(ctx, typ)
-						if err != nil {
-							return err
-						}
-						for i := 0; i < l; {
-							j := i + 997
-							if j > l {
-								j = l
-							}
-							_, err = w.Write(p[i:j])
-							if err != nil {
-								return err
-							}
+		ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
+		defer cancel()
 
-							i = j
-						}
+		r = r.WithContext(ctx)
 
-						err = w.Close()
-						if err != nil {
-							return err
-						}
-					} else {
-						err := c.Write(ctx, typ, p)
-						if err != nil {
-							return err
-						}
-					}
-					actTyp, p2, err := c.Read(ctx)
-					if err != nil {
-						return err
-					}
-					err = assert.Equalf(typ, actTyp, "unexpected message type")
-					if err != nil {
-						return err
-					}
-					return assert.Equalf(p, p2, "unexpected message")
-				})
-			}
-		}
-
-		run(websocket.MessageText)
-		run(websocket.MessageBinary)
+		h.ServeHTTP(w, r)
 	})
 
-	// Section 2.
-	t.Run("pingPong", func(t *testing.T) {
-		t.Parallel()
-
-		run(t, "emptyPayload", func(ctx context.Context, c *websocket.Conn) error {
-			ctx = c.CloseRead(ctx)
-			return c.PingWithPayload(ctx, "")
-		})
-		run(t, "smallTextPayload", func(ctx context.Context, c *websocket.Conn) error {
-			ctx = c.CloseRead(ctx)
-			return c.PingWithPayload(ctx, "hi")
-		})
-		run(t, "smallBinaryPayload", func(ctx context.Context, c *websocket.Conn) error {
-			ctx = c.CloseRead(ctx)
-			p := bytes.Repeat([]byte{0xFE}, 16)
-			return c.PingWithPayload(ctx, string(p))
-		})
-		run(t, "largeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error {
-			ctx = c.CloseRead(ctx)
-			p := bytes.Repeat([]byte{0xFE}, 125)
-			return c.PingWithPayload(ctx, string(p))
-		})
-		run(t, "tooLargeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error {
-			c.CloseRead(ctx)
-			p := bytes.Repeat([]byte{0xFE}, 126)
-			err := c.PingWithPayload(ctx, string(p))
-			return assertCloseStatus(err, websocket.StatusProtocolError)
-		})
-		run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error {
-			err := assertStreamPing(ctx, c, 125)
-			if err != nil {
-				return err
-			}
-			return c.Close(websocket.StatusNormalClosure, "")
-		})
-		t.Run("unsolicitedPong", func(t *testing.T) {
-			t.Parallel()
-
-			var testCases = []struct {
-				name        string
-				pongPayload string
-				ping        bool
-			}{
-				{
-					name:        "noPayload",
-					pongPayload: "",
-				},
-				{
-					name:        "payload",
-					pongPayload: "hi",
-				},
-				{
-					name:        "pongThenPing",
-					pongPayload: "hi",
-					ping:        true,
-				},
-			}
-			for _, tc := range testCases {
-				tc := tc
-				run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
-					_, err := c.WriteFrame(ctx, true, websocket.OpPong, []byte(tc.pongPayload))
-					if err != nil {
-						return err
-					}
-					if tc.ping {
-						_, err := c.WriteFrame(ctx, true, websocket.OpPing, []byte("meow"))
-						if err != nil {
-							return err
-						}
-						err = assertReadFrame(ctx, c, websocket.OpPong, []byte("meow"))
-						if err != nil {
-							return err
-						}
-					}
-					return c.Close(websocket.StatusNormalClosure, "")
-				})
-			}
-		})
-		run(t, "tenPings", func(ctx context.Context, c *websocket.Conn) error {
-			ctx = c.CloseRead(ctx)
-
-			for i := 0; i < 10; i++ {
-				err := c.Ping(ctx)
-				if err != nil {
-					return err
-				}
-			}
+	return func() error {
+		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+		defer cancel()
 
-			_, err := c.WriteClose(ctx, websocket.StatusNormalClosure, "")
-			if err != nil {
-				return err
-			}
-			<-ctx.Done()
-
-			err = c.Ping(context.Background())
-			return assertCloseStatus(err, websocket.StatusNormalClosure)
-		})
-
-		run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error {
-			for i := 0; i < 10; i++ {
-				err := assertStreamPing(ctx, c, 125)
-				if err != nil {
-					return err
-				}
-			}
-
-			return c.Close(websocket.StatusNormalClosure, "")
-		})
-	})
-
-	// Section 3.
-	// We skip the per octet sending as it will add too much complexity.
-	t.Run("reserved", func(t *testing.T) {
-		t.Parallel()
-
-		var testCases = []struct {
-			name   string
-			header websocket.Header
-		}{
-			{
-				name: "rsv1",
-				header: websocket.Header{
-					Fin:           true,
-					Rsv1:          true,
-					OpCode:        websocket.OpClose,
-					PayloadLength: 0,
-				},
-			},
-			{
-				name: "rsv2",
-				header: websocket.Header{
-					Fin:           true,
-					Rsv2:          true,
-					OpCode:        websocket.OpPong,
-					PayloadLength: 0,
-				},
-			},
-			{
-				name: "rsv3",
-				header: websocket.Header{
-					Fin:           true,
-					Rsv3:          true,
-					OpCode:        websocket.OpBinary,
-					PayloadLength: 0,
-				},
-			},
-			{
-				name: "rsvAll",
-				header: websocket.Header{
-					Fin:           true,
-					Rsv1:          true,
-					Rsv2:          true,
-					Rsv3:          true,
-					OpCode:        websocket.OpText,
-					PayloadLength: 0,
-				},
-			},
-		}
-		for _, tc := range testCases {
-			tc := tc
-			run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
-				err := assertEcho(ctx, c, websocket.MessageText, 4096)
-				if err != nil {
-					return err
-				}
-				err = c.WriteHeader(ctx, tc.header)
-				if err != nil {
-					return err
-				}
-				err = c.Flush()
-				if err != nil {
-					return err
-				}
-				_, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf"))
-				if err != nil {
-					return err
-				}
-				return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-			})
-		}
-	})
-
-	// Section 4.
-	t.Run("opcodes", func(t *testing.T) {
-		t.Parallel()
-
-		testCases := []struct {
-			name    string
-			opcode  websocket.OpCode
-			payload bool
-			echo    bool
-			ping    bool
-		}{
-			// Section 1.
-			{
-				name:   "3",
-				opcode: 3,
-			},
-			{
-				name:    "4",
-				opcode:  4,
-				payload: true,
-			},
-			{
-				name:   "5",
-				opcode: 5,
-				echo:   true,
-				ping:   true,
-			},
-			{
-				name:    "6",
-				opcode:  6,
-				payload: true,
-				echo:    true,
-				ping:    true,
-			},
-			{
-				name:    "7",
-				opcode:  7,
-				payload: true,
-				echo:    true,
-				ping:    true,
-			},
-
-			// Section 2.
-			{
-				name:   "11",
-				opcode: 11,
-			},
-			{
-				name:    "12",
-				opcode:  12,
-				payload: true,
-			},
-			{
-				name:    "13",
-				opcode:  13,
-				payload: true,
-				echo:    true,
-				ping:    true,
-			},
-			{
-				name:    "14",
-				opcode:  14,
-				payload: true,
-				echo:    true,
-				ping:    true,
-			},
-			{
-				name:    "15",
-				opcode:  15,
-				payload: true,
-				echo:    true,
-				ping:    true,
-			},
-		}
-		for _, tc := range testCases {
-			tc := tc
-			run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
-				if tc.echo {
-					err := assertEcho(ctx, c, websocket.MessageText, 4096)
-					if err != nil {
-						return err
-					}
-				}
-
-				p := []byte(nil)
-				if tc.payload {
-					p = randBytes(rand.Intn(4096) + 1)
-				}
-				_, err := c.WriteFrame(ctx, true, tc.opcode, p)
-				if err != nil {
-					return err
-				}
-				if tc.ping {
-					_, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf"))
-					if err != nil {
-						return err
-					}
-				}
-				return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-			})
-		}
-	})
-
-	// Section 5.
-	t.Run("fragmentation", func(t *testing.T) {
-		t.Parallel()
-
-		// 5.1 to 5.8
-		testCases := []struct {
-			name          string
-			opcode        websocket.OpCode
-			success       bool
-			pingInBetween bool
-		}{
-			{
-				name:    "ping",
-				opcode:  websocket.OpPing,
-				success: false,
-			},
-			{
-				name:    "pong",
-				opcode:  websocket.OpPong,
-				success: false,
-			},
-			{
-				name:    "text",
-				opcode:  websocket.OpText,
-				success: true,
-			},
-			{
-				name:          "textPing",
-				opcode:        websocket.OpText,
-				success:       true,
-				pingInBetween: true,
-			},
-		}
-		for _, tc := range testCases {
-			tc := tc
-			run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
-				p1 := randBytes(16)
-				_, err := c.WriteFrame(ctx, false, tc.opcode, p1)
-				if err != nil {
-					return err
-				}
-				err = c.BW().Flush()
-				if err != nil {
-					return err
-				}
-				if !tc.success {
-					_, _, err = c.Read(ctx)
-					return assertCloseStatus(err, websocket.StatusProtocolError)
-				}
-
-				if tc.pingInBetween {
-					_, err = c.WriteFrame(ctx, true, websocket.OpPing, p1)
-					if err != nil {
-						return err
-					}
-				}
-
-				p2 := randBytes(16)
-				_, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2)
-				if err != nil {
-					return err
-				}
-
-				err = assertReadFrame(ctx, c, tc.opcode, p1)
-				if err != nil {
-					return err
-				}
-
-				if tc.pingInBetween {
-					err = assertReadFrame(ctx, c, websocket.OpPong, p1)
-					if err != nil {
-						return err
-					}
-				}
-
-				return assertReadFrame(ctx, c, websocket.OpContinuation, p2)
-			})
+		err := s.Shutdown(ctx)
+		if err != nil {
+			return fmt.Errorf("server shutdown failed: %v", err)
 		}
 
-		t.Run("unexpectedContinuation", func(t *testing.T) {
-			t.Parallel()
-
-			testCases := []struct {
-				name      string
-				fin       bool
-				textFirst bool
-			}{
-				{
-					name: "fin",
-					fin:  true,
-				},
-				{
-					name: "noFin",
-					fin:  false,
-				},
-				{
-					name:      "echoFirst",
-					fin:       false,
-					textFirst: true,
-				},
-				// The rest of the tests in this section get complicated and do not inspire much confidence.
-			}
-
-			for _, tc := range testCases {
-				tc := tc
-				run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error {
-					if tc.textFirst {
-						w, err := c.Writer(ctx, websocket.MessageText)
-						if err != nil {
-							return err
-						}
-						p1 := randBytes(32)
-						_, err = w.Write(p1)
-						if err != nil {
-							return err
-						}
-						p2 := randBytes(32)
-						_, err = w.Write(p2)
-						if err != nil {
-							return err
-						}
-						err = w.Close()
-						if err != nil {
-							return err
-						}
-						err = assertReadFrame(ctx, c, websocket.OpText, p1)
-						if err != nil {
-							return err
-						}
-						err = assertReadFrame(ctx, c, websocket.OpContinuation, p2)
-						if err != nil {
-							return err
-						}
-						err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{})
-						if err != nil {
-							return err
-						}
-					}
-
-					_, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32))
-					if err != nil {
-						return err
-					}
-					err = c.BW().Flush()
-					if err != nil {
-						return err
-					}
-
-					return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-				})
-			}
-
-			run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error {
-				p1 := randBytes(32)
-				_, err := c.WriteFrame(ctx, false, websocket.OpText, p1)
-				if err != nil {
-					return err
-				}
-				_, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32))
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpText, p1)
-				if err != nil {
-					return err
-				}
-				return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-			})
-
-			run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error {
-				p1 := randBytes(32)
-				p2 := randBytes(32)
-				p3 := randBytes(32)
-				p4 := randBytes(32)
-				p5 := randBytes(32)
-
-				_, err := c.WriteFrame(ctx, false, websocket.OpText, p1)
-				if err != nil {
-					return err
-				}
-				_, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2)
-				if err != nil {
-					return err
-				}
-
-				_, err = c.WriteFrame(ctx, true, websocket.OpPing, p1)
-				if err != nil {
-					return err
-				}
-
-				time.Sleep(time.Second)
-
-				_, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3)
-				if err != nil {
-					return err
-				}
-				_, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4)
-				if err != nil {
-					return err
-				}
-
-				_, err = c.WriteFrame(ctx, true, websocket.OpPing, p1)
-				if err != nil {
-					return err
-				}
-
-				_, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5)
-				if err != nil {
-					return err
-				}
-
-				err = assertReadFrame(ctx, c, websocket.OpText, p1)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpContinuation, p2)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpPong, p1)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpContinuation, p3)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpContinuation, p4)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpPong, p1)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpContinuation, p5)
-				if err != nil {
-					return err
-				}
-				err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{})
-				if err != nil {
-					return err
-				}
-				return c.Close(websocket.StatusNormalClosure, "")
-			})
-		})
-	})
-
-	// Section 7
-	t.Run("closeHandling", func(t *testing.T) {
-		t.Parallel()
-
-		// 1.1 - 1.4 is useless.
-		run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error {
-			p1 := randBytes(32)
-			_, err := c.WriteFrame(ctx, false, websocket.OpText, p1)
-			if err != nil {
-				return err
-			}
-			err = c.Flush()
-			if err != nil {
-				return err
-			}
-			_, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "")
-			if err != nil {
-				return err
-			}
-			err = assertReadFrame(ctx, c, websocket.OpText, p1)
-			if err != nil {
-				return err
-			}
-			return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure)
-		})
-
-		run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error {
-			// 262144 bytes.
-			p1 := randBytes(1 << 18)
-			err := c.Write(ctx, websocket.MessageText, p1)
-			if err != nil {
-				return err
-			}
-			_, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "")
-			if err != nil {
-				return err
-			}
-			err = assertReadMessage(ctx, c, websocket.MessageText, p1)
-			if err != nil {
-				return err
-			}
-			return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure)
-		})
-
-		run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error {
-			_, err := c.WriteFrame(ctx, true, websocket.OpClose, nil)
-			if err != nil {
-				return err
-			}
-			return assertReadFrame(ctx, c, websocket.OpClose, []byte{})
-		})
-
-		run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error {
-			_, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1})
-			if err != nil {
-				return err
-			}
-			return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-		})
-
-		run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error {
-			return c.Close(websocket.StatusNormalClosure, "")
-		})
-
-		run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error {
-			return c.Close(websocket.StatusNormalClosure, randString(16))
-		})
-
-		run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error {
-			return c.Close(websocket.StatusNormalClosure, randString(123))
-		})
-
-		run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error {
-			_, err := c.WriteFrame(ctx, true, websocket.OpClose,
-				append([]byte{0x03, 0xE8}, randString(124)...),
-			)
-			if err != nil {
-				return err
-			}
-			return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-		})
-
-		t.Run("validCloses", func(t *testing.T) {
-			t.Parallel()
-
-			codes := [...]websocket.StatusCode{
-				1000,
-				1001,
-				1002,
-				1003,
-				1007,
-				1008,
-				1009,
-				1010,
-				1011,
-				3000,
-				3999,
-				4000,
-				4999,
-			}
-			for _, code := range codes {
-				run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error {
-					return c.Close(code, randString(32))
-				})
-			}
-		})
-
-		t.Run("invalidCloseCodes", func(t *testing.T) {
-			t.Parallel()
-
-			codes := []websocket.StatusCode{
-				0,
-				999,
-				1004,
-				1005,
-				1006,
-				1016,
-				1100,
-				2000,
-				2999,
-				5000,
-				65535,
-			}
-			for _, code := range codes {
-				run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error {
-					p := make([]byte, 2)
-					binary.BigEndian.PutUint16(p, uint16(code))
-					p = append(p, randBytes(32)...)
-					_, err := c.WriteFrame(ctx, true, websocket.OpClose, p)
-					if err != nil {
-						return err
-					}
-					return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError)
-				})
-			}
-		})
-	})
-
-	// Section 9.
-	t.Run("limits", func(t *testing.T) {
-		t.Parallel()
-
-		t.Run("unfragmentedEcho", func(t *testing.T) {
-			t.Parallel()
-
-			lengths := []int{
-				1 << 16,
-				1 << 18,
-				// Anything higher is completely unnecessary.
-			}
-
-			for _, l := range lengths {
-				l := l
-				run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error {
-					return assertEcho(ctx, c, websocket.MessageBinary, l)
-				})
-			}
-		})
-
-		t.Run("fragmentedEcho", func(t *testing.T) {
-			t.Parallel()
-
-			fragments := []int{
-				64,
-				256,
-				1 << 10,
-				1 << 12,
-				1 << 14,
-				1 << 16,
-			}
-
-			for _, l := range fragments {
-				fragmentLength := l
-				run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error {
-					w, err := c.Writer(ctx, websocket.MessageText)
-					if err != nil {
-						return err
-					}
-					b := randBytes(1 << 16)
-					for i := 0; i < len(b); {
-						j := i + fragmentLength
-						if j > len(b) {
-							j = len(b)
-						}
-
-						_, err = w.Write(b[i:j])
-						if err != nil {
-							return err
-						}
-
-						i = j
-					}
-					err = w.Close()
-					if err != nil {
-						return err
-					}
-
-					err = assertReadMessage(ctx, c, websocket.MessageText, b)
-					if err != nil {
-						return err
-					}
-					return c.Close(websocket.StatusNormalClosure, "")
-				})
-			}
-		})
-
-		t.Run("latencyEcho", func(t *testing.T) {
-			t.Parallel()
-
-			lengths := []int{
-				0,
-				16,
-			}
-
-			for _, l := range lengths {
-				l := l
-				run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error {
-					for i := 0; i < 1000; i++ {
-						err := assertEcho(ctx, c, websocket.MessageBinary, l)
-						if err != nil {
-							return err
-						}
-					}
+		t := time.NewTicker(time.Millisecond * 10)
+		defer t.Stop()
+		for {
+			select {
+			case <-t.C:
+				if atomic.LoadInt64(&conns) == 0 {
 					return nil
-				})
-			}
-		})
-	})
-}
-
-func assertCloseStatus(err error, code websocket.StatusCode) error {
-	var cerr websocket.CloseError
-	if !errors.As(err, &cerr) {
-		return fmt.Errorf("no websocket close error in error chain: %+v", err)
-	}
-	return assert.Equalf(code, cerr.Code, "unexpected status code")
-}
-
-func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
-	expType := reflect.TypeOf(exp)
-	actv := reflect.New(expType.Elem())
-	act := actv.Interface().(proto.Message)
-	err := wspb.Read(ctx, c, act)
-	if err != nil {
-		return err
-	}
-
-	return assert.Equalf(exp, act, "unexpected protobuf")
-}
-
-func assertNetConnRead(r io.Reader, exp string) error {
-	act := make([]byte, len(exp))
-	_, err := r.Read(act)
-	if err != nil {
-		return err
-	}
-	return assert.Equalf(exp, string(act), "unexpected net conn read")
-}
-
-func assertErrorContains(err error, exp string) error {
-	if err == nil || !strings.Contains(err.Error(), exp) {
-		return fmt.Errorf("expected error that contains %q but got: %+v", exp, err)
-	}
-	return nil
-}
-
-func assertErrorIs(exp, act error) error {
-	if !errors.Is(act, exp) {
-		return fmt.Errorf("expected error %+v to be in %+v", exp, act)
-	}
-	return nil
-}
-
-func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.OpCode, p []byte) error {
-	actOpcode, actP, err := c.ReadFrame(ctx)
-	if err != nil {
-		return err
-	}
-	err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
-	if err != nil {
-		return err
-	}
-	return assert.Equalf(p, actP, "unexpected frame %v payload", opcode)
-}
-
-func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error {
-	actOpcode, actP, err := c.ReadFrame(ctx)
-	if err != nil {
-		return err
-	}
-	err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
-	if err != nil {
-		return err
-	}
-	ce, err := websocket.ParseClosePayload(actP)
-	if err != nil {
-		return fmt.Errorf("failed to parse close frame payload: %w", err)
-	}
-	return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
-}
-
-func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error {
-	err := c.WriteHeader(ctx, websocket.Header{
-		Fin:           true,
-		OpCode:        websocket.OpPing,
-		PayloadLength: int64(l),
-	})
-	if err != nil {
-		return err
-	}
-	for i := 0; i < l; i++ {
-		err = c.BW().WriteByte(0xFE)
-		if err != nil {
-			return fmt.Errorf("failed to write byte %d: %w", i, err)
-		}
-		if i%32 == 0 {
-			err = c.BW().Flush()
-			if err != nil {
-				return fmt.Errorf("failed to flush at byte %d: %w", i, err)
+				}
+			case <-ctx.Done():
+				return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err())
 			}
 		}
 	}
-	err = c.BW().Flush()
-	if err != nil {
-		return fmt.Errorf("failed to flush: %v", err)
-	}
-	return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l))
-}
-
-func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error {
-	actTyp, actP, err := c.Read(ctx)
-	if err != nil {
-		return err
-	}
-	err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
-	if err != nil {
-		return err
-	}
-	return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp)
-}
-
-func BenchmarkConn(b *testing.B) {
-	sizes := []int{
-		2,
-		16,
-		32,
-		512,
-		4096,
-		16384,
-	}
-
-	b.Run("write", func(b *testing.B) {
-		for _, size := range sizes {
-			b.Run(strconv.Itoa(size), func(b *testing.B) {
-				b.Run("stream", func(b *testing.B) {
-					benchConn(b, false, true, size)
-				})
-				b.Run("buffer", func(b *testing.B) {
-					benchConn(b, false, false, size)
-				})
-			})
-		}
-	})
-
-	b.Run("echo", func(b *testing.B) {
-		for _, size := range sizes {
-			b.Run(strconv.Itoa(size), func(b *testing.B) {
-				benchConn(b, true, true, size)
-			})
-		}
-	})
 }
 
-func benchConn(b *testing.B, echo, stream bool, size int) {
-	s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error {
-		c, err := websocket.Accept(w, r, nil)
-		if err != nil {
-			return err
-		}
-		if echo {
-			wsecho.Loop(r.Context(), c)
-		} else {
-			discardLoop(r.Context(), c)
-		}
-		return nil
-	}, false)
-	defer closeFn()
-
-	wsURL := strings.Replace(s.URL, "http", "ws", 1)
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
-	defer cancel()
-
-	c, _, err := websocket.Dial(ctx, wsURL, nil)
-	if err != nil {
-		b.Fatal(err)
-	}
+// echoLoop echos every msg received from c until an error
+// occurs or the context expires.
+// The read limit is set to 1 << 30.
+func echoLoop(ctx context.Context, c *websocket.Conn) error {
 	defer c.Close(websocket.StatusInternalError, "")
 
-	msg := []byte(strings.Repeat("2", size))
-	readBuf := make([]byte, len(msg))
-	b.SetBytes(int64(len(msg)))
-	b.ReportAllocs()
-	b.ResetTimer()
-	for i := 0; i < b.N; i++ {
-		if stream {
-			w, err := c.Writer(ctx, websocket.MessageText)
-			if err != nil {
-				b.Fatal(err)
-			}
-
-			_, err = w.Write(msg)
-			if err != nil {
-				b.Fatal(err)
-			}
-
-			err = w.Close()
-			if err != nil {
-				b.Fatal(err)
-			}
-		} else {
-			err = c.Write(ctx, websocket.MessageText, msg)
-			if err != nil {
-				b.Fatal(err)
-			}
-		}
-
-		if echo {
-			_, r, err := c.Reader(ctx)
-			if err != nil {
-				b.Fatal(err)
-			}
-
-			_, err = io.ReadFull(r, readBuf)
-			if err != nil {
-				b.Fatal(err)
-			}
-		}
-	}
-	b.StopTimer()
-
-	c.Close(websocket.StatusNormalClosure, "")
-}
-
-func discardLoop(ctx context.Context, c *websocket.Conn) {
-	defer c.Close(websocket.StatusInternalError, "")
+	c.SetReadLimit(1 << 30)
 
 	ctx, cancel := context.WithTimeout(ctx, time.Minute)
 	defer cancel()
 
-	b := make([]byte, 32768)
-	echo := func() error {
-		_, r, err := c.Reader(ctx)
+	b := make([]byte, 32<<10)
+	for {
+		typ, r, err := c.Reader(ctx)
 		if err != nil {
 			return err
 		}
 
-		_, err = io.CopyBuffer(ioutil.Discard, r, b)
+		w, err := c.Writer(ctx, typ)
 		if err != nil {
 			return err
 		}
-		return nil
-	}
 
-	for {
-		err := echo()
+		_, err = io.CopyBuffer(w, r, b)
 		if err != nil {
-			return
+			return err
 		}
-	}
-}
-
-func TestAutobahnPython(t *testing.T) {
-	// This test contains the old autobahn test suite tests that use the
-	// python binary. The approach is clunky and slow so new tests
-	// have been written in pure Go in websocket_test.go.
-	// These have been kept for correctness purposes and are occasionally ran.
-	if os.Getenv("AUTOBAHN_PYTHON") == "" {
-		t.Skip("Set $AUTOBAHN_PYTHON to run tests against the python autobahn test suite")
-	}
-
-	t.Run("server", testServerAutobahnPython)
-	t.Run("client", testClientAutobahnPython)
-}
-
-// https://github.com/crossbario/autobahn-python/tree/master/wstest
-func testServerAutobahnPython(t *testing.T) {
-	t.Parallel()
 
-	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-			Subprotocols: []string{"echo"},
-		})
+		err = w.Close()
 		if err != nil {
-			t.Logf("server handshake failed: %+v", err)
-			return
+			return err
 		}
-		wsecho.Loop(r.Context(), c)
-	}))
-	defer s.Close()
-
-	spec := map[string]interface{}{
-		"outdir": "ci/out/wstestServerReports",
-		"servers": []interface{}{
-			map[string]interface{}{
-				"agent": "main",
-				"url":   strings.Replace(s.URL, "http", "ws", 1),
-			},
-		},
-		"cases": []string{"*"},
-		// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
-		// more performance overhead. 7.5.1 is the same.
-		// 12.* and 13.* as we do not support compression.
-		"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
-	}
-	specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json")
-	if err != nil {
-		t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err)
-	}
-	defer specFile.Close()
-
-	e := json.NewEncoder(specFile)
-	e.SetIndent("", "\t")
-	err = e.Encode(spec)
-	if err != nil {
-		t.Fatalf("failed to write spec: %v", err)
-	}
-
-	err = specFile.Close()
-	if err != nil {
-		t.Fatalf("failed to close file: %v", err)
-	}
-
-	ctx := context.Background()
-	ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
-	defer cancel()
-
-	args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()}
-	wstest := exec.CommandContext(ctx, "wstest", args...)
-	out, err := wstest.CombinedOutput()
-	if err != nil {
-		t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out)
 	}
-
-	checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
 }
 
-func unusedListenAddr() (string, error) {
-	l, err := net.Listen("tcp", "localhost:0")
-	if err != nil {
-		return "", err
-	}
-	l.Close()
-	return l.Addr().String(), nil
-}
-
-// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py
-func testClientAutobahnPython(t *testing.T) {
+func TestConn(t *testing.T) {
 	t.Parallel()
 
-	if os.Getenv("AUTOBAHN_PYTHON") == "" {
-		t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite")
-	}
-
-	serverAddr, err := unusedListenAddr()
-	if err != nil {
-		t.Fatalf("failed to get unused listen addr for wstest: %v", err)
-	}
-
-	wsServerURL := "ws://" + serverAddr
-
-	spec := map[string]interface{}{
-		"url":    wsServerURL,
-		"outdir": "ci/out/wstestClientReports",
-		"cases":  []string{"*"},
-		// See TestAutobahnServer for the reasons why we exclude these.
-		"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
-	}
-	specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json")
-	if err != nil {
-		t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err)
-	}
-	defer specFile.Close()
-
-	e := json.NewEncoder(specFile)
-	e.SetIndent("", "\t")
-	err = e.Encode(spec)
-	if err != nil {
-		t.Fatalf("failed to write spec: %v", err)
-	}
-
-	err = specFile.Close()
-	if err != nil {
-		t.Fatalf("failed to close file: %v", err)
-	}
-
-	ctx := context.Background()
-	ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
-	defer cancel()
-
-	args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(),
-		// Disables some server that runs as part of fuzzingserver mode.
-		// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
-		"--webport=0",
-	}
-	wstest := exec.CommandContext(ctx, "wstest", args...)
-	err = wstest.Start()
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer func() {
-		err := wstest.Process.Kill()
-		if err != nil {
-			t.Error(err)
-		}
-	}()
-
-	// Let it come up.
-	time.Sleep(time.Second * 5)
-
-	var cases int
-	func() {
-		c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil)
-		if err != nil {
-			t.Fatal(err)
-		}
-		defer c.Close(websocket.StatusInternalError, "")
-
-		_, r, err := c.Reader(ctx)
-		if err != nil {
-			t.Fatal(err)
-		}
-		b, err := ioutil.ReadAll(r)
-		if err != nil {
-			t.Fatal(err)
-		}
-		cases, err = strconv.Atoi(string(b))
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		c.Close(websocket.StatusNormalClosure, "")
-	}()
-
-	for i := 1; i <= cases; i++ {
-		func() {
-			ctx, cancel := context.WithTimeout(ctx, time.Second*45)
-			defer cancel()
-
-			c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil)
-			if err != nil {
-				t.Fatal(err)
-			}
-			wsecho.Loop(ctx, c)
-		}()
-	}
-
-	c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil)
-	if err != nil {
-		t.Fatal(err)
-	}
-	c.Close(websocket.StatusNormalClosure, "")
+	t.Run("json", func(t *testing.T) {
+		s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
+			c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
+				Subprotocols:       []string{"echo"},
+				InsecureSkipVerify: true,
+			})
+			assert.Success(t, err)
+			defer c.Close(websocket.StatusInternalError, "")
 
-	checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
-}
+			err = echoLoop(r.Context(), c)
+			assertCloseStatus(t, websocket.StatusNormalClosure, err)
+		}, false)
+		defer closeFn()
 
-func checkWSTestIndex(t *testing.T, path string) {
-	wstestOut, err := ioutil.ReadFile(path)
-	if err != nil {
-		t.Fatalf("failed to read index.json: %v", err)
-	}
+		wsURL := strings.Replace(s.URL, "http", "ws", 1)
 
-	var indexJSON map[string]map[string]struct {
-		Behavior      string `json:"behavior"`
-		BehaviorClose string `json:"behaviorClose"`
-	}
-	err = json.Unmarshal(wstestOut, &indexJSON)
-	if err != nil {
-		t.Fatalf("failed to unmarshal index.json: %v", err)
-	}
+		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+		defer cancel()
 
-	var failed bool
-	for _, tests := range indexJSON {
-		for test, result := range tests {
-			switch result.Behavior {
-			case "OK", "NON-STRICT", "INFORMATIONAL":
-			default:
-				failed = true
-				t.Errorf("test %v failed", test)
-			}
-			switch result.BehaviorClose {
-			case "OK", "INFORMATIONAL":
-			default:
-				failed = true
-				t.Errorf("bad close behaviour for test %v", test)
-			}
-		}
-	}
-
-	if failed {
-		path = strings.Replace(path, ".json", ".html", 1)
-		if os.Getenv("CI") == "" {
-			t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path)
-		}
-	}
-}
-
-func TestWASM(t *testing.T) {
-	t.Parallel()
-
-	s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error {
-		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-			Subprotocols:       []string{"echo"},
-			InsecureSkipVerify: true,
-		})
-		if err != nil {
-			return err
-		}
-		defer c.Close(websocket.StatusInternalError, "")
-
-		err = wsecho.Loop(r.Context(), c)
-		if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
-			return err
+		opts := &websocket.DialOptions{
+			Subprotocols: []string{"echo"},
 		}
-		return nil
-	}, false)
-	defer closeFn()
+		opts.HTTPClient = s.Client()
 
-	wsURL := strings.Replace(s.URL, "http", "ws", 1)
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-	defer cancel()
+		c, _, err := websocket.Dial(ctx, wsURL, opts)
+		assert.Success(t, err)
 
-	cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...")
-	cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL))
-
-	b, err := cmd.CombinedOutput()
-	if err != nil {
-		t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
-	}
+		assertJSONEcho(t, ctx, c, 2)
+	})
 }
diff --git a/dial.go b/dial.go
index 1008868..8fa0f7a 100644
--- a/dial.go
+++ b/dial.go
@@ -1,17 +1,19 @@
 package websocket
 
 import (
+	"bufio"
 	"bytes"
 	"context"
 	"crypto/rand"
 	"encoding/base64"
+	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
 	"net/http"
 	"net/url"
-	"nhooyr.io/websocket/internal/bufpool"
 	"strings"
+	"sync"
 )
 
 // DialOptions represents the options available to pass to Dial.
@@ -50,7 +52,7 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
 	return c, r, nil
 }
 
-func (opts *DialOptions) fill() (*DialOptions, error) {
+func (opts *DialOptions) ensure() *DialOptions {
 	if opts == nil {
 		opts = &DialOptions{}
 	} else {
@@ -60,20 +62,18 @@ func (opts *DialOptions) fill() (*DialOptions, error) {
 	if opts.HTTPClient == nil {
 		opts.HTTPClient = http.DefaultClient
 	}
-	if opts.HTTPClient.Timeout > 0 {
-		return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
-	}
 	if opts.HTTPHeader == nil {
 		opts.HTTPHeader = http.Header{}
 	}
 
-	return opts, nil
+	return opts
 }
 
 func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) {
-	opts, err = opts.fill()
-	if err != nil {
-		return nil, nil, err
+	opts = opts.ensure()
+
+	if opts.HTTPClient.Timeout > 0 {
+		return nil, nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
 	}
 
 	parsedURL, err := url.Parse(u)
@@ -104,8 +104,10 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
 	if len(opts.Subprotocols) > 0 {
 		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
 	}
-	copts := opts.CompressionMode.opts()
-	copts.setHeader(req.Header)
+	if opts.CompressionMode != CompressionDisabled {
+		copts := opts.CompressionMode.opts()
+		copts.setHeader(req.Header)
+	}
 
 	resp, err := opts.HTTPClient.Do(req)
 	if err != nil {
@@ -121,7 +123,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
 		}
 	}()
 
-	copts, err = verifyServerResponse(req, resp, opts)
+	copts, err := verifyServerResponse(req, resp)
 	if err != nil {
 		return nil, resp, err
 	}
@@ -131,18 +133,14 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
 		return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc)
 	}
 
-	c := &Conn{
+	return newConn(connConfig{
 		subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
-		br:          bufpool.GetReader(rwc),
-		bw:          bufpool.GetWriter(rwc),
-		closer:      rwc,
+		rwc:         rwc,
 		client:      true,
 		copts:       copts,
-	}
-	c.extractBufioWriterBuf(rwc)
-	c.init()
-
-	return c, resp, nil
+		br:          getBufioReader(rwc),
+		bw:          getBufioWriter(rwc),
+	}), resp, nil
 }
 
 func secWebSocketKey() (string, error) {
@@ -154,7 +152,7 @@ func secWebSocketKey() (string, error) {
 	return base64.StdEncoding.EncodeToString(b), nil
 }
 
-func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*compressionOptions, error) {
+func verifyServerResponse(r *http.Request, resp *http.Response) (*compressionOptions, error) {
 	if resp.StatusCode != http.StatusSwitchingProtocols {
 		return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
 	}
@@ -178,7 +176,7 @@ func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOption
 		return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
 	}
 
-	copts, err := verifyServerExtensions(resp.Header, opts.CompressionMode)
+	copts, err := verifyServerExtensions(resp.Header)
 	if err != nil {
 		return nil, err
 	}
@@ -186,7 +184,7 @@ func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOption
 	return copts, nil
 }
 
-func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOptions, error) {
+func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
 	exts := websocketExtensions(h)
 	if len(exts) == 0 {
 		return nil, nil
@@ -201,7 +199,7 @@ func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOp
 		return nil, fmt.Errorf("unexpected extra extensions from server: %+v", exts[1:])
 	}
 
-	copts := mode.opts()
+	copts := &compressionOptions{}
 	for _, p := range ext.params {
 		switch p {
 		case "client_no_context_takeover":
@@ -217,3 +215,33 @@ func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOp
 
 	return copts, nil
 }
+
+var readerPool sync.Pool
+
+func getBufioReader(r io.Reader) *bufio.Reader {
+	br, ok := readerPool.Get().(*bufio.Reader)
+	if !ok {
+		return bufio.NewReader(r)
+	}
+	br.Reset(r)
+	return br
+}
+
+func putBufioReader(br *bufio.Reader) {
+	readerPool.Put(br)
+}
+
+var writerPool sync.Pool
+
+func getBufioWriter(w io.Writer) *bufio.Writer {
+	bw, ok := writerPool.Get().(*bufio.Writer)
+	if !ok {
+		return bufio.NewWriter(w)
+	}
+	bw.Reset(w)
+	return bw
+}
+
+func putBufioWriter(bw *bufio.Writer) {
+	writerPool.Put(bw)
+}
diff --git a/dial_test.go b/dial_test.go
index 391aa1c..5eeb904 100644
--- a/dial_test.go
+++ b/dial_test.go
@@ -140,7 +140,7 @@ func Test_verifyServerHandshake(t *testing.T) {
 				resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
 			}
 
-			_, err = verifyServerResponse(r, resp, &DialOptions{})
+			_, err = verifyServerResponse(r, resp)
 			if (err == nil) != tc.success {
 				t.Fatalf("unexpected error: %+v", err)
 			}
diff --git a/example_echo_test.go b/example_echo_test.go
index ecc9b97..16d003d 100644
--- a/example_echo_test.go
+++ b/example_echo_test.go
@@ -4,6 +4,7 @@ package websocket_test
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"io"
 	"log"
@@ -77,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
 
 	if c.Subprotocol() != "echo" {
 		c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol")
-		return fmt.Errorf("client does not speak echo sub protocol")
+		return errors.New("client does not speak echo sub protocol")
 	}
 
 	l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
diff --git a/internal/wsframe/mask.go b/frame.go
similarity index 57%
rename from internal/wsframe/mask.go
rename to frame.go
index 2da4c11..0f10d55 100644
--- a/internal/wsframe/mask.go
+++ b/frame.go
@@ -1,11 +1,167 @@
-package wsframe
+package websocket
 
 import (
+	"bufio"
 	"encoding/binary"
+	"math"
 	"math/bits"
+	"nhooyr.io/websocket/internal/errd"
 )
 
-// Mask applies the WebSocket masking algorithm to p
+// opcode represents a WebSocket opcode.
+type opcode int
+
+// List at https://tools.ietf.org/html/rfc6455#section-11.8.
+const (
+	opContinuation opcode = iota
+	opText
+	opBinary
+	// 3 - 7 are reserved for further non-control frames.
+	_
+	_
+	_
+	_
+	_
+	opClose
+	opPing
+	opPong
+	// 11-16 are reserved for further control frames.
+)
+
+// header represents a WebSocket frame header.
+// See https://tools.ietf.org/html/rfc6455#section-5.2.
+type header struct {
+	fin    bool
+	rsv1   bool
+	rsv2   bool
+	rsv3   bool
+	opcode opcode
+
+	payloadLength int64
+
+	masked  bool
+	maskKey uint32
+}
+
+// readFrameHeader reads a header from the reader.
+// See https://tools.ietf.org/html/rfc6455#section-5.2.
+func readFrameHeader(r *bufio.Reader) (_ header, err error) {
+	defer errd.Wrap(&err, "failed to read frame header")
+
+	b, err := r.ReadByte()
+	if err != nil {
+		return header{}, err
+	}
+
+	var h header
+	h.fin = b&(1<<7) != 0
+	h.rsv1 = b&(1<<6) != 0
+	h.rsv2 = b&(1<<5) != 0
+	h.rsv3 = b&(1<<4) != 0
+
+	h.opcode = opcode(b & 0xf)
+
+	b, err = r.ReadByte()
+	if err != nil {
+		return header{}, err
+	}
+
+	h.masked = b&(1<<7) != 0
+
+	payloadLength := b &^ (1 << 7)
+	switch {
+	case payloadLength < 126:
+		h.payloadLength = int64(payloadLength)
+	case payloadLength == 126:
+		var pl uint16
+		err = binary.Read(r, binary.BigEndian, &pl)
+		h.payloadLength = int64(pl)
+	case payloadLength == 127:
+		err = binary.Read(r, binary.BigEndian, &h.payloadLength)
+	}
+	if err != nil {
+		return header{}, err
+	}
+
+	if h.masked {
+		err = binary.Read(r, binary.LittleEndian, &h.maskKey)
+		if err != nil {
+			return header{}, err
+		}
+	}
+
+	return h, nil
+}
+
+// maxControlPayload is the maximum length of a control frame payload.
+// See https://tools.ietf.org/html/rfc6455#section-5.5.
+const maxControlPayload = 125
+
+// writeFrameHeader writes the bytes of the header to w.
+// See https://tools.ietf.org/html/rfc6455#section-5.2
+func writeFrameHeader(h header, w *bufio.Writer) (err error) {
+	defer errd.Wrap(&err, "failed to write frame header")
+
+	var b byte
+	if h.fin {
+		b |= 1 << 7
+	}
+	if h.rsv1 {
+		b |= 1 << 6
+	}
+	if h.rsv2 {
+		b |= 1 << 5
+	}
+	if h.rsv3 {
+		b |= 1 << 4
+	}
+
+	b |= byte(h.opcode)
+
+	err = w.WriteByte(b)
+	if err != nil {
+		return err
+	}
+
+	lengthByte := byte(0)
+	if h.masked {
+		lengthByte |= 1 << 7
+	}
+
+	switch {
+	case h.payloadLength > math.MaxUint16:
+		lengthByte |= 127
+	case h.payloadLength > 125:
+		lengthByte |= 126
+	case h.payloadLength >= 0:
+		lengthByte |= byte(h.payloadLength)
+	}
+	err = w.WriteByte(lengthByte)
+	if err != nil {
+		return err
+	}
+
+	switch {
+	case h.payloadLength > math.MaxUint16:
+		err = binary.Write(w, binary.BigEndian, h.payloadLength)
+	case h.payloadLength > 125:
+		err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength))
+	}
+	if err != nil {
+		return err
+	}
+
+	if h.masked {
+		err = binary.Write(w, binary.LittleEndian, h.maskKey)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+// mask applies the WebSocket masking algorithm to p
 // with the given key.
 // See https://tools.ietf.org/html/rfc6455#section-5.3
 //
@@ -16,7 +172,7 @@ import (
 // to be in little endian.
 //
 // See https://github.com/golang/go/issues/31586
-func Mask(key uint32, b []byte) uint32 {
+func mask(key uint32, b []byte) uint32 {
 	if len(b) >= 8 {
 		key64 := uint64(key)<<32 | uint64(key)
 
diff --git a/internal/wsframe/mask_test.go b/frame_test.go
similarity index 51%
rename from internal/wsframe/mask_test.go
rename to frame_test.go
index fbd2989..0ed14ae 100644
--- a/internal/wsframe/mask_test.go
+++ b/frame_test.go
@@ -1,32 +1,108 @@
-package wsframe_test
+// +build !js
+
+package websocket
 
 import (
-	"crypto/rand"
+	"bufio"
+	"bytes"
 	"encoding/binary"
-	"github.com/gobwas/ws"
-	"github.com/google/go-cmp/cmp"
 	"math/bits"
-	"nhooyr.io/websocket/internal/wsframe"
+	"nhooyr.io/websocket/internal/assert"
 	"strconv"
 	"testing"
+	"time"
 	_ "unsafe"
+
+	"github.com/gobwas/ws"
+	_ "github.com/gorilla/websocket"
+	"math/rand"
 )
 
+func init() {
+	rand.Seed(time.Now().UnixNano())
+}
+
+func TestHeader(t *testing.T) {
+	t.Parallel()
+
+	t.Run("lengths", func(t *testing.T) {
+		t.Parallel()
+
+		lengths := []int{
+			124,
+			125,
+			126,
+			127,
+
+			65534,
+			65535,
+			65536,
+			65537,
+		}
+
+		for _, n := range lengths {
+			n := n
+			t.Run(strconv.Itoa(n), func(t *testing.T) {
+				t.Parallel()
+
+				testHeader(t, header{
+					payloadLength: int64(n),
+				})
+			})
+		}
+	})
+
+	t.Run("fuzz", func(t *testing.T) {
+		t.Parallel()
+
+		randBool := func() bool {
+			return rand.Intn(1) == 0
+		}
+
+		for i := 0; i < 10000; i++ {
+			h := header{
+				fin:    randBool(),
+				rsv1:   randBool(),
+				rsv2:   randBool(),
+				rsv3:   randBool(),
+				opcode: opcode(rand.Intn(16)),
+
+				masked:        randBool(),
+				maskKey:       rand.Uint32(),
+				payloadLength: rand.Int63(),
+			}
+
+			testHeader(t, h)
+		}
+	})
+}
+
+func testHeader(t *testing.T, h header) {
+	b := &bytes.Buffer{}
+	w := bufio.NewWriter(b)
+	r := bufio.NewReader(b)
+
+	err := writeFrameHeader(h, w)
+	assert.Success(t, err)
+	err = w.Flush()
+	assert.Success(t, err)
+
+	h2, err := readFrameHeader(r)
+	assert.Success(t, err)
+
+	assert.Equalf(t, h, h2, "written and read headers differ")
+}
+
 func Test_mask(t *testing.T) {
 	t.Parallel()
 
 	key := []byte{0xa, 0xb, 0xc, 0xff}
 	key32 := binary.LittleEndian.Uint32(key)
 	p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
-	gotKey32 := wsframe.Mask(key32, p)
+	gotKey32 := mask(key32, p)
 
-	if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) {
-		t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p))
-	}
-
-	if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) {
-		t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32))
-	}
+	assert.Equalf(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "unexpected mask")
+	assert.Equalf(t, bits.RotateLeft32(key32, -8), gotKey32, "unexpected mask key")
 }
 
 func basicMask(maskKey [4]byte, pos int, b []byte) int {
@@ -74,7 +150,7 @@ func Benchmark_mask(b *testing.B) {
 				b.ResetTimer()
 
 				for i := 0; i < b.N; i++ {
-					wsframe.Mask(key32, p)
+					mask(key32, p)
 				}
 			},
 		},
@@ -98,9 +174,7 @@ func Benchmark_mask(b *testing.B) {
 
 	var key [4]byte
 	_, err := rand.Read(key[:])
-	if err != nil {
-		b.Fatalf("failed to populate mask key: %v", err)
-	}
+	assert.Success(b, err)
 
 	for _, size := range sizes {
 		p := make([]byte, size)
diff --git a/internal/assert/assert.go b/internal/assert/assert.go
index 372d546..1d9aece 100644
--- a/internal/assert/assert.go
+++ b/internal/assert/assert.go
@@ -2,6 +2,7 @@ package assert
 
 import (
 	"reflect"
+	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -53,7 +54,7 @@ func structTypes(v reflect.Value, m map[reflect.Type]struct{}) {
 	}
 }
 
-func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) {
+func Equalf(t testing.TB, exp, act interface{}, f string, v ...interface{}) {
 	t.Helper()
 	diff := cmpDiff(exp, act)
 	if diff != "" {
@@ -61,7 +62,40 @@ func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) {
 	}
 }
 
-func Success(t *testing.T, err error) {
+func NotEqualf(t testing.TB, exp, act interface{}, f string, v ...interface{}) {
 	t.Helper()
-	Equalf(t, error(nil), err, "unexpected failure")
+	diff := cmpDiff(exp, act)
+	if diff == "" {
+		t.Fatalf(f+": %v", append(v, diff)...)
+	}
+}
+
+func Success(t testing.TB, err error) {
+	t.Helper()
+	if err != nil {
+		t.Fatalf("unexpected error: %+v", err)
+	}
+}
+
+func Error(t testing.TB, err error) {
+	t.Helper()
+	if err == nil {
+		t.Fatal("expected error")
+	}
+}
+
+func ErrorContains(t testing.TB, err error, sub string) {
+	t.Helper()
+	Error(t, err)
+	errs := err.Error()
+	if !strings.Contains(errs, sub) {
+		t.Fatalf("error string %q does not contain %q", errs, sub)
+	}
+}
+
+func Panicf(t testing.TB, f string, v ...interface{}) {
+	r := recover()
+	if r == nil {
+		t.Fatalf(f, v...)
+	}
 }
diff --git a/internal/atomicint/atomicint.go b/internal/atomicint/atomicint.go
deleted file mode 100644
index 668b3b4..0000000
--- a/internal/atomicint/atomicint.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package atomicint
-
-import (
-	"fmt"
-	"sync/atomic"
-)
-
-// See https://github.com/nhooyr/websocket/issues/153
-type Int64 struct {
-	v int64
-}
-
-func (v *Int64) Load() int64 {
-	return atomic.LoadInt64(&v.v)
-}
-
-func (v *Int64) Store(i int64) {
-	atomic.StoreInt64(&v.v, i)
-}
-
-func (v *Int64) String() string {
-	return fmt.Sprint(v.Load())
-}
-
-// Increment increments the value and returns the new value.
-func (v *Int64) Increment(delta int64) int64 {
-	return atomic.AddInt64(&v.v, delta)
-}
-
-func (v *Int64) CAS(old, new int64) (swapped bool) {
-	return atomic.CompareAndSwapInt64(&v.v, old, new)
-}
diff --git a/internal/bufpool/buf.go b/internal/bufpool/buf.go
index 324a17e..0f7d976 100644
--- a/internal/bufpool/buf.go
+++ b/internal/bufpool/buf.go
@@ -5,12 +5,12 @@ import (
 	"sync"
 )
 
-var bpool sync.Pool
+var pool sync.Pool
 
 // Get returns a buffer from the pool or creates a new one if
 // the pool is empty.
 func Get() *bytes.Buffer {
-	b, ok := bpool.Get().(*bytes.Buffer)
+	b, ok := pool.Get().(*bytes.Buffer)
 	if !ok {
 		b = &bytes.Buffer{}
 	}
@@ -20,5 +20,5 @@ func Get() *bytes.Buffer {
 // Put returns a buffer into the pool.
 func Put(b *bytes.Buffer) {
 	b.Reset()
-	bpool.Put(b)
+	pool.Put(b)
 }
diff --git a/internal/bufpool/bufio.go b/internal/bufpool/bufio.go
deleted file mode 100644
index 875bbf4..0000000
--- a/internal/bufpool/bufio.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package bufpool
-
-import (
-	"bufio"
-	"io"
-	"sync"
-)
-
-var readerPool = sync.Pool{
-	New: func() interface{} {
-		return bufio.NewReader(nil)
-	},
-}
-
-func GetReader(r io.Reader) *bufio.Reader {
-	br := readerPool.Get().(*bufio.Reader)
-	br.Reset(r)
-	return br
-}
-
-func PutReader(br *bufio.Reader) {
-	readerPool.Put(br)
-}
-
-var writerPool = sync.Pool{
-	New: func() interface{} {
-		return bufio.NewWriter(nil)
-	},
-}
-
-func GetWriter(w io.Writer) *bufio.Writer {
-	bw := writerPool.Get().(*bufio.Writer)
-	bw.Reset(w)
-	return bw
-}
-
-func PutWriter(bw *bufio.Writer) {
-	writerPool.Put(bw)
-}
-
diff --git a/internal/errd/errd.go b/internal/errd/errd.go
new file mode 100644
index 0000000..51b7b4f
--- /dev/null
+++ b/internal/errd/errd.go
@@ -0,0 +1,11 @@
+package errd
+
+import (
+	"fmt"
+)
+
+func Wrap(err *error, f string, v ...interface{}) {
+	if *err != nil {
+		*err = fmt.Errorf(f+ ": %w", append(v, *err)...)
+	}
+}
diff --git a/internal/wsecho/wsecho.go b/internal/wsecho/wsecho.go
deleted file mode 100644
index c408f07..0000000
--- a/internal/wsecho/wsecho.go
+++ /dev/null
@@ -1,55 +0,0 @@
-// +build !js
-
-package wsecho
-
-import (
-	"context"
-	"io"
-	"time"
-
-	"nhooyr.io/websocket"
-)
-
-// Loop echos every msg received from c until an error
-// occurs or the context expires.
-// The read limit is set to 1 << 30.
-func Loop(ctx context.Context, c *websocket.Conn) error {
-	defer c.Close(websocket.StatusInternalError, "")
-
-	c.SetReadLimit(1 << 30)
-
-	ctx, cancel := context.WithTimeout(ctx, time.Minute)
-	defer cancel()
-
-	b := make([]byte, 32<<10)
-	echo := func() error {
-		typ, r, err := c.Reader(ctx)
-		if err != nil {
-			return err
-		}
-
-		w, err := c.Writer(ctx, typ)
-		if err != nil {
-			return err
-		}
-
-		_, err = io.CopyBuffer(w, r, b)
-		if err != nil {
-			return err
-		}
-
-		err = w.Close()
-		if err != nil {
-			return err
-		}
-
-		return nil
-	}
-
-	for {
-		err := echo()
-		if err != nil {
-			return err
-		}
-	}
-}
diff --git a/internal/wsframe/frame.go b/internal/wsframe/frame.go
deleted file mode 100644
index 50ff8c1..0000000
--- a/internal/wsframe/frame.go
+++ /dev/null
@@ -1,194 +0,0 @@
-package wsframe
-
-import (
-	"encoding/binary"
-	"fmt"
-	"io"
-	"math"
-)
-
-// Opcode represents a WebSocket Opcode.
-type Opcode int
-
-// Opcode constants.
-const (
-	OpContinuation Opcode = iota
-	OpText
-	OpBinary
-	// 3 - 7 are reserved for further non-control frames.
-	_
-	_
-	_
-	_
-	_
-	OpClose
-	OpPing
-	OpPong
-	// 11-16 are reserved for further control frames.
-)
-
-func (o Opcode) Control() bool {
-	switch o {
-	case OpClose, OpPing, OpPong:
-		return true
-	}
-	return false
-}
-
-func (o Opcode) Data() bool {
-	switch o {
-	case OpText, OpBinary:
-		return true
-	}
-	return false
-}
-
-// First byte contains fin, rsv1, rsv2, rsv3.
-// Second byte contains mask flag and payload len.
-// Next 8 bytes are the maximum extended payload length.
-// Last 4 bytes are the mask key.
-// https://tools.ietf.org/html/rfc6455#section-5.2
-const maxHeaderSize = 1 + 1 + 8 + 4
-
-// Header represents a WebSocket frame Header.
-// See https://tools.ietf.org/html/rfc6455#section-5.2
-type Header struct {
-	Fin    bool
-	RSV1   bool
-	RSV2   bool
-	RSV3   bool
-	Opcode Opcode
-
-	PayloadLength int64
-
-	Masked  bool
-	MaskKey uint32
-}
-
-// bytes returns the bytes of the Header.
-// See https://tools.ietf.org/html/rfc6455#section-5.2
-func (h Header) Bytes(b []byte) []byte {
-	if b == nil {
-		b = make([]byte, maxHeaderSize)
-	}
-
-	b = b[:2]
-	b[0] = 0
-
-	if h.Fin {
-		b[0] |= 1 << 7
-	}
-	if h.RSV1 {
-		b[0] |= 1 << 6
-	}
-	if h.RSV2 {
-		b[0] |= 1 << 5
-	}
-	if h.RSV3 {
-		b[0] |= 1 << 4
-	}
-
-	b[0] |= byte(h.Opcode)
-
-	switch {
-	case h.PayloadLength < 0:
-		panic(fmt.Sprintf("websocket: invalid Header: negative length: %v", h.PayloadLength))
-	case h.PayloadLength <= 125:
-		b[1] = byte(h.PayloadLength)
-	case h.PayloadLength <= math.MaxUint16:
-		b[1] = 126
-		b = b[:len(b)+2]
-		binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.PayloadLength))
-	default:
-		b[1] = 127
-		b = b[:len(b)+8]
-		binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.PayloadLength))
-	}
-
-	if h.Masked {
-		b[1] |= 1 << 7
-		b = b[:len(b)+4]
-		binary.LittleEndian.PutUint32(b[len(b)-4:], h.MaskKey)
-	}
-
-	return b
-}
-
-func MakeReadHeaderBuf() []byte {
-	return make([]byte, maxHeaderSize-2)
-}
-
-// ReadHeader reads a Header from the reader.
-// See https://tools.ietf.org/html/rfc6455#section-5.2
-func ReadHeader(r io.Reader, b []byte) (Header, error) {
-	// We read the first two bytes first so that we know
-	// exactly how long the Header is.
-	b = b[:2]
-	_, err := io.ReadFull(r, b)
-	if err != nil {
-		return Header{}, err
-	}
-
-	var h Header
-	h.Fin = b[0]&(1<<7) != 0
-	h.RSV1 = b[0]&(1<<6) != 0
-	h.RSV2 = b[0]&(1<<5) != 0
-	h.RSV3 = b[0]&(1<<4) != 0
-
-	h.Opcode = Opcode(b[0] & 0xf)
-
-	var extra int
-
-	h.Masked = b[1]&(1<<7) != 0
-	if h.Masked {
-		extra += 4
-	}
-
-	payloadLength := b[1] &^ (1 << 7)
-	switch {
-	case payloadLength < 126:
-		h.PayloadLength = int64(payloadLength)
-	case payloadLength == 126:
-		extra += 2
-	case payloadLength == 127:
-		extra += 8
-	}
-
-	if extra == 0 {
-		return h, nil
-	}
-
-	b = b[:extra]
-	_, err = io.ReadFull(r, b)
-	if err != nil {
-		return Header{}, err
-	}
-
-	switch {
-	case payloadLength == 126:
-		h.PayloadLength = int64(binary.BigEndian.Uint16(b))
-		b = b[2:]
-	case payloadLength == 127:
-		h.PayloadLength = int64(binary.BigEndian.Uint64(b))
-		if h.PayloadLength < 0 {
-			return Header{}, fmt.Errorf("Header with negative payload length: %v", h.PayloadLength)
-		}
-		b = b[8:]
-	}
-
-	if h.Masked {
-		h.MaskKey = binary.LittleEndian.Uint32(b)
-	}
-
-	return h, nil
-}
-
-const MaxControlFramePayload = 125
-
-func ParseClosePayload(p []byte) (uint16, string, error) {
-	if len(p) < 2 {
-		return 0, "", fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
-	}
-
-	return binary.BigEndian.Uint16(p), string(p[2:]), nil
-}
diff --git a/internal/wsframe/frame_stringer.go b/internal/wsframe/frame_stringer.go
deleted file mode 100644
index b2e7f42..0000000
--- a/internal/wsframe/frame_stringer.go
+++ /dev/null
@@ -1,91 +0,0 @@
-// Code generated by "stringer -type=Opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT.
-
-package wsframe
-
-import "strconv"
-
-func _() {
-	// An "invalid array index" compiler error signifies that the constant values have changed.
-	// Re-run the stringer command to generate them again.
-	var x [1]struct{}
-	_ = x[OpContinuation-0]
-	_ = x[OpText-1]
-	_ = x[OpBinary-2]
-	_ = x[OpClose-8]
-	_ = x[OpPing-9]
-	_ = x[OpPong-10]
-}
-
-const (
-	_opcode_name_0 = "opContinuationopTextopBinary"
-	_opcode_name_1 = "opCloseopPingopPong"
-)
-
-var (
-	_opcode_index_0 = [...]uint8{0, 14, 20, 28}
-	_opcode_index_1 = [...]uint8{0, 7, 13, 19}
-)
-
-func (i Opcode) String() string {
-	switch {
-	case 0 <= i && i <= 2:
-		return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]]
-	case 8 <= i && i <= 10:
-		i -= 8
-		return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]]
-	default:
-		return "Opcode(" + strconv.FormatInt(int64(i), 10) + ")"
-	}
-}
-func _() {
-	// An "invalid array index" compiler error signifies that the constant values have changed.
-	// Re-run the stringer command to generate them again.
-	var x [1]struct{}
-	_ = x[MessageText-1]
-	_ = x[MessageBinary-2]
-}
-
-const _MessageType_name = "MessageTextMessageBinary"
-
-var _MessageType_index = [...]uint8{0, 11, 24}
-
-func (i MessageType) String() string {
-	i -= 1
-	if i < 0 || i >= MessageType(len(_MessageType_index)-1) {
-		return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")"
-	}
-	return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]]
-}
-func _() {
-	// An "invalid array index" compiler error signifies that the constant values have changed.
-	// Re-run the stringer command to generate them again.
-	var x [1]struct{}
-	_ = x[StatusNormalClosure-1000]
-	_ = x[StatusGoingAway-1001]
-	_ = x[StatusProtocolError-1002]
-	_ = x[StatusUnsupportedData-1003]
-	_ = x[statusReserved-1004]
-	_ = x[StatusNoStatusRcvd-1005]
-	_ = x[StatusAbnormalClosure-1006]
-	_ = x[StatusInvalidFramePayloadData-1007]
-	_ = x[StatusPolicyViolation-1008]
-	_ = x[StatusMessageTooBig-1009]
-	_ = x[StatusMandatoryExtension-1010]
-	_ = x[StatusInternalError-1011]
-	_ = x[StatusServiceRestart-1012]
-	_ = x[StatusTryAgainLater-1013]
-	_ = x[StatusBadGateway-1014]
-	_ = x[StatusTLSHandshake-1015]
-}
-
-const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake"
-
-var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312}
-
-func (i StatusCode) String() string {
-	i -= 1000
-	if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) {
-		return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")"
-	}
-	return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]]
-}
diff --git a/internal/wsframe/frame_test.go b/internal/wsframe/frame_test.go
deleted file mode 100644
index d6b66e7..0000000
--- a/internal/wsframe/frame_test.go
+++ /dev/null
@@ -1,157 +0,0 @@
-// +build !js
-
-package wsframe
-
-import (
-	"bytes"
-	"io"
-	"math/rand"
-	"strconv"
-	"testing"
-	"time"
-	_ "unsafe"
-
-	"github.com/google/go-cmp/cmp"
-	_ "github.com/gorilla/websocket"
-)
-
-func init() {
-	rand.Seed(time.Now().UnixNano())
-}
-
-func randBool() bool {
-	return rand.Intn(1) == 0
-}
-
-func TestHeader(t *testing.T) {
-	t.Parallel()
-
-	t.Run("eof", func(t *testing.T) {
-		t.Parallel()
-
-		testCases := []struct {
-			name  string
-			bytes []byte
-		}{
-			{
-				"start",
-				[]byte{0xff},
-			},
-			{
-				"middle",
-				[]byte{0xff, 0xff, 0xff},
-			},
-		}
-		for _, tc := range testCases {
-			tc := tc
-			t.Run(tc.name, func(t *testing.T) {
-				t.Parallel()
-
-				b := bytes.NewBuffer(tc.bytes)
-				_, err := ReadHeader(nil, b)
-				if io.ErrUnexpectedEOF != err {
-					t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err)
-				}
-			})
-		}
-	})
-
-	t.Run("writeNegativeLength", func(t *testing.T) {
-		t.Parallel()
-
-		defer func() {
-			r := recover()
-			if r == nil {
-				t.Fatal("failed to induce panic in writeHeader with negative payload length")
-			}
-		}()
-
-		Header{
-			PayloadLength: -1,
-		}.Bytes(nil)
-	})
-
-	t.Run("readNegativeLength", func(t *testing.T) {
-		t.Parallel()
-
-		b := Header{
-			PayloadLength: 1<<16 + 1,
-		}.Bytes(nil)
-
-		// Make length negative
-		b[2] |= 1 << 7
-
-		r := bytes.NewReader(b)
-		_, err := ReadHeader(nil, r)
-		if err == nil {
-			t.Fatalf("unexpected error value: %+v", err)
-		}
-	})
-
-	t.Run("lengths", func(t *testing.T) {
-		t.Parallel()
-
-		lengths := []int{
-			124,
-			125,
-			126,
-			4096,
-			16384,
-			65535,
-			65536,
-			65537,
-			131072,
-		}
-
-		for _, n := range lengths {
-			n := n
-			t.Run(strconv.Itoa(n), func(t *testing.T) {
-				t.Parallel()
-
-				testHeader(t, Header{
-					PayloadLength: int64(n),
-				})
-			})
-		}
-	})
-
-	t.Run("fuzz", func(t *testing.T) {
-		t.Parallel()
-
-		for i := 0; i < 10000; i++ {
-			h := Header{
-				Fin:    randBool(),
-				RSV1:   randBool(),
-				RSV2:   randBool(),
-				RSV3:   randBool(),
-				Opcode: Opcode(rand.Intn(1 << 4)),
-
-				Masked:        randBool(),
-				PayloadLength: rand.Int63(),
-			}
-
-			if h.Masked {
-				h.MaskKey = rand.Uint32()
-			}
-
-			testHeader(t, h)
-		}
-	})
-}
-
-func testHeader(t *testing.T, h Header) {
-	b := h.Bytes(nil)
-	r := bytes.NewReader(b)
-	h2, err := ReadHeader(r, nil)
-	if err != nil {
-		t.Logf("Header: %#v", h)
-		t.Logf("bytes: %b", b)
-		t.Fatalf("failed to read Header: %v", err)
-	}
-
-	if !cmp.Equal(h, h2, cmp.AllowUnexported(Header{})) {
-		t.Logf("Header: %#v", h)
-		t.Logf("bytes: %b", b)
-		t.Fatalf("parsed and read Header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(Header{})))
-	}
-}
diff --git a/internal/wsgrace/wsgrace.go b/internal/wsgrace/wsgrace.go
deleted file mode 100644
index 513af1f..0000000
--- a/internal/wsgrace/wsgrace.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package wsgrace
-
-import (
-	"context"
-	"fmt"
-	"net/http"
-	"sync/atomic"
-	"time"
-)
-
-// Grace wraps s.Handler to gracefully shutdown WebSocket connections.
-// The returned function must be used to close the server instead of s.Close.
-func Grace(s *http.Server) (closeFn func() error) {
-	h := s.Handler
-	var conns int64
-	s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		atomic.AddInt64(&conns, 1)
-		defer atomic.AddInt64(&conns, -1)
-
-		ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
-		defer cancel()
-
-		r = r.WithContext(ctx)
-
-		h.ServeHTTP(w, r)
-	})
-
-	return func() error {
-		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-		defer cancel()
-
-		err := s.Shutdown(ctx)
-		if err != nil {
-			return fmt.Errorf("server shutdown failed: %v", err)
-		}
-
-		t := time.NewTicker(time.Millisecond * 10)
-		defer t.Stop()
-		for {
-			select {
-			case <-t.C:
-				if atomic.LoadInt64(&conns) == 0 {
-					return nil
-				}
-			case <-ctx.Done():
-				return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err())
-			}
-		}
-	}
-}
diff --git a/js_test.go b/js_test.go
deleted file mode 100644
index 80af789..0000000
--- a/js_test.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package websocket_test
-
-import (
-	"context"
-	"fmt"
-	"net/http"
-	"nhooyr.io/websocket/internal/wsecho"
-	"os"
-	"os/exec"
-	"strings"
-	"testing"
-	"time"
-
-	"nhooyr.io/websocket"
-)
-
-func TestJS(t *testing.T) {
-	t.Parallel()
-
-	s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error {
-		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-			Subprotocols:       []string{"echo"},
-			InsecureSkipVerify: true,
-		})
-		if err != nil {
-			return err
-		}
-		defer c.Close(websocket.StatusInternalError, "")
-
-		err = wsecho.Loop(r.Context(), c)
-		if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
-			return err
-		}
-		return nil
-	}, false)
-	defer closeFn()
-
-	wsURL := strings.Replace(s.URL, "http", "ws", 1)
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-	defer cancel()
-
-	cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...")
-	cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL))
-
-	b, err := cmd.CombinedOutput()
-	if err != nil {
-		t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
-	}
-}
diff --git a/read.go b/read.go
new file mode 100644
index 0000000..97096f7
--- /dev/null
+++ b/read.go
@@ -0,0 +1,479 @@
+package websocket
+
+import (
+	"bufio"
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"log"
+	"nhooyr.io/websocket/internal/errd"
+	"strings"
+	"sync/atomic"
+	"time"
+)
+
+// Reader waits until there is a WebSocket data message to read
+// from the connection.
+// It returns the type of the message and a reader to read it.
+// The passed context will also bound the reader.
+// Ensure you read to EOF otherwise the connection will hang.
+//
+// All returned errors will cause the connection
+// to be closed so you do not need to write your own error message.
+// This applies to the Read methods in the wsjson/wspb subpackages as well.
+//
+// You must read from the connection for control frames to be handled.
+// Thus if you expect messages to take a long time to be responded to,
+// you should handle such messages async to reading from the connection
+// to ensure control frames are promptly handled.
+//
+// If you do not expect any data messages from the peer, call CloseRead.
+//
+// Only one Reader may be open at a time.
+//
+// If you need a separate timeout on the Reader call and then the message
+// Read, use time.AfterFunc to cancel the context passed in early.
+// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
+// Most users should not need this.
+func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
+	typ, r, err := c.cr.reader(ctx)
+	if err != nil {
+		return 0, nil, fmt.Errorf("failed to get reader: %w", err)
+	}
+	return typ, r, nil
+}
+
+// Read is a convenience method to read a single message from the connection.
+//
+// See the Reader method to reuse buffers or for streaming.
+// The docs on Reader apply to this method as well.
+func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
+	typ, r, err := c.Reader(ctx)
+	if err != nil {
+		return 0, nil, err
+	}
+
+	b, err := ioutil.ReadAll(r)
+	return typ, b, err
+}
+
+// CloseRead will start a goroutine to read from the connection until it is closed or a data message
+// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
+// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
+// After calling this method, you cannot read any data messages from the connection.
+// The returned context will be cancelled when the connection is closed.
+//
+// Use this when you do not want to read data messages from the connection anymore but will
+// want to write messages to it.
+func (c *Conn) CloseRead(ctx context.Context) context.Context {
+	ctx, cancel := context.WithCancel(ctx)
+	go func() {
+		defer cancel()
+		c.Reader(ctx)
+		c.Close(StatusPolicyViolation, "unexpected data message")
+	}()
+	return ctx
+}
+
+// SetReadLimit sets the max number of bytes to read for a single message.
+// It applies to the Reader and Read methods.
+//
+// By default, the connection has a message read limit of 32768 bytes.
+//
+// When the limit is hit, the connection will be closed with StatusMessageTooBig.
+func (c *Conn) SetReadLimit(n int64) {
+	c.cr.mr.lr.limit.Store(n)
+}
+
+type connReader struct {
+	c       *Conn
+	br      *bufio.Reader
+	timeout chan context.Context
+
+	mu                mu
+	controlPayloadBuf [maxControlPayload]byte
+	mr                *msgReader
+}
+
+func (cr *connReader) init(c *Conn, br *bufio.Reader) {
+	cr.c = c
+	cr.br = br
+	cr.timeout = make(chan context.Context)
+
+	cr.mr = &msgReader{
+		cr:  cr,
+		fin: true,
+	}
+
+	cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768)
+	if c.deflateNegotiated() && cr.contextTakeover() {
+		cr.ensureFlateReader()
+	}
+}
+
+func (cr *connReader) ensureFlateReader() {
+	cr.mr.fr = getFlateReader(readerFunc(cr.mr.read))
+	cr.mr.lr.reset(cr.mr.fr)
+}
+
+func (cr *connReader) close() {
+	cr.mu.Lock(context.Background())
+	if cr.c.client {
+		putBufioReader(cr.br)
+	}
+	if cr.c.deflateNegotiated() && cr.contextTakeover() {
+		putFlateReader(cr.mr.fr)
+	}
+}
+
+func (cr *connReader) contextTakeover() bool {
+	if cr.c.client {
+		return cr.c.copts.serverNoContextTakeover
+	}
+	return cr.c.copts.clientNoContextTakeover
+}
+
+func (cr *connReader) rsv1Illegal(h header) bool {
+	// If compression is enabled, rsv1 is always illegal.
+	if !cr.c.deflateNegotiated() {
+		return true
+	}
+	// rsv1 is only allowed on data frames beginning messages.
+	if h.opcode != opText && h.opcode != opBinary {
+		return true
+	}
+	return false
+}
+
+func (cr *connReader) loop(ctx context.Context) (header, error) {
+	for {
+		h, err := cr.frameHeader(ctx)
+		if err != nil {
+			return header{}, err
+		}
+
+		if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 {
+			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
+			cr.c.cw.error(StatusProtocolError, err)
+			return header{}, err
+		}
+
+		if !cr.c.client && !h.masked {
+			return header{}, errors.New("received unmasked frame from client")
+		}
+
+		switch h.opcode {
+		case opClose, opPing, opPong:
+			err = cr.control(ctx, h)
+			if err != nil {
+				// Pass through CloseErrors when receiving a close frame.
+				if h.opcode == opClose && CloseStatus(err) != -1 {
+					return header{}, err
+				}
+				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
+			}
+		case opContinuation, opText, opBinary:
+			return h, nil
+		default:
+			err := fmt.Errorf("received unknown opcode %v", h.opcode)
+			cr.c.cw.error(StatusProtocolError, err)
+			return header{}, err
+		}
+	}
+}
+
+func (cr *connReader) frameHeader(ctx context.Context) (header, error) {
+	select {
+	case <-cr.c.closed:
+		return header{}, cr.c.closeErr
+	case cr.timeout <- ctx:
+	}
+
+	h, err := readFrameHeader(cr.br)
+	if err != nil {
+		select {
+		case <-cr.c.closed:
+			return header{}, cr.c.closeErr
+		case <-ctx.Done():
+			return header{}, ctx.Err()
+		default:
+			cr.c.close(err)
+			return header{}, err
+		}
+	}
+
+	select {
+	case <-cr.c.closed:
+		return header{}, cr.c.closeErr
+	case cr.timeout <- context.Background():
+	}
+
+	return h, nil
+}
+
+func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) {
+	select {
+	case <-cr.c.closed:
+		return 0, cr.c.closeErr
+	case cr.timeout <- ctx:
+	}
+
+	n, err := io.ReadFull(cr.br, p)
+	if err != nil {
+		select {
+		case <-cr.c.closed:
+			return n, cr.c.closeErr
+		case <-ctx.Done():
+			return n, ctx.Err()
+		default:
+			err = fmt.Errorf("failed to read frame payload: %w", err)
+			cr.c.close(err)
+			return n, err
+		}
+	}
+
+	select {
+	case <-cr.c.closed:
+		return n, cr.c.closeErr
+	case cr.timeout <- context.Background():
+	}
+
+	return n, err
+}
+
+func (cr *connReader) control(ctx context.Context, h header) error {
+	if h.payloadLength < 0 {
+		err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength)
+		cr.c.cw.error(StatusProtocolError, err)
+		return err
+	}
+
+	if h.payloadLength > maxControlPayload {
+		err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength)
+		cr.c.cw.error(StatusProtocolError, err)
+		return err
+	}
+
+	if !h.fin {
+		err := errors.New("received fragmented control frame")
+		cr.c.cw.error(StatusProtocolError, err)
+		return err
+	}
+
+	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
+	defer cancel()
+
+	b := cr.controlPayloadBuf[:h.payloadLength]
+	_, err := cr.framePayload(ctx, b)
+	if err != nil {
+		return err
+	}
+
+	if h.masked {
+		mask(h.maskKey, b)
+	}
+
+	switch h.opcode {
+	case opPing:
+		return cr.c.cw.control(ctx, opPong, b)
+	case opPong:
+		cr.c.activePingsMu.Lock()
+		pong, ok := cr.c.activePings[string(b)]
+		cr.c.activePingsMu.Unlock()
+		if ok {
+			close(pong)
+		}
+		return nil
+	}
+
+	ce, err := parseClosePayload(b)
+	if err != nil {
+		err = fmt.Errorf("received invalid close payload: %w", err)
+		cr.c.cw.error(StatusProtocolError, err)
+		return err
+	}
+
+	err = fmt.Errorf("received close frame: %w", ce)
+	cr.c.setCloseErr(err)
+	cr.c.cw.control(context.Background(), opClose, ce.bytes())
+	return err
+}
+
+func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) {
+	err := cr.mu.Lock(ctx)
+	if err != nil {
+		return 0, nil, err
+	}
+	defer cr.mu.Unlock()
+
+	if !cr.mr.fin {
+		return 0, nil, errors.New("previous message not read to completion")
+	}
+
+	h, err := cr.loop(ctx)
+	if err != nil {
+		return 0, nil, err
+	}
+
+	if h.opcode == opContinuation {
+		err := errors.New("received continuation frame without text or binary frame")
+		cr.c.cw.error(StatusProtocolError, err)
+		return 0, nil, err
+	}
+
+	cr.mr.reset(ctx, h)
+
+	return MessageType(h.opcode), cr.mr, nil
+}
+
+type msgReader struct {
+	cr *connReader
+	fr io.Reader
+	lr *limitReader
+
+	ctx context.Context
+
+	deflate     bool
+	deflateTail strings.Reader
+
+	payloadLength int64
+	maskKey       uint32
+	fin           bool
+}
+
+func (mr *msgReader) reset(ctx context.Context, h header) {
+	mr.ctx = ctx
+	mr.deflate = h.rsv1
+	if mr.deflate {
+		mr.deflateTail.Reset(deflateMessageTail)
+		if !mr.cr.contextTakeover() {
+			mr.cr.ensureFlateReader()
+		}
+	}
+	mr.setFrame(h)
+	mr.fin = false
+}
+
+func (mr *msgReader) setFrame(h header) {
+	mr.payloadLength = h.payloadLength
+	mr.maskKey = h.maskKey
+	mr.fin = h.fin
+}
+
+func (mr *msgReader) Read(p []byte) (_ int, err error) {
+	defer func() {
+		errd.Wrap(&err, "failed to read")
+		if errors.Is(err, io.EOF) {
+			err = io.EOF
+		}
+	}()
+
+	err = mr.cr.mu.Lock(mr.ctx)
+	if err != nil {
+		return 0, err
+	}
+	defer mr.cr.mu.Unlock()
+
+	if mr.payloadLength == 0 && mr.fin {
+		if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() {
+			if mr.fr != nil {
+				putFlateReader(mr.fr)
+				mr.fr = nil
+			}
+		}
+		return 0, io.EOF
+	}
+
+	return mr.lr.Read(p)
+}
+
+func (mr *msgReader) read(p []byte) (int, error) {
+	log.Println("compress", mr.deflate)
+
+	if mr.payloadLength == 0 {
+		h, err := mr.cr.loop(mr.ctx)
+		if err != nil {
+			return 0, err
+		}
+		if h.opcode != opContinuation {
+			err := errors.New("received new data message without finishing the previous message")
+			mr.cr.c.cw.error(StatusProtocolError, err)
+			return 0, err
+		}
+		mr.setFrame(h)
+	}
+
+	if int64(len(p)) > mr.payloadLength {
+		p = p[:mr.payloadLength]
+	}
+
+	n, err := mr.cr.framePayload(mr.ctx, p)
+	if err != nil {
+		return n, err
+	}
+
+	mr.payloadLength -= int64(n)
+
+	if !mr.cr.c.client {
+		mr.maskKey = mask(mr.maskKey, p)
+	}
+
+	return n, nil
+}
+
+type limitReader struct {
+	c     *Conn
+	r     io.Reader
+	limit atomicInt64
+	n     int64
+}
+
+func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
+	lr := &limitReader{
+		c: c,
+	}
+	lr.limit.Store(limit)
+	lr.reset(r)
+	return lr
+}
+
+func (lr *limitReader) reset(r io.Reader) {
+	lr.n = lr.limit.Load()
+	lr.r = r
+}
+
+func (lr *limitReader) Read(p []byte) (int, error) {
+	if lr.n <= 0 {
+		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
+		lr.c.cw.error(StatusMessageTooBig, err)
+		return 0, err
+	}
+
+	if int64(len(p)) > lr.n {
+		p = p[:lr.n]
+	}
+	n, err := lr.r.Read(p)
+	lr.n -= int64(n)
+	return n, err
+}
+
+type atomicInt64 struct {
+	i atomic.Value
+}
+
+func (v *atomicInt64) Load() int64 {
+	i, _ := v.i.Load().(int64)
+	return i
+}
+
+func (v *atomicInt64) Store(i int64) {
+	v.i.Store(i)
+}
+
+type readerFunc func(p []byte) (int, error)
+
+func (f readerFunc) Read(p []byte) (int, error) {
+	return f(p)
+}
diff --git a/reader.go b/reader.go
deleted file mode 100644
index fe71656..0000000
--- a/reader.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package websocket
-
-import (
-	"bufio"
-	"context"
-	"io"
-	"nhooyr.io/websocket/internal/atomicint"
-	"nhooyr.io/websocket/internal/wsframe"
-	"strings"
-)
-
-type reader struct {
-	// Acquired before performing any sort of read operation.
-	readLock chan struct{}
-
-	c *Conn
-
-	deflateReader io.Reader
-	br            *bufio.Reader
-
-	readClosed        *atomicint.Int64
-	readHeaderBuf     []byte
-	controlPayloadBuf []byte
-
-	msgCtx        context.Context
-	msgCompressed bool
-	frameHeader   wsframe.Header
-	frameMaskKey  uint32
-	frameEOF      bool
-	deflateTail   strings.Reader
-}
diff --git a/write.go b/write.go
new file mode 100644
index 0000000..5bb489b
--- /dev/null
+++ b/write.go
@@ -0,0 +1,348 @@
+package websocket
+
+import (
+	"bufio"
+	"compress/flate"
+	"context"
+	"crypto/rand"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+	"nhooyr.io/websocket/internal/errd"
+	"time"
+)
+
+// Writer returns a writer bounded by the context that will write
+// a WebSocket message of type dataType to the connection.
+//
+// You must close the writer once you have written the entire message.
+//
+// Only one writer can be open at a time, multiple calls will block until the previous writer
+// is closed.
+//
+// Never close the returned writer twice.
+func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
+	w, err := c.cw.writer(ctx, typ)
+	if err != nil {
+		return nil, fmt.Errorf("failed to get writer: %w", err)
+	}
+	return w, nil
+}
+
+// Write writes a message to the connection.
+//
+// See the Writer method if you want to stream a message.
+//
+// If compression is disabled, then it is guaranteed to write the message
+// in a single frame.
+func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
+	_, err := c.cw.write(ctx, typ, p)
+	if err != nil {
+		return fmt.Errorf("failed to write msg: %w", err)
+	}
+	return nil
+}
+
+type connWriter struct {
+	c  *Conn
+	bw *bufio.Writer
+
+	writeBuf []byte
+
+	mw      *messageWriter
+	frameMu mu
+	h       header
+
+	timeout chan context.Context
+}
+
+func (cw *connWriter) init(c *Conn, bw *bufio.Writer) {
+	cw.c = c
+	cw.bw = bw
+
+	if cw.c.client {
+		cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc)
+	}
+
+	cw.timeout = make(chan context.Context)
+
+	cw.mw = &messageWriter{
+		cw: cw,
+	}
+	cw.mw.tw = &trimLastFourBytesWriter{
+		w: writerFunc(cw.mw.write),
+	}
+	if cw.c.deflateNegotiated() && cw.mw.contextTakeover() {
+		cw.mw.ensureFlateWriter()
+	}
+}
+
+func (mw *messageWriter) ensureFlateWriter() {
+	mw.fw = getFlateWriter(mw.tw)
+}
+
+func (cw *connWriter) close() {
+	if cw.c.client {
+		cw.frameMu.Lock(context.Background())
+		putBufioWriter(cw.bw)
+	}
+	if cw.c.deflateNegotiated() && cw.mw.contextTakeover() {
+		cw.mw.mu.Lock(context.Background())
+		putFlateWriter(cw.mw.fw)
+	}
+}
+
+func (mw *messageWriter) contextTakeover() bool {
+	if mw.cw.c.client {
+		return mw.cw.c.copts.clientNoContextTakeover
+	}
+	return mw.cw.c.copts.serverNoContextTakeover
+}
+
+func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
+	err := cw.mw.reset(ctx, typ)
+	if err != nil {
+		return nil, err
+	}
+	return cw.mw, nil
+}
+
+func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
+	ww, err := cw.writer(ctx, typ)
+	if err != nil {
+		return 0, err
+	}
+
+	if !cw.c.deflateNegotiated() {
+		// Fast single frame path.
+		defer cw.mw.mu.Unlock()
+		return cw.frame(ctx, true, cw.mw.opcode, p)
+	}
+
+	n, err := ww.Write(p)
+	if err != nil {
+		return n, err
+	}
+
+	err = ww.Close()
+	return n, err
+}
+
+type messageWriter struct {
+	cw *connWriter
+
+	mu       mu
+	compress bool
+	tw       *trimLastFourBytesWriter
+	fw       *flate.Writer
+	ctx      context.Context
+	opcode   opcode
+	closed   bool
+}
+
+func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error {
+	err := mw.mu.Lock(ctx)
+	if err != nil {
+		return err
+	}
+
+	mw.closed = false
+	mw.ctx = ctx
+	mw.opcode = opcode(typ)
+	return nil
+}
+
+// Write writes the given bytes to the WebSocket connection.
+func (mw *messageWriter) Write(p []byte) (_ int, err error) {
+	defer errd.Wrap(&err, "failed to write")
+
+	if mw.closed {
+		return 0, errors.New("cannot use closed writer")
+	}
+
+	if mw.cw.c.deflateNegotiated() {
+		if !mw.compress {
+			if !mw.contextTakeover() {
+				mw.ensureFlateWriter()
+			}
+			mw.tw.reset()
+			mw.compress = true
+		}
+
+		return mw.fw.Write(p)
+	}
+
+	return mw.write(p)
+}
+
+func (mw *messageWriter) write(p []byte) (int, error) {
+	n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p)
+	if err != nil {
+		return n, fmt.Errorf("failed to write data frame: %w", err)
+	}
+	mw.opcode = opContinuation
+	return n, nil
+}
+
+// Close flushes the frame to the connection.
+// This must be called for every messageWriter.
+func (mw *messageWriter) Close() (err error) {
+	defer errd.Wrap(&err, "failed to close writer")
+
+	if mw.closed {
+		return errors.New("cannot use closed writer")
+	}
+	mw.closed = true
+
+	if mw.cw.c.deflateNegotiated() {
+		err = mw.fw.Flush()
+		if err != nil {
+			return fmt.Errorf("failed to flush flate writer: %w", err)
+		}
+	}
+
+	_, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil)
+	if err != nil {
+		return fmt.Errorf("failed to write fin frame: %w", err)
+	}
+
+	if mw.compress && !mw.contextTakeover() {
+		putFlateWriter(mw.fw)
+		mw.compress = false
+	}
+
+	mw.mu.Unlock()
+	return nil
+}
+
+func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error {
+	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
+	defer cancel()
+
+	_, err := cw.frame(ctx, true, opcode, p)
+	if err != nil {
+		return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
+	}
+	return nil
+}
+
+// frame handles all writes to the connection.
+func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
+	err := cw.frameMu.Lock(ctx)
+	if err != nil {
+		return 0, err
+	}
+	defer cw.frameMu.Unlock()
+
+	select {
+	case <-cw.c.closed:
+		return 0, cw.c.closeErr
+	case cw.timeout <- ctx:
+	}
+
+	cw.h.fin = fin
+	cw.h.opcode = opcode
+	cw.h.masked = cw.c.client
+	cw.h.payloadLength = int64(len(p))
+
+	cw.h.rsv1 = false
+	if cw.mw.compress && (opcode == opText || opcode == opBinary) {
+		cw.h.rsv1 = true
+	}
+
+	if cw.h.masked {
+		err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey)
+		if err != nil {
+			return 0, fmt.Errorf("failed to generate masking key: %w", err)
+		}
+	}
+
+	err = writeFrameHeader(cw.h, cw.bw)
+	if err != nil {
+		return 0, err
+	}
+
+	n, err := cw.framePayload(p)
+	if err != nil {
+		return n, err
+	}
+
+	if cw.h.fin {
+		err = cw.bw.Flush()
+		if err != nil {
+			return n, fmt.Errorf("failed to flush: %w", err)
+		}
+	}
+
+	select {
+	case <-cw.c.closed:
+		return n, cw.c.closeErr
+	case cw.timeout <- context.Background():
+	}
+
+	return n, nil
+}
+
+func (cw *connWriter) framePayload(p []byte) (_ int, err error) {
+	defer errd.Wrap(&err, "failed to write frame payload")
+
+	if !cw.h.masked {
+		return cw.bw.Write(p)
+	}
+
+	var n int
+	maskKey := cw.h.maskKey
+	for len(p) > 0 {
+		// If the buffer is full, we need to flush.
+		if cw.bw.Available() == 0 {
+			err = cw.bw.Flush()
+			if err != nil {
+				return n, err
+			}
+		}
+
+		// Start of next write in the buffer.
+		i := cw.bw.Buffered()
+
+		j := len(p)
+		if j > cw.bw.Available() {
+			j = cw.bw.Available()
+		}
+
+		_, err := cw.bw.Write(p[:j])
+		if err != nil {
+			return n, err
+		}
+
+		maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()])
+
+		p = p[j:]
+		n += j
+	}
+
+	return n, nil
+}
+
+type writerFunc func(p []byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) {
+	return f(p)
+}
+
+// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
+// and returns it.
+func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
+	var writeBuf []byte
+	bw.Reset(writerFunc(func(p2 []byte) (int, error) {
+		writeBuf = p2[:cap(p2)]
+		return len(p2), nil
+	}))
+
+	bw.WriteByte(0)
+	bw.Flush()
+
+	bw.Reset(w)
+
+	return writeBuf
+}
diff --git a/writer.go b/writer.go
deleted file mode 100644
index b31d57a..0000000
--- a/writer.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package websocket
-
-type writer struct {
-
-}
diff --git a/ws_js.go b/ws_js.go
index 4c06743..10ce0da 100644
--- a/ws_js.go
+++ b/ws_js.go
@@ -9,7 +9,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
-	"nhooyr.io/websocket/internal/atomicint"
+	"nhooyr.io/websocket/internal/wssync"
 	"reflect"
 	"runtime"
 	"sync"
@@ -24,10 +24,10 @@ type Conn struct {
 	ws wsjs.WebSocket
 
 	// read limit for a message in bytes.
-	msgReadLimit *atomicint.Int64
+	msgReadLimit *wssync.Int64
 
 	closingMu     sync.Mutex
-	isReadClosed  *atomicint.Int64
+	isReadClosed  *wssync.Int64
 	closeOnce     sync.Once
 	closed        chan struct{}
 	closeErrOnce  sync.Once
@@ -59,10 +59,10 @@ func (c *Conn) init() {
 	c.closed = make(chan struct{})
 	c.readSignal = make(chan struct{}, 1)
 
-	c.msgReadLimit = &atomicint.Int64{}
+	c.msgReadLimit = &wssync.Int64{}
 	c.msgReadLimit.Store(32768)
 
-	c.isReadClosed = &atomicint.Int64{}
+	c.isReadClosed = &wssync.Int64{}
 
 	c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
 		err := CloseError{
@@ -105,7 +105,7 @@ func (c *Conn) closeWithInternal() {
 // The maximum time spent waiting is bounded by the context.
 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
 	if c.isReadClosed.Load() == 1 {
-		return 0, nil, fmt.Errorf("websocket connection read closed")
+		return 0, nil, errors.New("websocket connection read closed")
 	}
 
 	typ, p, err := c.read(ctx)
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
index 9fa8b54..e818805 100644
--- a/wsjson/wsjson.go
+++ b/wsjson/wsjson.go
@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"log"
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket/internal/bufpool"
 )
@@ -41,6 +42,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
 	err = json.Unmarshal(b.Bytes(), v)
 	if err != nil {
 		c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
+		log.Printf("%X", b.Bytes())
 		return fmt.Errorf("failed to unmarshal json: %w", err)
 	}
 
-- 
GitLab