From 43cb01eaf9fad1e2052a18b69b777db62820aae7 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Fri, 29 Nov 2019 00:00:52 -0500
Subject: [PATCH] Refactor read.go/write.go

---
 README.md                 |  43 +++---
 assert_test.go            |  13 +-
 close.go                  |  64 +++++----
 conn.go                   |  92 ++++++++++---
 conn_test.go              |   3 +-
 internal/assert/assert.go |   2 +-
 read.go                   | 266 ++++++++++++++++----------------------
 write.go                  | 215 +++++++++++++-----------------
 wsjson/wsjson.go          |   1 -
 9 files changed, 345 insertions(+), 354 deletions(-)

diff --git a/README.md b/README.md
index efb4a59..f0babdf 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@ go get nhooyr.io/websocket
 - Concurrent writes
 - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close)
 - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
-- [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping)
+- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping)
 - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
 - Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm)
 
@@ -88,26 +88,27 @@ c.Close(websocket.StatusNormalClosure, "")
 [gorilla/websocket](https://github.com/gorilla/websocket) is a widely used and mature library.
 
 Advantages of nhooyr.io/websocket:
-  - Minimal and idiomatic API
-    - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side.
-  - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
-  - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
-  - Full [context.Context](https://blog.golang.org/context) support
-  - Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing
-    - Will enable easy HTTP/2 support in the future
-    - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client.
-  - Concurrent writes
-  - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
-  - Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API
-    - gorilla/websocket requires registering a pong callback and then sending a Ping
-  - Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
-  - Transparent buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages
-  - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
-    - Gorilla's implementation depends on unsafe and is slower
-  - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
+
+- Minimal and idiomatic API
+  - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side.
+- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
+- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
+- Full [context.Context](https://blog.golang.org/context) support
+- Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing
+  - Will enable easy HTTP/2 support in the future
+  - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client.
+- Concurrent writes
+- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
+- Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API
+  - gorilla/websocket requires registering a pong callback and then sending a Ping
+- Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
+- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages
+- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
+  - Gorilla's implementation depends on unsafe and is slower
+- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
   - Gorilla only supports no context takeover mode
-  - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper
-  - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
+- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper
+- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
 
 #### golang.org/x/net/websocket
 
@@ -120,7 +121,7 @@ to nhooyr.io/websocket.
 #### gobwas/ws
 
 [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
-in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). 
+in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
 
 However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use.
 
diff --git a/assert_test.go b/assert_test.go
index b6e50a4..e431993 100644
--- a/assert_test.go
+++ b/assert_test.go
@@ -4,12 +4,11 @@ import (
 	"context"
 	"crypto/rand"
 	"io"
-	"strings"
-	"testing"
-
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket/internal/assert"
 	"nhooyr.io/websocket/wsjson"
+	"strings"
+	"testing"
 )
 
 func randBytes(t *testing.T, n int) []byte {
@@ -21,12 +20,15 @@ func randBytes(t *testing.T, n int) []byte {
 
 func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) {
 	t.Helper()
+	defer c.Close(websocket.StatusInternalError, "")
 
 	exp := randString(t, n)
 	err := wsjson.Write(ctx, c, exp)
 	assert.Success(t, err)
 
 	assertJSONRead(t, ctx, c, exp)
+
+	c.Close(websocket.StatusNormalClosure, "")
 }
 
 func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
@@ -74,5 +76,10 @@ func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) {
 
 func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) {
 	t.Helper()
+	defer func() {
+		if t.Failed() {
+			t.Logf("error: %+v", err)
+		}
+	}()
 	assert.Equal(t, exp, websocket.CloseStatus(err), "StatusCode")
 }
diff --git a/close.go b/close.go
index a02dc7d..4c474b7 100644
--- a/close.go
+++ b/close.go
@@ -7,9 +7,6 @@ import (
 	"fmt"
 	"log"
 	"nhooyr.io/websocket/internal/errd"
-	"time"
-
-	"nhooyr.io/websocket/internal/bpool"
 )
 
 // StatusCode represents a WebSocket status code.
@@ -103,59 +100,58 @@ func (c *Conn) Close(code StatusCode, reason string) error {
 func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
 	defer errd.Wrap(&err, "failed to close WebSocket")
 
-	err = c.cw.sendClose(code, reason)
+	err = c.writeClose(code, reason)
 	if err != nil {
 		return err
 	}
 
-	return c.cr.waitClose()
+	return c.waitClose()
 }
 
-func (cw *connWriter) error(code StatusCode, err error) {
-	cw.c.setCloseErr(err)
-	cw.sendClose(code, err.Error())
-	cw.c.closeWithErr(nil)
+func (c *Conn) writeError(code StatusCode, err error) {
+	c.setCloseErr(err)
+	c.writeClose(code, err.Error())
+	c.closeWithErr(nil)
 }
 
-func (cw *connWriter) sendClose(code StatusCode, reason string) error {
+func (c *Conn) writeClose(code StatusCode, reason string) error {
 	ce := CloseError{
 		Code:   code,
 		Reason: reason,
 	}
 
-	cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
+	c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
 
 	var p []byte
 	if ce.Code != StatusNoStatusRcvd {
 		p = ce.bytes()
 	}
 
-	return cw.control(context.Background(), opClose, p)
+	return c.writeControl(context.Background(), opClose, p)
 }
 
-func (cr *connReader) waitClose() error {
-	defer cr.c.closeWithErr(nil)
+func (c *Conn) waitClose() error {
+	defer c.closeWithErr(nil)
 
 	return nil
 
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-	defer cancel()
-
-	err := cr.mu.Lock(ctx)
-	if err != nil {
-		return err
-	}
-	defer cr.mu.Unlock()
-
-	b := bpool.Get()
-	buf := b.Bytes()
-	buf = buf[:cap(buf)]
-	defer bpool.Put(b)
-
-	for {
-		// TODO
-		return nil
-	}
+	// ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	// defer cancel()
+	//
+	// err := cr.mu.Lock(ctx)
+	// if err != nil {
+	// 	return err
+	// }
+	// defer cr.mu.Unlock()
+	//
+	// b := bpool.Get()
+	// buf := b.Bytes()
+	// buf = buf[:cap(buf)]
+	// defer bpool.Put(b)
+	//
+	// for {
+	// 	return nil
+	// }
 }
 
 func parseClosePayload(p []byte) (CloseError, error) {
@@ -230,11 +226,11 @@ func (ce CloseError) bytesErr() ([]byte, error) {
 
 func (c *Conn) setCloseErr(err error) {
 	c.closeMu.Lock()
-	c.setCloseErrNoLock(err)
+	c.setCloseErrLocked(err)
 	c.closeMu.Unlock()
 }
 
-func (c *Conn) setCloseErrNoLock(err error) {
+func (c *Conn) setCloseErrLocked(err error) {
 	if c.closeErr == nil {
 		c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
 	}
diff --git a/conn.go b/conn.go
index d900179..dc067d1 100644
--- a/conn.go
+++ b/conn.go
@@ -30,7 +30,7 @@ const (
 // All methods may be called concurrently except for Reader and Read.
 //
 // You must always read from the connection. Otherwise control
-// frames will not be handled. See the docs on Reader and CloseRead.
+// frames will not be handled. See Reader and CloseRead.
 //
 // Be sure to call Close on the connection when you
 // are finished with it to release associated resources.
@@ -42,9 +42,22 @@ type Conn struct {
 	rwc         io.ReadWriteCloser
 	client      bool
 	copts       *compressionOptions
+	br          *bufio.Reader
+	bw          *bufio.Writer
 
-	cr connReader
-	cw connWriter
+	readTimeout  chan context.Context
+	writeTimeout chan context.Context
+
+	// Read state.
+	readMu         mu
+	readControlBuf [maxControlPayload]byte
+	msgReader      *msgReader
+
+	// Write state.
+	msgWriter    *msgWriter
+	writeFrameMu mu
+	writeBuf     []byte
+	writeHeader  header
 
 	closed chan struct{}
 
@@ -63,8 +76,8 @@ type connConfig struct {
 	client      bool
 	copts       *compressionOptions
 
-	bw *bufio.Writer
 	br *bufio.Reader
+	bw *bufio.Writer
 }
 
 func newConn(cfg connConfig) *Conn {
@@ -73,13 +86,23 @@ func newConn(cfg connConfig) *Conn {
 		rwc:         cfg.rwc,
 		client:      cfg.client,
 		copts:       cfg.copts,
+
+		br: cfg.br,
+		bw: cfg.bw,
+
+		readTimeout: make(chan context.Context),
+		writeTimeout: make(chan context.Context),
+
+		closed: make(chan struct{}),
+		activePings: make(map[string]chan<- struct{}),
 	}
 
-	c.cr.init(c, cfg.br)
-	c.cw.init(c, cfg.bw)
+	c.msgReader = newMsgReader(c)
 
-	c.closed = make(chan struct{})
-	c.activePings = make(map[string]chan<- struct{})
+	c.msgWriter = newMsgWriter(c)
+	if c.client {
+		c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
+	}
 
 	runtime.SetFinalizer(c, func(c *Conn) {
 		c.closeWithErr(errors.New("connection garbage collected"))
@@ -90,6 +113,34 @@ func newConn(cfg connConfig) *Conn {
 	return c
 }
 
+func newMsgReader(c *Conn) *msgReader {
+	mr := &msgReader{
+		c:   c,
+		fin: true,
+	}
+
+	mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768)
+	if c.deflateNegotiated() && mr.contextTakeover() {
+		mr.ensureFlateReader()
+	}
+
+	return mr
+}
+
+func newMsgWriter(c *Conn) *msgWriter {
+	mw := &msgWriter{
+		c: c,
+	}
+	mw.trimWriter = &trimLastFourBytesWriter{
+		w: writerFunc(mw.write),
+	}
+	if c.deflateNegotiated() && mw.contextTakeover() {
+		mw.ensureFlateWriter()
+	}
+
+	return mw
+}
+
 // Subprotocol returns the negotiated subprotocol.
 // An empty string means the default protocol.
 func (c *Conn) Subprotocol() string {
@@ -105,7 +156,7 @@ func (c *Conn) closeWithErr(err error) {
 	}
 	close(c.closed)
 	runtime.SetFinalizer(c, nil)
-	c.setCloseErrNoLock(err)
+	c.setCloseErrLocked(err)
 
 	// Have to close after c.closed is closed to ensure any goroutine that wakes up
 	// from the connection being closed also sees that c.closed is closed and returns
@@ -113,8 +164,18 @@ func (c *Conn) closeWithErr(err error) {
 	c.rwc.Close()
 
 	go func() {
-		c.cr.close()
-		c.cw.close()
+		if c.client {
+			c.writeFrameMu.Lock(context.Background())
+			putBufioWriter(c.bw)
+		}
+		c.msgWriter.close()
+
+		if c.client {
+			c.readMu.Lock(context.Background())
+			putBufioReader(c.br)
+			c.readMu.Unlock()
+		}
+		c.msgReader.close()
 	}()
 }
 
@@ -127,13 +188,12 @@ func (c *Conn) timeoutLoop() {
 		case <-c.closed:
 			return
 
-		case writeCtx = <-c.cw.timeout:
-		case readCtx = <-c.cr.timeout:
+		case writeCtx = <-c.writeTimeout:
+		case readCtx = <-c.readTimeout:
 
 		case <-readCtx.Done():
 			c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
-			c.cw.error(StatusPolicyViolation, errors.New("timed out"))
-			return
+			go c.writeError(StatusPolicyViolation, errors.New("timed out"))
 		case <-writeCtx.Done():
 			c.closeWithErr(fmt.Errorf("write timed out: %w", writeCtx.Err()))
 			return
@@ -175,7 +235,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
 		c.activePingsMu.Unlock()
 	}()
 
-	err := c.cw.control(ctx, opPing, []byte(p))
+	err := c.writeControl(ctx, opPing, []byte(p))
 	if err != nil {
 		return err
 	}
diff --git a/conn_test.go b/conn_test.go
index 6b8a778..cf2334f 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -25,6 +25,7 @@ func TestConn(t *testing.T) {
 			c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 				Subprotocols:       []string{"echo"},
 				InsecureSkipVerify: true,
+				// CompressionMode: websocket.CompressionDisabled,
 			})
 			assert.Success(t, err)
 			defer c.Close(websocket.StatusInternalError, "")
@@ -41,12 +42,12 @@ func TestConn(t *testing.T) {
 
 		opts := &websocket.DialOptions{
 			Subprotocols: []string{"echo"},
+			// CompressionMode: websocket.CompressionDisabled,
 		}
 		opts.HTTPClient = s.Client()
 
 		c, _, err := websocket.Dial(ctx, wsURL, opts)
 		assert.Success(t, err)
-
 		assertJSONEcho(t, ctx, c, 2)
 	})
 }
diff --git a/internal/assert/assert.go b/internal/assert/assert.go
index 4ebdb51..b448711 100644
--- a/internal/assert/assert.go
+++ b/internal/assert/assert.go
@@ -23,7 +23,7 @@ func NotEqual(t testing.TB, exp, act interface{}, name string) {
 func Success(t testing.TB, err error) {
 	t.Helper()
 	if err != nil {
-		t.Fatalf("unexpected error : %+v", err)
+		t.Fatalf("unexpected error: %+v", err)
 	}
 }
 
diff --git a/read.go b/read.go
index 7dba832..d8691d6 100644
--- a/read.go
+++ b/read.go
@@ -1,7 +1,6 @@
 package websocket
 
 import (
-	"bufio"
 	"context"
 	"errors"
 	"fmt"
@@ -14,41 +13,22 @@ import (
 	"nhooyr.io/websocket/internal/errd"
 )
 
-// Reader waits until there is a WebSocket data message to read
-// from the connection.
-// It returns the type of the message and a reader to read it.
+// Reader reads from the connection until until there is a WebSocket
+// data message to be read. It will handle ping, pong and close frames as appropriate.
+//
+// It returns the type of the message and an io.Reader to read it.
 // The passed context will also bound the reader.
 // Ensure you read to EOF otherwise the connection will hang.
 //
-// All returned errors will cause the connection
-// to be closed so you do not need to write your own error message.
-// This applies to the Read methods in the wsjson/wspb subpackages as well.
-//
-// You must read from the connection for control frames to be handled.
-// Thus if you expect messages to take a long time to be responded to,
-// you should handle such messages async to reading from the connection
-// to ensure control frames are promptly handled.
-//
-// If you do not expect any data messages from the peer, call CloseRead.
+// Call CloseRead if you do not expect any data messages from the peer.
 //
 // Only one Reader may be open at a time.
-//
-// If you need a separate timeout on the Reader call and then the message
-// Read, use time.AfterFunc to cancel the context passed in early.
-// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
-// Most users should not need this.
 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
-	typ, r, err := c.cr.reader(ctx)
-	if err != nil {
-		return 0, nil, fmt.Errorf("failed to get reader: %w", err)
-	}
-	return typ, r, nil
+	return c.reader(ctx)
 }
 
-// Read is a convenience method to read a single message from the connection.
-//
-// See the Reader method to reuse buffers or for streaming.
-// The docs on Reader apply to this method as well.
+// Read is a convenience method around Reader to read a single message
+// from the connection.
 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
 	typ, r, err := c.Reader(ctx)
 	if err != nil {
@@ -59,14 +39,17 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
 	return typ, b, err
 }
 
-// CloseRead will start a goroutine to read from the connection until it is closed or a data message
-// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
-// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
-// After calling this method, you cannot read any data messages from the connection.
+// CloseRead starts a goroutine to read from the connection until it is closed
+// or a data message is received.
+//
+// Once CloseRead is called you cannot read any messages from the connection.
 // The returned context will be cancelled when the connection is closed.
 //
-// Use this when you do not want to read data messages from the connection anymore but will
-// want to write messages to it.
+// If a data message is received, the connection will be closed with StatusPolicyViolation.
+//
+// Call CloseRead when you do not expect to read any more messages.
+// Since it actively reads from the connection, it will ensure that ping, pong and close
+// frames are responded to.
 func (c *Conn) CloseRead(ctx context.Context) context.Context {
 	ctx, cancel := context.WithCancel(ctx)
 	go func() {
@@ -84,60 +67,32 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
 //
 // When the limit is hit, the connection will be closed with StatusMessageTooBig.
 func (c *Conn) SetReadLimit(n int64) {
-	c.cr.mr.lr.limit.Store(n)
-}
-
-type connReader struct {
-	c       *Conn
-	br      *bufio.Reader
-	timeout chan context.Context
-
-	mu                mu
-	controlPayloadBuf [maxControlPayload]byte
-	mr                *msgReader
-}
-
-func (cr *connReader) init(c *Conn, br *bufio.Reader) {
-	cr.c = c
-	cr.br = br
-	cr.timeout = make(chan context.Context)
-
-	cr.mr = &msgReader{
-		cr:  cr,
-		fin: true,
-	}
-
-	cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768)
-	if c.deflateNegotiated() && cr.contextTakeover() {
-		cr.ensureFlateReader()
-	}
+	c.msgReader.limitReader.setLimit(n)
 }
 
-func (cr *connReader) ensureFlateReader() {
-	cr.mr.fr = getFlateReader(readerFunc(cr.mr.read))
-	cr.mr.lr.reset(cr.mr.fr)
+func (mr *msgReader) ensureFlateReader() {
+	mr.flateReader = getFlateReader(readerFunc(mr.read))
+	mr.limitReader.reset(mr.flateReader)
 }
 
-func (cr *connReader) close() {
-	cr.mu.Lock(context.Background())
-	if cr.c.client {
-		putBufioReader(cr.br)
-	}
-	if cr.c.deflateNegotiated() && cr.contextTakeover() {
-		putFlateReader(cr.mr.fr)
+func (mr *msgReader) close() {
+	if mr.c.deflateNegotiated() && mr.contextTakeover() {
+		mr.c.readMu.Lock(context.Background())
+		putFlateReader(mr.flateReader)
+		mr.c.readMu.Unlock()
 	}
 }
 
-func (cr *connReader) contextTakeover() bool {
-	if cr.c.client {
-		return cr.c.copts.serverNoContextTakeover
+func (mr *msgReader) contextTakeover() bool {
+	if mr.c.client {
+		return mr.c.copts.serverNoContextTakeover
 	}
-	return cr.c.copts.clientNoContextTakeover
+	return mr.c.copts.clientNoContextTakeover
 }
 
-func (cr *connReader) rsv1Illegal(h header) bool {
+func (c *Conn) readRSV1Illegal(h header) bool {
 	// If compression is enabled, rsv1 is always illegal.
-	if !cr.c.deflateNegotiated() {
+	if !c.deflateNegotiated() {
 		return true
 	}
 	// rsv1 is only allowed on data frames beginning messages.
@@ -147,26 +102,26 @@ func (cr *connReader) rsv1Illegal(h header) bool {
 	return false
 }
 
-func (cr *connReader) loop(ctx context.Context) (header, error) {
+func (c *Conn) readLoop(ctx context.Context) (header, error) {
 	for {
-		h, err := cr.frameHeader(ctx)
+		h, err := c.readFrameHeader(ctx)
 		if err != nil {
 			return header{}, err
 		}
 
-		if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 {
+		if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
 			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
-			cr.c.cw.error(StatusProtocolError, err)
+			c.writeError(StatusProtocolError, err)
 			return header{}, err
 		}
 
-		if !cr.c.client && !h.masked {
+		if !c.client && !h.masked {
 			return header{}, errors.New("received unmasked frame from client")
 		}
 
 		switch h.opcode {
 		case opClose, opPing, opPong:
-			err = cr.control(ctx, h)
+			err = c.handleControl(ctx, h)
 			if err != nil {
 				// Pass through CloseErrors when receiving a close frame.
 				if h.opcode == opClose && CloseStatus(err) != -1 {
@@ -178,95 +133,89 @@ func (cr *connReader) loop(ctx context.Context) (header, error) {
 			return h, nil
 		default:
 			err := fmt.Errorf("received unknown opcode %v", h.opcode)
-			cr.c.cw.error(StatusProtocolError, err)
+			c.writeError(StatusProtocolError, err)
 			return header{}, err
 		}
 	}
 }
 
-func (cr *connReader) frameHeader(ctx context.Context) (header, error) {
+func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
 	select {
-	case <-cr.c.closed:
-		return header{}, cr.c.closeErr
-	case cr.timeout <- ctx:
+	case <-c.closed:
+		return header{}, c.closeErr
+	case c.readTimeout <- ctx:
 	}
 
-	h, err := readFrameHeader(cr.br)
+	h, err := readFrameHeader(c.br)
 	if err != nil {
 		select {
-		case <-cr.c.closed:
-			return header{}, cr.c.closeErr
+		case <-c.closed:
+			return header{}, c.closeErr
 		case <-ctx.Done():
 			return header{}, ctx.Err()
 		default:
-			cr.c.closeWithErr(err)
+			c.closeWithErr(err)
 			return header{}, err
 		}
 	}
 
 	select {
-	case <-cr.c.closed:
-		return header{}, cr.c.closeErr
-	case cr.timeout <- context.Background():
+	case <-c.closed:
+		return header{}, c.closeErr
+	case c.readTimeout <- context.Background():
 	}
 
 	return h, nil
 }
 
-func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) {
+func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
 	select {
-	case <-cr.c.closed:
-		return 0, cr.c.closeErr
-	case cr.timeout <- ctx:
+	case <-c.closed:
+		return 0, c.closeErr
+	case c.readTimeout <- ctx:
 	}
 
-	n, err := io.ReadFull(cr.br, p)
+	n, err := io.ReadFull(c.br, p)
 	if err != nil {
 		select {
-		case <-cr.c.closed:
-			return n, cr.c.closeErr
+		case <-c.closed:
+			return n, c.closeErr
 		case <-ctx.Done():
 			return n, ctx.Err()
 		default:
 			err = fmt.Errorf("failed to read frame payload: %w", err)
-			cr.c.closeWithErr(err)
+			c.closeWithErr(err)
 			return n, err
 		}
 	}
 
 	select {
-	case <-cr.c.closed:
-		return n, cr.c.closeErr
-	case cr.timeout <- context.Background():
+	case <-c.closed:
+		return n, c.closeErr
+	case c.readTimeout <- context.Background():
 	}
 
 	return n, err
 }
 
-func (cr *connReader) control(ctx context.Context, h header) error {
-	if h.payloadLength < 0 {
-		err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength)
-		cr.c.cw.error(StatusProtocolError, err)
-		return err
-	}
-
-	if h.payloadLength > maxControlPayload {
-		err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength)
-		cr.c.cw.error(StatusProtocolError, err)
+func (c *Conn) handleControl(ctx context.Context, h header) error {
+	if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
+		err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
+		c.writeError(StatusProtocolError, err)
 		return err
 	}
 
 	if !h.fin {
 		err := errors.New("received fragmented control frame")
-		cr.c.cw.error(StatusProtocolError, err)
+		c.writeError(StatusProtocolError, err)
 		return err
 	}
 
 	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
 	defer cancel()
 
-	b := cr.controlPayloadBuf[:h.payloadLength]
-	_, err := cr.framePayload(ctx, b)
+	b := c.readControlBuf[:h.payloadLength]
+	_, err := c.readFramePayload(ctx, b)
 	if err != nil {
 		return err
 	}
@@ -277,11 +226,11 @@ func (cr *connReader) control(ctx context.Context, h header) error {
 
 	switch h.opcode {
 	case opPing:
-		return cr.c.cw.control(ctx, opPong, b)
+		return c.writeControl(ctx, opPong, b)
 	case opPong:
-		cr.c.activePingsMu.Lock()
-		pong, ok := cr.c.activePings[string(b)]
-		cr.c.activePingsMu.Unlock()
+		c.activePingsMu.Lock()
+		pong, ok := c.activePings[string(b)]
+		c.activePingsMu.Unlock()
 		if ok {
 			close(pong)
 		}
@@ -291,53 +240,56 @@ func (cr *connReader) control(ctx context.Context, h header) error {
 	ce, err := parseClosePayload(b)
 	if err != nil {
 		err = fmt.Errorf("received invalid close payload: %w", err)
-		cr.c.cw.error(StatusProtocolError, err)
+		c.writeError(StatusProtocolError, err)
 		return err
 	}
 
 	err = fmt.Errorf("received close frame: %w", ce)
-	cr.c.setCloseErr(err)
-	cr.c.cw.control(context.Background(), opClose, ce.bytes())
+	c.setCloseErr(err)
+	c.writeControl(context.Background(), opClose, ce.bytes())
 	return err
 }
 
-func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) {
-	err := cr.mu.Lock(ctx)
+func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
+	defer errd.Wrap(&err, "failed to get reader")
+
+	err = c.readMu.Lock(ctx)
 	if err != nil {
 		return 0, nil, err
 	}
-	defer cr.mu.Unlock()
+	defer c.readMu.Unlock()
 
-	if !cr.mr.fin {
+	if !c.msgReader.fin {
 		return 0, nil, errors.New("previous message not read to completion")
 	}
 
-	h, err := cr.loop(ctx)
+	h, err := c.readLoop(ctx)
 	if err != nil {
 		return 0, nil, err
 	}
 
 	if h.opcode == opContinuation {
 		err := errors.New("received continuation frame without text or binary frame")
-		cr.c.cw.error(StatusProtocolError, err)
+		c.writeError(StatusProtocolError, err)
 		return 0, nil, err
 	}
 
-	cr.mr.reset(ctx, h)
+	c.msgReader.reset(ctx, h)
 
-	return MessageType(h.opcode), cr.mr, nil
+	return MessageType(h.opcode), c.msgReader, nil
 }
 
 type msgReader struct {
-	cr *connReader
-	fr io.Reader
-	lr *limitReader
+	c *Conn
 
 	ctx context.Context
 
 	deflate     bool
+	flateReader io.Reader
 	deflateTail strings.Reader
 
+	limitReader *limitReader
+
 	payloadLength int64
 	maskKey       uint32
 	fin           bool
@@ -348,8 +300,8 @@ func (mr *msgReader) reset(ctx context.Context, h header) {
 	mr.deflate = h.rsv1
 	if mr.deflate {
 		mr.deflateTail.Reset(deflateMessageTail)
-		if !mr.cr.contextTakeover() {
-			mr.cr.ensureFlateReader()
+		if !mr.contextTakeover() {
+			mr.ensureFlateReader()
 		}
 	}
 	mr.setFrame(h)
@@ -370,34 +322,42 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) {
 		}
 	}()
 
-	err = mr.cr.mu.Lock(mr.ctx)
+	err = mr.c.readMu.Lock(mr.ctx)
 	if err != nil {
 		return 0, err
 	}
-	defer mr.cr.mu.Unlock()
+	defer mr.c.readMu.Unlock()
 
 	if mr.payloadLength == 0 && mr.fin {
-		if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() {
-			if mr.fr != nil {
-				putFlateReader(mr.fr)
-				mr.fr = nil
+		if mr.c.deflateNegotiated() && !mr.contextTakeover() {
+			if mr.flateReader != nil {
+				putFlateReader(mr.flateReader)
+				mr.flateReader = nil
 			}
 		}
 		return 0, io.EOF
 	}
 
-	return mr.lr.Read(p)
+	return mr.limitReader.Read(p)
 }
 
 func (mr *msgReader) read(p []byte) (int, error) {
 	if mr.payloadLength == 0 {
-		h, err := mr.cr.loop(mr.ctx)
+		if mr.fin {
+			if mr.deflate {
+				n, _ := mr.deflateTail.Read(p[:4])
+				return n, nil
+			}
+			return 0, io.EOF
+		}
+
+		h, err := mr.c.readLoop(mr.ctx)
 		if err != nil {
 			return 0, err
 		}
 		if h.opcode != opContinuation {
 			err := errors.New("received new data message without finishing the previous message")
-			mr.cr.c.cw.error(StatusProtocolError, err)
+			mr.c.writeError(StatusProtocolError, err)
 			return 0, err
 		}
 		mr.setFrame(h)
@@ -407,14 +367,14 @@ func (mr *msgReader) read(p []byte) (int, error) {
 		p = p[:mr.payloadLength]
 	}
 
-	n, err := mr.cr.framePayload(mr.ctx, p)
+	n, err := mr.c.readFramePayload(mr.ctx, p)
 	if err != nil {
 		return n, err
 	}
 
 	mr.payloadLength -= int64(n)
 
-	if !mr.cr.c.client {
+	if !mr.c.client {
 		mr.maskKey = mask(mr.maskKey, p)
 	}
 
@@ -442,10 +402,14 @@ func (lr *limitReader) reset(r io.Reader) {
 	lr.r = r
 }
 
+func (lr *limitReader) setLimit(limit int64) {
+	lr.limit.Store(limit)
+}
+
 func (lr *limitReader) Read(p []byte) (int, error) {
 	if lr.n <= 0 {
 		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
-		lr.c.cw.error(StatusMessageTooBig, err)
+		lr.c.writeError(StatusMessageTooBig, err)
 		return 0, err
 	}
 
diff --git a/write.go b/write.go
index 9cafc5c..0ddf11e 100644
--- a/write.go
+++ b/write.go
@@ -24,7 +24,7 @@ import (
 //
 // Never close the returned writer twice.
 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
-	w, err := c.cw.writer(ctx, typ)
+	w, err := c.writer(ctx, typ)
 	if err != nil {
 		return nil, fmt.Errorf("failed to get writer: %w", err)
 	}
@@ -38,111 +38,68 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
 // If compression is disabled, then it is guaranteed to write the message
 // in a single frame.
 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
-	_, err := c.cw.write(ctx, typ, p)
+	_, err := c.write(ctx, typ, p)
 	if err != nil {
 		return fmt.Errorf("failed to write msg: %w", err)
 	}
 	return nil
 }
 
-type connWriter struct {
-	c  *Conn
-	bw *bufio.Writer
-
-	writeBuf []byte
-
-	mw      *messageWriter
-	frameMu mu
-	h       header
-
-	timeout chan context.Context
+func (mw *msgWriter) ensureFlateWriter() {
+	mw.flateWriter = getFlateWriter(mw.trimWriter)
 }
 
-func (cw *connWriter) init(c *Conn, bw *bufio.Writer) {
-	cw.c = c
-	cw.bw = bw
-
-	if cw.c.client {
-		cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc)
-	}
-
-	cw.timeout = make(chan context.Context)
-
-	cw.mw = &messageWriter{
-		cw: cw,
+func (mw *msgWriter) contextTakeover() bool {
+	if mw.c.client {
+		return mw.c.copts.clientNoContextTakeover
 	}
-	cw.mw.tw = &trimLastFourBytesWriter{
-		w: writerFunc(cw.mw.write),
-	}
-	if cw.c.deflateNegotiated() && cw.mw.contextTakeover() {
-		cw.mw.ensureFlateWriter()
-	}
-}
-
-func (mw *messageWriter) ensureFlateWriter() {
-	mw.fw = getFlateWriter(mw.tw)
+	return mw.c.copts.serverNoContextTakeover
 }
 
-func (cw *connWriter) close() {
-	if cw.c.client {
-		cw.frameMu.Lock(context.Background())
-		putBufioWriter(cw.bw)
-	}
-	if cw.c.deflateNegotiated() && cw.mw.contextTakeover() {
-		cw.mw.mu.Lock(context.Background())
-		putFlateWriter(cw.mw.fw)
-	}
-}
-
-func (mw *messageWriter) contextTakeover() bool {
-	if mw.cw.c.client {
-		return mw.cw.c.copts.clientNoContextTakeover
-	}
-	return mw.cw.c.copts.serverNoContextTakeover
-}
-
-func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
-	err := cw.mw.reset(ctx, typ)
+func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
+	err := c.msgWriter.reset(ctx, typ)
 	if err != nil {
 		return nil, err
 	}
-	return cw.mw, nil
+	return c.msgWriter, nil
 }
 
-func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
-	ww, err := cw.writer(ctx, typ)
+func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
+	mw, err := c.writer(ctx, typ)
 	if err != nil {
 		return 0, err
 	}
 
-	if !cw.c.deflateNegotiated() {
+	if !c.deflateNegotiated() {
 		// Fast single frame path.
-		defer cw.mw.mu.Unlock()
-		return cw.frame(ctx, true, cw.mw.opcode, p)
+		defer c.msgWriter.mu.Unlock()
+		return c.writeFrame(ctx, true, c.msgWriter.opcode, p)
 	}
 
-	n, err := ww.Write(p)
+	n, err := mw.Write(p)
 	if err != nil {
 		return n, err
 	}
 
-	err = ww.Close()
+	err = mw.Close()
 	return n, err
 }
 
-type messageWriter struct {
-	cw *connWriter
+type msgWriter struct {
+	c *Conn
 
-	mu       mu
-	compress bool
-	tw       *trimLastFourBytesWriter
-	fw       *flate.Writer
-	ctx      context.Context
-	opcode   opcode
-	closed   bool
+	mu      mu
+
+	deflate bool
+	ctx     context.Context
+	opcode  opcode
+	closed  bool
+
+	trimWriter   *trimLastFourBytesWriter
+	flateWriter  *flate.Writer
 }
 
-func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error {
+func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
 	err := mw.mu.Lock(ctx)
 	if err != nil {
 		return err
@@ -155,30 +112,30 @@ func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error {
 }
 
 // Write writes the given bytes to the WebSocket connection.
-func (mw *messageWriter) Write(p []byte) (_ int, err error) {
+func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 	defer errd.Wrap(&err, "failed to write")
 
 	if mw.closed {
 		return 0, errors.New("cannot use closed writer")
 	}
 
-	if mw.cw.c.deflateNegotiated() {
-		if !mw.compress {
+	if mw.c.deflateNegotiated() {
+		if !mw.deflate {
 			if !mw.contextTakeover() {
 				mw.ensureFlateWriter()
 			}
-			mw.tw.reset()
-			mw.compress = true
+			mw.trimWriter.reset()
+			mw.deflate = true
 		}
 
-		return mw.fw.Write(p)
+		return mw.flateWriter.Write(p)
 	}
 
 	return mw.write(p)
 }
 
-func (mw *messageWriter) write(p []byte) (int, error) {
-	n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p)
+func (mw *msgWriter) write(p []byte) (int, error) {
+	n, err := mw.c.writeFrame(mw.ctx, false, mw.opcode, p)
 	if err != nil {
 		return n, fmt.Errorf("failed to write data frame: %w", err)
 	}
@@ -187,8 +144,7 @@ func (mw *messageWriter) write(p []byte) (int, error) {
 }
 
 // Close flushes the frame to the connection.
-// This must be called for every messageWriter.
-func (mw *messageWriter) Close() (err error) {
+func (mw *msgWriter) Close() (err error) {
 	defer errd.Wrap(&err, "failed to close writer")
 
 	if mw.closed {
@@ -196,32 +152,39 @@ func (mw *messageWriter) Close() (err error) {
 	}
 	mw.closed = true
 
-	if mw.cw.c.deflateNegotiated() {
-		err = mw.fw.Flush()
+	if mw.c.deflateNegotiated() {
+		err = mw.flateWriter.Flush()
 		if err != nil {
 			return fmt.Errorf("failed to flush flate writer: %w", err)
 		}
 	}
 
-	_, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil)
+	_, err = mw.c.writeFrame(mw.ctx, true, mw.opcode, nil)
 	if err != nil {
 		return fmt.Errorf("failed to write fin frame: %w", err)
 	}
 
-	if mw.compress && !mw.contextTakeover() {
-		putFlateWriter(mw.fw)
-		mw.compress = false
+	if mw.deflate && !mw.contextTakeover() {
+		putFlateWriter(mw.flateWriter)
+		mw.deflate = false
 	}
 
 	mw.mu.Unlock()
 	return nil
 }
 
-func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error {
+func (cw *msgWriter) close() {
+	if cw.c.deflateNegotiated() && cw.contextTakeover() {
+		cw.mu.Lock(context.Background())
+		putFlateWriter(cw.flateWriter)
+	}
+}
+
+func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
 	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
 	defer cancel()
 
-	_, err := cw.frame(ctx, true, opcode, p)
+	_, err := c.writeFrame(ctx, true, opcode, p)
 	if err != nil {
 		return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
 	}
@@ -229,94 +192,94 @@ func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) erro
 }
 
 // frame handles all writes to the connection.
-func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
-	err := cw.frameMu.Lock(ctx)
+func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
+	err := c.writeFrameMu.Lock(ctx)
 	if err != nil {
 		return 0, err
 	}
-	defer cw.frameMu.Unlock()
+	defer c.writeFrameMu.Unlock()
 
 	select {
-	case <-cw.c.closed:
-		return 0, cw.c.closeErr
-	case cw.timeout <- ctx:
+	case <-c.closed:
+		return 0, c.closeErr
+	case c.writeTimeout <- ctx:
 	}
 
-	cw.h.fin = fin
-	cw.h.opcode = opcode
-	cw.h.masked = cw.c.client
-	cw.h.payloadLength = int64(len(p))
-
-	cw.h.rsv1 = false
-	if cw.mw.compress && (opcode == opText || opcode == opBinary) {
-		cw.h.rsv1 = true
-	}
+	c.writeHeader.fin = fin
+	c.writeHeader.opcode = opcode
+	c.writeHeader.payloadLength = int64(len(p))
 
-	if cw.h.masked {
-		err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey)
+	if c.client {
+		c.writeHeader.masked = true
+		err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey)
 		if err != nil {
 			return 0, fmt.Errorf("failed to generate masking key: %w", err)
 		}
 	}
 
-	err = writeFrameHeader(cw.h, cw.bw)
+	c.writeHeader.rsv1 = false
+	if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) {
+		c.writeHeader.rsv1 = true
+	}
+
+	err = writeFrameHeader(c.writeHeader, c.bw)
 	if err != nil {
 		return 0, err
 	}
 
-	n, err := cw.framePayload(p)
+	n, err := c.writeFramePayload(p)
 	if err != nil {
 		return n, err
 	}
 
-	if cw.h.fin {
-		err = cw.bw.Flush()
+	if c.writeHeader.fin {
+		err = c.bw.Flush()
 		if err != nil {
 			return n, fmt.Errorf("failed to flush: %w", err)
 		}
 	}
 
 	select {
-	case <-cw.c.closed:
-		return n, cw.c.closeErr
-	case cw.timeout <- context.Background():
+	case <-c.closed:
+		return n, c.closeErr
+	case c.writeTimeout <- context.Background():
 	}
 
 	return n, nil
 }
 
-func (cw *connWriter) framePayload(p []byte) (_ int, err error) {
+func (c *Conn) writeFramePayload(p []byte) (_ int, err error) {
 	defer errd.Wrap(&err, "failed to write frame payload")
 
-	if !cw.h.masked {
-		return cw.bw.Write(p)
+	if !c.writeHeader.masked {
+		return c.bw.Write(p)
 	}
 
 	var n int
-	maskKey := cw.h.maskKey
+	maskKey := c.writeHeader.maskKey
 	for len(p) > 0 {
 		// If the buffer is full, we need to flush.
-		if cw.bw.Available() == 0 {
-			err = cw.bw.Flush()
+		if c.bw.Available() == 0 {
+			err = c.bw.Flush()
 			if err != nil {
 				return n, err
 			}
 		}
 
 		// Start of next write in the buffer.
-		i := cw.bw.Buffered()
+		i := c.bw.Buffered()
 
 		j := len(p)
-		if j > cw.bw.Available() {
-			j = cw.bw.Available()
+		if j > c.bw.Available() {
+			j = c.bw.Available()
 		}
 
-		_, err := cw.bw.Write(p[:j])
+		_, err := c.bw.Write(p[:j])
 		if err != nil {
 			return n, err
 		}
 
-		maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()])
+		maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
 
 		p = p[j:]
 		n += j
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
index 99996a6..36dd2df 100644
--- a/wsjson/wsjson.go
+++ b/wsjson/wsjson.go
@@ -5,7 +5,6 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
-
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket/internal/bpool"
 	"nhooyr.io/websocket/internal/errd"
-- 
GitLab