diff --git a/compress_notjs.go b/compress_notjs.go
index 7c6b2fc013041efec8b7353de38001c2b8c59cf6..a61b7ba472dc85dd98b9a36ccecb2886945e1068 100644
--- a/compress_notjs.go
+++ b/compress_notjs.go
@@ -3,10 +3,11 @@
 package websocket
 
 import (
-	"compress/flate"
 	"io"
 	"net/http"
 	"sync"
+
+	"github.com/klauspost/compress/flate"
 )
 
 func (m CompressionMode) opts() *compressionOptions {
@@ -45,10 +46,16 @@ type trimLastFourBytesWriter struct {
 }
 
 func (tw *trimLastFourBytesWriter) reset() {
-	tw.tail = tw.tail[:0]
+	if tw != nil && tw.tail != nil {
+		tw.tail = tw.tail[:0]
+	}
 }
 
 func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
+	if tw.tail == nil {
+		tw.tail = make([]byte, 0, 4)
+	}
+
 	extra := len(tw.tail) + len(p) - 4
 
 	if extra <= 0 {
@@ -65,7 +72,10 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
 		if err != nil {
 			return 0, err
 		}
-		tw.tail = tw.tail[extra:]
+
+		// Shift remaining bytes in tail over.
+		n := copy(tw.tail, tw.tail[extra:])
+		tw.tail = tw.tail[:n]
 	}
 
 	// If p is less than or equal to 4 bytes,
@@ -118,22 +128,32 @@ type slidingWindow struct {
 	buf []byte
 }
 
-var swPoolMu sync.Mutex
+var swPoolMu sync.RWMutex
 var swPool = map[int]*sync.Pool{}
 
-func (sw *slidingWindow) init(n int) {
-	if sw.buf != nil {
-		return
+func slidingWindowPool(n int) *sync.Pool {
+	swPoolMu.RLock()
+	p, ok := swPool[n]
+	swPoolMu.RUnlock()
+	if ok {
+		return p
 	}
 
+	p = &sync.Pool{}
+
 	swPoolMu.Lock()
-	defer swPoolMu.Unlock()
+	swPool[n] = p
+	swPoolMu.Unlock()
 
-	p, ok := swPool[n]
-	if !ok {
-		p = &sync.Pool{}
-		swPool[n] = p
+	return p
+}
+
+func (sw *slidingWindow) init(n int) {
+	if sw.buf != nil {
+		return
 	}
+
+	p := slidingWindowPool(n)
 	buf, ok := p.Get().([]byte)
 	if ok {
 		sw.buf = buf[:0]
diff --git a/conn_notjs.go b/conn_notjs.go
index 178fcad02a61574b3e1a519ddaaebd39742854df..e6ff7df362d08e400ca26f1225a4bada8a7a56eb 100644
--- a/conn_notjs.go
+++ b/conn_notjs.go
@@ -39,16 +39,17 @@ type Conn struct {
 
 	// Read state.
 	readMu            *mu
-	readHeader        header
+	readHeaderBuf     [8]byte
 	readControlBuf    [maxControlPayload]byte
 	msgReader         *msgReader
 	readCloseFrameErr error
 
 	// Write state.
-	msgWriter    *msgWriter
-	writeFrameMu *mu
-	writeBuf     []byte
-	writeHeader  header
+	msgWriterState *msgWriterState
+	writeFrameMu   *mu
+	writeBuf       []byte
+	writeHeaderBuf [8]byte
+	writeHeader    header
 
 	closed     chan struct{}
 	closeMu    sync.Mutex
@@ -94,14 +95,14 @@ func newConn(cfg connConfig) *Conn {
 
 	c.msgReader = newMsgReader(c)
 
-	c.msgWriter = newMsgWriter(c)
+	c.msgWriterState = newMsgWriterState(c)
 	if c.client {
 		c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
 	}
 
 	if c.flate() && c.flateThreshold == 0 {
 		c.flateThreshold = 256
-		if !c.msgWriter.flateContextTakeover() {
+		if !c.msgWriterState.flateContextTakeover() {
 			c.flateThreshold = 512
 		}
 	}
@@ -142,7 +143,7 @@ func (c *Conn) close(err error) {
 			c.writeFrameMu.Lock(context.Background())
 			putBufioWriter(c.bw)
 		}
-		c.msgWriter.close()
+		c.msgWriterState.close()
 
 		c.msgReader.close()
 		if c.client {
diff --git a/conn_test.go b/conn_test.go
index 265156e970cd1cf0c753706580c532f98d8b19f1..398ffd5181e50b5519b977bbd739162ad714d77b 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -5,7 +5,6 @@ package websocket_test
 import (
 	"bytes"
 	"context"
-	"crypto/rand"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -13,6 +12,7 @@ import (
 	"net/http/httptest"
 	"os"
 	"os/exec"
+	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -379,15 +379,15 @@ func BenchmarkConn(b *testing.B) {
 		mode websocket.CompressionMode
 	}{
 		{
-			name: "compressionDisabled",
+			name: "disabledCompress",
 			mode: websocket.CompressionDisabled,
 		},
 		{
-			name: "compression",
+			name: "compress",
 			mode: websocket.CompressionContextTakeover,
 		},
 		{
-			name: "noContextCompression",
+			name: "compressNoContext",
 			mode: websocket.CompressionNoContextTakeover,
 		},
 	}
@@ -395,44 +395,36 @@ func BenchmarkConn(b *testing.B) {
 		b.Run(bc.name, func(b *testing.B) {
 			bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
 				CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode},
-			}, nil)
+			}, &websocket.AcceptOptions{
+				CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode},
+			})
 			defer bb.cleanup()
 
 			bb.goEchoLoop(c2)
 
-			const n = 32768
-			writeBuf := make([]byte, n)
-			readBuf := make([]byte, n)
-			writes := make(chan websocket.MessageType)
+			msg := []byte(strings.Repeat("1234", 128))
+			readBuf := make([]byte, len(msg))
+			writes := make(chan struct{})
 			defer close(writes)
 			werrs := make(chan error)
 
 			go func() {
-				for typ := range writes {
-					werrs <- c1.Write(bb.ctx, typ, writeBuf)
+				for range writes {
+					werrs <- c1.Write(bb.ctx, websocket.MessageText, msg)
 				}
 			}()
-			b.SetBytes(n)
+			b.SetBytes(int64(len(msg)))
 			b.ReportAllocs()
 			b.ResetTimer()
 			for i := 0; i < b.N; i++ {
-				_, err := rand.Reader.Read(writeBuf)
-				if err != nil {
-					b.Fatal(err)
-				}
-
-				expType := websocket.MessageBinary
-				if writeBuf[0]%2 == 1 {
-					expType = websocket.MessageText
-				}
-				writes <- expType
+				writes <- struct{}{}
 
 				typ, r, err := c1.Reader(bb.ctx)
 				if err != nil {
 					b.Fatal(err)
 				}
-				if expType != typ {
-					assert.Equal(b, "data type", expType, typ)
+				if websocket.MessageText != typ {
+					assert.Equal(b, "data type", websocket.MessageText, typ)
 				}
 
 				_, err = io.ReadFull(r, readBuf)
@@ -448,8 +440,8 @@ func BenchmarkConn(b *testing.B) {
 					assert.Equal(b, "n2", 0, n2)
 				}
 
-				if !bytes.Equal(writeBuf, readBuf) {
-					assert.Equal(b, "msg", writeBuf, readBuf)
+				if !bytes.Equal(msg, readBuf) {
+					assert.Equal(b, "msg", msg, readBuf)
 				}
 
 				err = <-werrs
@@ -464,3 +456,8 @@ func BenchmarkConn(b *testing.B) {
 		})
 	}
 }
+
+func TestCompression(t *testing.T) {
+	t.Parallel()
+
+}
diff --git a/frame.go b/frame.go
index 491ae75c33c5bf6ed4f0274191fa8a4b1ab20cff..4acaecf43ff6c9a8eedad12b0e5a52a29af14ae5 100644
--- a/frame.go
+++ b/frame.go
@@ -3,9 +3,12 @@ package websocket
 import (
 	"bufio"
 	"encoding/binary"
+	"io"
 	"math"
 	"math/bits"
 
+	"golang.org/x/xerrors"
+
 	"nhooyr.io/websocket/internal/errd"
 )
 
@@ -46,12 +49,12 @@ type header struct {
 
 // readFrameHeader reads a header from the reader.
 // See https://tools.ietf.org/html/rfc6455#section-5.2.
-func readFrameHeader(h *header, r *bufio.Reader) (err error) {
+func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
 	defer errd.Wrap(&err, "failed to read frame header")
 
 	b, err := r.ReadByte()
 	if err != nil {
-		return err
+		return header{}, err
 	}
 
 	h.fin = b&(1<<7) != 0
@@ -63,7 +66,7 @@ func readFrameHeader(h *header, r *bufio.Reader) (err error) {
 
 	b, err = r.ReadByte()
 	if err != nil {
-		return err
+		return header{}, err
 	}
 
 	h.masked = b&(1<<7) != 0
@@ -73,24 +76,29 @@ func readFrameHeader(h *header, r *bufio.Reader) (err error) {
 	case payloadLength < 126:
 		h.payloadLength = int64(payloadLength)
 	case payloadLength == 126:
-		var pl uint16
-		err = binary.Read(r, binary.BigEndian, &pl)
-		h.payloadLength = int64(pl)
+		_, err = io.ReadFull(r, readBuf[:2])
+		h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
 	case payloadLength == 127:
-		err = binary.Read(r, binary.BigEndian, &h.payloadLength)
+		_, err = io.ReadFull(r, readBuf)
+		h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
 	}
 	if err != nil {
-		return err
+		return header{}, err
+	}
+
+	if h.payloadLength < 0 {
+		return header{}, xerrors.Errorf("received negative payload length: %v", h.payloadLength)
 	}
 
 	if h.masked {
-		err = binary.Read(r, binary.LittleEndian, &h.maskKey)
+		_, err = io.ReadFull(r, readBuf[:4])
 		if err != nil {
-			return err
+			return header{}, err
 		}
+		h.maskKey = binary.LittleEndian.Uint32(readBuf)
 	}
 
-	return nil
+	return h, nil
 }
 
 // maxControlPayload is the maximum length of a control frame payload.
@@ -99,7 +107,7 @@ const maxControlPayload = 125
 
 // writeFrameHeader writes the bytes of the header to w.
 // See https://tools.ietf.org/html/rfc6455#section-5.2
-func writeFrameHeader(h header, w *bufio.Writer) (err error) {
+func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
 	defer errd.Wrap(&err, "failed to write frame header")
 
 	var b byte
@@ -143,16 +151,19 @@ func writeFrameHeader(h header, w *bufio.Writer) (err error) {
 
 	switch {
 	case h.payloadLength > math.MaxUint16:
-		err = binary.Write(w, binary.BigEndian, h.payloadLength)
+		binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
+		_, err = w.Write(buf)
 	case h.payloadLength > 125:
-		err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength))
+		binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
+		_, err = w.Write(buf[:2])
 	}
 	if err != nil {
 		return err
 	}
 
 	if h.masked {
-		err = binary.Write(w, binary.LittleEndian, h.maskKey)
+		binary.LittleEndian.PutUint32(buf, h.maskKey)
+		_, err = w.Write(buf[:4])
 		if err != nil {
 			return err
 		}
diff --git a/frame_test.go b/frame_test.go
index 38f1599a890cb52bda60a5d194c15e1449e5dbac..76826248d040e4696e0603b4d28f6a0159b7e363 100644
--- a/frame_test.go
+++ b/frame_test.go
@@ -80,14 +80,13 @@ func testHeader(t *testing.T, h header) {
 	w := bufio.NewWriter(b)
 	r := bufio.NewReader(b)
 
-	err := writeFrameHeader(h, w)
+	err := writeFrameHeader(h, w, make([]byte, 8))
 	assert.Success(t, err)
 
 	err = w.Flush()
 	assert.Success(t, err)
 
-	var h2 header
-	err = readFrameHeader(&h2, r)
+	h2, err := readFrameHeader(r, make([]byte, 8))
 	assert.Success(t, err)
 
 	assert.Equal(t, "read header", h, h2)
diff --git a/go.mod b/go.mod
index cb37239165d515dd37e5d3d08dcc4c31d6e8d44c..a10c7b1e3e53dde31226a820586acdb7f6624fa8 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
 	github.com/golang/protobuf v1.3.3
 	github.com/google/go-cmp v0.4.0
 	github.com/gorilla/websocket v1.4.1
+	github.com/klauspost/compress v1.10.0
 	golang.org/x/time v0.0.0-20191024005414-555d28b269f0
 	golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
 )
diff --git a/go.sum b/go.sum
index 8cbc66ce14e6244c32295b76f4df68b7306a99fc..e4bbd62d337c4edbd4e71bd740271cf2544a0467 100644
--- a/go.sum
+++ b/go.sum
@@ -10,6 +10,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
 github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y=
+github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
 golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
diff --git a/internal/xsync/go.go b/internal/xsync/go.go
index d88ac622c5ddd0ebb202a1c5759aa098a14a3c1c..712739aa2fa5b6484629c47a5aa41c1a80f1bd7b 100644
--- a/internal/xsync/go.go
+++ b/internal/xsync/go.go
@@ -6,7 +6,7 @@ import (
 
 // Go allows running a function in another goroutine
 // and waiting for its error.
-func Go(fn func() error) <- chan error {
+func Go(fn func() error) <-chan error {
 	errs := make(chan error, 1)
 	go func() {
 		defer func() {
diff --git a/read.go b/read.go
index bf7fa6d928835a98b50a43e3f2765406e04e1201..bbad30d14e59f43be73cab8176b9c0cdd14f6216 100644
--- a/read.go
+++ b/read.go
@@ -3,6 +3,7 @@
 package websocket
 
 import (
+	"bufio"
 	"context"
 	"io"
 	"io/ioutil"
@@ -81,8 +82,9 @@ func newMsgReader(c *Conn) *msgReader {
 		c:   c,
 		fin: true,
 	}
+	mr.readFunc = mr.read
 
-	mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit+1)
+	mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
 	return mr
 }
 
@@ -90,13 +92,16 @@ func (mr *msgReader) resetFlate() {
 	if mr.flateContextTakeover() {
 		mr.dict.init(32768)
 	}
+	if mr.flateBufio == nil {
+		mr.flateBufio = getBufioReader(mr.readFunc)
+	}
 
-	mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
+	mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
 	mr.limitReader.r = mr.flateReader
 	mr.flateTail.Reset(deflateMessageTail)
 }
 
-func (mr *msgReader) returnFlateReader() {
+func (mr *msgReader) putFlateReader() {
 	if mr.flateReader != nil {
 		putFlateReader(mr.flateReader)
 		mr.flateReader = nil
@@ -105,9 +110,11 @@ func (mr *msgReader) returnFlateReader() {
 
 func (mr *msgReader) close() {
 	mr.c.readMu.Lock(context.Background())
-	mr.returnFlateReader()
-
+	mr.putFlateReader()
 	mr.dict.close()
+	if mr.flateBufio != nil {
+		putBufioReader(mr.flateBufio)
+	}
 }
 
 func (mr *msgReader) flateContextTakeover() bool {
@@ -173,7 +180,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
 	case c.readTimeout <- ctx:
 	}
 
-	err := readFrameHeader(&c.readHeader, c.br)
+	h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
 	if err != nil {
 		select {
 		case <-c.closed:
@@ -192,7 +199,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
 	case c.readTimeout <- context.Background():
 	}
 
-	return c.readHeader, nil
+	return h, nil
 }
 
 func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
@@ -317,6 +324,7 @@ type msgReader struct {
 	ctx         context.Context
 	flate       bool
 	flateReader io.Reader
+	flateBufio  *bufio.Reader
 	flateTail   strings.Reader
 	limitReader *limitReader
 	dict        slidingWindow
@@ -324,12 +332,15 @@ type msgReader struct {
 	fin           bool
 	payloadLength int64
 	maskKey       uint32
+
+	// readerFunc(mr.Read) to avoid continuous allocations.
+	readFunc readerFunc
 }
 
 func (mr *msgReader) reset(ctx context.Context, h header) {
 	mr.ctx = ctx
 	mr.flate = h.rsv1
-	mr.limitReader.reset(readerFunc(mr.read))
+	mr.limitReader.reset(mr.readFunc)
 
 	if mr.flate {
 		mr.resetFlate()
@@ -346,15 +357,15 @@ func (mr *msgReader) setFrame(h header) {
 
 func (mr *msgReader) Read(p []byte) (n int, err error) {
 	defer func() {
-		errd.Wrap(&err, "failed to read")
 		if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
 			err = io.EOF
 		}
 		if xerrors.Is(err, io.EOF) {
 			err = io.EOF
-
-			mr.returnFlateReader()
+			mr.putFlateReader()
+			return
 		}
+		errd.Wrap(&err, "failed to read")
 	}()
 
 	err = mr.c.readMu.Lock(mr.ctx)
@@ -372,44 +383,46 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
 }
 
 func (mr *msgReader) read(p []byte) (int, error) {
-	if mr.payloadLength == 0 {
-		if mr.fin {
-			if mr.flate {
-				return mr.flateTail.Read(p)
+	for {
+		if mr.payloadLength == 0 {
+			if mr.fin {
+				if mr.flate {
+					return mr.flateTail.Read(p)
+				}
+				return 0, io.EOF
 			}
-			return 0, io.EOF
-		}
 
-		h, err := mr.c.readLoop(mr.ctx)
-		if err != nil {
-			return 0, err
-		}
-		if h.opcode != opContinuation {
-			err := xerrors.New("received new data message without finishing the previous message")
-			mr.c.writeError(StatusProtocolError, err)
-			return 0, err
+			h, err := mr.c.readLoop(mr.ctx)
+			if err != nil {
+				return 0, err
+			}
+			if h.opcode != opContinuation {
+				err := xerrors.New("received new data message without finishing the previous message")
+				mr.c.writeError(StatusProtocolError, err)
+				return 0, err
+			}
+			mr.setFrame(h)
+
+			continue
 		}
-		mr.setFrame(h)
 
-		return mr.read(p)
-	}
+		if int64(len(p)) > mr.payloadLength {
+			p = p[:mr.payloadLength]
+		}
 
-	if int64(len(p)) > mr.payloadLength {
-		p = p[:mr.payloadLength]
-	}
+		n, err := mr.c.readFramePayload(mr.ctx, p)
+		if err != nil {
+			return n, err
+		}
 
-	n, err := mr.c.readFramePayload(mr.ctx, p)
-	if err != nil {
-		return n, err
-	}
+		mr.payloadLength -= int64(n)
 
-	mr.payloadLength -= int64(n)
+		if !mr.c.client {
+			mr.maskKey = mask(mr.maskKey, p)
+		}
 
-	if !mr.c.client {
-		mr.maskKey = mask(mr.maskKey, p)
+		return n, nil
 	}
-
-	return n, nil
 }
 
 type limitReader struct {
diff --git a/write.go b/write.go
index 9d4b670f84985570b2334cd90be5169771ca7289..ec3b7d059d4d31bc0597d0f9a5ca42d3a2d654f0 100644
--- a/write.go
+++ b/write.go
@@ -4,7 +4,6 @@ package websocket
 
 import (
 	"bufio"
-	"compress/flate"
 	"context"
 	"crypto/rand"
 	"encoding/binary"
@@ -12,6 +11,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/klauspost/compress/flate"
+	kflate "github.com/klauspost/compress/flate"
 	"golang.org/x/xerrors"
 
 	"nhooyr.io/websocket/internal/errd"
@@ -24,8 +25,6 @@ import (
 //
 // Only one writer can be open at a time, multiple calls will block until the previous writer
 // is closed.
-//
-// Never close the returned writer twice.
 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
 	w, err := c.writer(ctx, typ)
 	if err != nil {
@@ -49,6 +48,26 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
 }
 
 type msgWriter struct {
+	mw     *msgWriterState
+	closed bool
+}
+
+func (mw *msgWriter) Write(p []byte) (int, error) {
+	if mw.closed {
+		return 0, xerrors.New("cannot use closed writer")
+	}
+	return mw.mw.Write(p)
+}
+
+func (mw *msgWriter) Close() error {
+	if mw.closed {
+		return xerrors.New("cannot use closed writer")
+	}
+	mw.closed = true
+	return mw.mw.Close()
+}
+
+type msgWriterState struct {
 	c *Conn
 
 	mu      *mu
@@ -56,36 +75,42 @@ type msgWriter struct {
 
 	ctx    context.Context
 	opcode opcode
-	closed bool
 	flate  bool
 
 	trimWriter  *trimLastFourBytesWriter
 	flateWriter *flate.Writer
+	dict        slidingWindow
 }
 
-func newMsgWriter(c *Conn) *msgWriter {
-	mw := &msgWriter{
+func newMsgWriterState(c *Conn) *msgWriterState {
+	mw := &msgWriterState{
 		c:  c,
 		mu: newMu(c),
 	}
 	return mw
 }
 
-func (mw *msgWriter) ensureFlate() {
+const stateless = true
+
+func (mw *msgWriterState) ensureFlate() {
 	if mw.trimWriter == nil {
 		mw.trimWriter = &trimLastFourBytesWriter{
 			w: writerFunc(mw.write),
 		}
 	}
 
-	if mw.flateWriter == nil {
-		mw.flateWriter = getFlateWriter(mw.trimWriter)
+	if stateless {
+		mw.dict.init(8192)
+	} else {
+		if mw.flateWriter == nil {
+			mw.flateWriter = getFlateWriter(mw.trimWriter)
+		}
 	}
 
 	mw.flate = true
 }
 
-func (mw *msgWriter) flateContextTakeover() bool {
+func (mw *msgWriterState) flateContextTakeover() bool {
 	if mw.c.client {
 		return !mw.c.copts.clientNoContextTakeover
 	}
@@ -93,11 +118,14 @@ func (mw *msgWriter) flateContextTakeover() bool {
 }
 
 func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
-	err := c.msgWriter.reset(ctx, typ)
+	err := c.msgWriterState.reset(ctx, typ)
 	if err != nil {
 		return nil, err
 	}
-	return c.msgWriter, nil
+	return &msgWriter{
+		mw:     c.msgWriterState,
+		closed: false,
+	}, nil
 }
 
 func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
@@ -107,8 +135,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
 	}
 
 	if !c.flate() {
-		defer c.msgWriter.mu.Unlock()
-		return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
+		defer c.msgWriterState.mu.Unlock()
+		return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
 	}
 
 	n, err := mw.Write(p)
@@ -120,25 +148,22 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
 	return n, err
 }
 
-func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
+func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
 	err := mw.mu.Lock(ctx)
 	if err != nil {
 		return err
 	}
 
-	mw.closed = false
 	mw.ctx = ctx
 	mw.opcode = opcode(typ)
 	mw.flate = false
 
-	if mw.trimWriter != nil {
-		mw.trimWriter.reset()
-	}
+	mw.trimWriter.reset()
 
 	return nil
 }
 
-func (mw *msgWriter) returnFlateWriter() {
+func (mw *msgWriterState) putFlateWriter() {
 	if mw.flateWriter != nil {
 		putFlateWriter(mw.flateWriter)
 		mw.flateWriter = nil
@@ -146,16 +171,12 @@ func (mw *msgWriter) returnFlateWriter() {
 }
 
 // Write writes the given bytes to the WebSocket connection.
-func (mw *msgWriter) Write(p []byte) (_ int, err error) {
+func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
 	defer errd.Wrap(&err, "failed to write")
 
 	mw.writeMu.Lock()
 	defer mw.writeMu.Unlock()
 
-	if mw.closed {
-		return 0, xerrors.New("cannot use closed writer")
-	}
-
 	if mw.c.flate() {
 		// Only enables flate if the length crosses the
 		// threshold on the first frame
@@ -165,13 +186,21 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 	}
 
 	if mw.flate {
+		if stateless {
+			err = kflate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf)
+			if err != nil {
+				return 0, err
+			}
+			mw.dict.write(p)
+			return len(p), nil
+		}
 		return mw.flateWriter.Write(p)
 	}
 
 	return mw.write(p)
 }
 
-func (mw *msgWriter) write(p []byte) (int, error) {
+func (mw *msgWriterState) write(p []byte) (int, error) {
 	n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
 	if err != nil {
 		return n, xerrors.Errorf("failed to write data frame: %w", err)
@@ -181,42 +210,36 @@ func (mw *msgWriter) write(p []byte) (int, error) {
 }
 
 // Close flushes the frame to the connection.
-func (mw *msgWriter) Close() (err error) {
+func (mw *msgWriterState) Close() (err error) {
 	defer errd.Wrap(&err, "failed to close writer")
 
 	mw.writeMu.Lock()
 	defer mw.writeMu.Unlock()
 
-	if mw.closed {
-		return xerrors.New("cannot use closed writer")
-	}
-
-	if mw.flate {
+	if mw.flate && !stateless {
 		err = mw.flateWriter.Flush()
 		if err != nil {
-			return xerrors.Errorf("failed to flush flate writer: %w", err)
+			return xerrors.Errorf("failed to flush flate: %w", err)
 		}
 	}
 
-	// We set closed after flushing the flate writer to ensure Write
-	// can succeed.
-	mw.closed = true
-
 	_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
 	if err != nil {
 		return xerrors.Errorf("failed to write fin frame: %w", err)
 	}
 
 	if mw.flate && !mw.flateContextTakeover() {
-		mw.returnFlateWriter()
+		mw.dict.close()
+		mw.putFlateWriter()
 	}
 	mw.mu.Unlock()
 	return nil
 }
 
-func (mw *msgWriter) close() {
+func (mw *msgWriterState) close() {
 	mw.writeMu.Lock()
-	mw.returnFlateWriter()
+	mw.putFlateWriter()
+	mw.dict.close()
 }
 
 func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
@@ -250,10 +273,11 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
 
 	if c.client {
 		c.writeHeader.masked = true
-		err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey)
+		_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
 		if err != nil {
 			return 0, xerrors.Errorf("failed to generate masking key: %w", err)
 		}
+		c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
 	}
 
 	c.writeHeader.rsv1 = false
@@ -261,7 +285,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
 		c.writeHeader.rsv1 = true
 	}
 
-	err = writeFrameHeader(c.writeHeader, c.bw)
+	err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
 	if err != nil {
 		return 0, err
 	}