From a02cbef5605d23c97972fbea8dd16488cf437b7a Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Fri, 13 Oct 2023 03:34:15 -0700
Subject: [PATCH] compress.go: Fix context takeover

---
 accept.go             |  1 +
 ci/bench.sh           |  4 ++--
 compress.go           | 16 ++++++----------
 conn.go               |  1 +
 conn_test.go          |  4 ++--
 dial_test.go          |  3 ++-
 export_test.go        |  6 ++++--
 internal/util/util.go |  7 +++++++
 internal/xsync/go.go  |  3 ++-
 read.go               | 27 ++++++++++++++++-----------
 write.go              | 11 +++--------
 ws_js.go              |  1 +
 12 files changed, 47 insertions(+), 37 deletions(-)

diff --git a/accept.go b/accept.go
index ff2033e..6c63e73 100644
--- a/accept.go
+++ b/accept.go
@@ -269,6 +269,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
 
 		if strings.HasPrefix(p, "client_max_window_bits") {
 			// We cannot adjust the read sliding window so cannot make use of this.
+			// By not responding to it, we tell the client we're ignoring it.
 			continue
 		}
 
diff --git a/ci/bench.sh b/ci/bench.sh
index 31bf2f1..8f99278 100755
--- a/ci/bench.sh
+++ b/ci/bench.sh
@@ -2,8 +2,8 @@
 set -eu
 cd -- "$(dirname "$0")/.."
 
-go test --bench=. "$@" ./...
+go test --run=^$ --bench=. "$@" ./...
 (
   cd ./internal/thirdparty
-  go test --bench=. "$@" ./...
+  go test --run=^$ --bench=. "$@" ./...
 )
diff --git a/compress.go b/compress.go
index e6722fc..61e6e26 100644
--- a/compress.go
+++ b/compress.go
@@ -31,7 +31,7 @@ const (
 	CompressionDisabled CompressionMode = iota
 
 	// CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection.
-	// It reusing the sliding window from previous messages.
+	// It reuses the sliding window from previous messages.
 	// As most WebSocket protocols are repetitive, this can be very efficient.
 	// It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover.
 	//
@@ -80,7 +80,7 @@ func (copts *compressionOptions) setHeader(h http.Header) {
 // They are removed when sending to avoid the overhead as
 // WebSocket framing tell's when the message has ended but then
 // we need to add them back otherwise flate.Reader keeps
-// trying to return more bytes.
+// trying to read more bytes.
 const deflateMessageTail = "\x00\x00\xff\xff"
 
 type trimLastFourBytesWriter struct {
@@ -201,23 +201,19 @@ func (sw *slidingWindow) init(n int) {
 	}
 
 	p := slidingWindowPool(n)
-	buf, ok := p.Get().(*[]byte)
+	sw2, ok := p.Get().(*slidingWindow)
 	if ok {
-		sw.buf = (*buf)[:0]
+		*sw = *sw2
 	} else {
 		sw.buf = make([]byte, 0, n)
 	}
 }
 
 func (sw *slidingWindow) close() {
-	if sw.buf == nil {
-		return
-	}
-
+	sw.buf = sw.buf[:0]
 	swPoolMu.Lock()
-	swPool[cap(sw.buf)].Put(&sw.buf)
+	swPool[cap(sw.buf)].Put(sw)
 	swPoolMu.Unlock()
-	sw.buf = nil
 }
 
 func (sw *slidingWindow) write(p []byte) {
diff --git a/conn.go b/conn.go
index 17a6b96..81a57c7 100644
--- a/conn.go
+++ b/conn.go
@@ -292,4 +292,5 @@ func (m *mu) unlock() {
 }
 
 type noCopy struct{}
+
 func (*noCopy) Lock() {}
diff --git a/conn_test.go b/conn_test.go
index 59661b7..7a6a0c3 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -267,7 +267,7 @@ func TestConn(t *testing.T) {
 
 	t.Run("HTTPClient.Timeout", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
-			HTTPClient: &http.Client{Timeout: time.Second*5},
+			HTTPClient: &http.Client{Timeout: time.Second * 5},
 		}, nil)
 
 		tt.goEchoLoop(c2)
@@ -458,7 +458,7 @@ func BenchmarkConn(b *testing.B) {
 
 				typ, r, err := c1.Reader(bb.ctx)
 				if err != nil {
-					b.Fatal(err)
+					b.Fatal(i, err)
 				}
 				if websocket.MessageText != typ {
 					assert.Equal(b, "data type", websocket.MessageText, typ)
diff --git a/dial_test.go b/dial_test.go
index 8680147..e072db2 100644
--- a/dial_test.go
+++ b/dial_test.go
@@ -15,6 +15,7 @@ import (
 	"time"
 
 	"nhooyr.io/websocket/internal/test/assert"
+	"nhooyr.io/websocket/internal/util"
 )
 
 func TestBadDials(t *testing.T) {
@@ -27,7 +28,7 @@ func TestBadDials(t *testing.T) {
 			name   string
 			url    string
 			opts   *DialOptions
-			rand   readerFunc
+			rand   util.ReaderFunc
 			nilCtx bool
 		}{
 			{
diff --git a/export_test.go b/export_test.go
index d618a15..8731b6d 100644
--- a/export_test.go
+++ b/export_test.go
@@ -3,9 +3,11 @@
 
 package websocket
 
+import "nhooyr.io/websocket/internal/util"
+
 func (c *Conn) RecordBytesWritten() *int {
 	var bytesWritten int
-	c.bw.Reset(writerFunc(func(p []byte) (int, error) {
+	c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) {
 		bytesWritten += len(p)
 		return c.rwc.Write(p)
 	}))
@@ -14,7 +16,7 @@ func (c *Conn) RecordBytesWritten() *int {
 
 func (c *Conn) RecordBytesRead() *int {
 	var bytesRead int
-	c.br.Reset(readerFunc(func(p []byte) (int, error) {
+	c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) {
 		n, err := c.rwc.Read(p)
 		bytesRead += n
 		return n, err
diff --git a/internal/util/util.go b/internal/util/util.go
index f23fb67..aa21070 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -6,3 +6,10 @@ type WriterFunc func(p []byte) (int, error)
 func (f WriterFunc) Write(p []byte) (int, error) {
 	return f(p)
 }
+
+// ReaderFunc is used to implement one off io.Readers.
+type ReaderFunc func(p []byte) (int, error)
+
+func (f ReaderFunc) Read(p []byte) (int, error) {
+	return f(p)
+}
diff --git a/internal/xsync/go.go b/internal/xsync/go.go
index 7a61f27..5229b12 100644
--- a/internal/xsync/go.go
+++ b/internal/xsync/go.go
@@ -2,6 +2,7 @@ package xsync
 
 import (
 	"fmt"
+	"runtime/debug"
 )
 
 // Go allows running a function in another goroutine
@@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error {
 			r := recover()
 			if r != nil {
 				select {
-				case errs <- fmt.Errorf("panic in go fn: %v", r):
+				case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()):
 				default:
 				}
 			}
diff --git a/read.go b/read.go
index 7bc6f20..d321786 100644
--- a/read.go
+++ b/read.go
@@ -13,6 +13,7 @@ import (
 	"time"
 
 	"nhooyr.io/websocket/internal/errd"
+	"nhooyr.io/websocket/internal/util"
 	"nhooyr.io/websocket/internal/xsync"
 )
 
@@ -101,13 +102,20 @@ func newMsgReader(c *Conn) *msgReader {
 
 func (mr *msgReader) resetFlate() {
 	if mr.flateContextTakeover() {
+		if mr.dict == nil {
+			mr.dict = &slidingWindow{}
+		}
 		mr.dict.init(32768)
 	}
 	if mr.flateBufio == nil {
 		mr.flateBufio = getBufioReader(mr.readFunc)
 	}
 
-	mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
+	if mr.flateContextTakeover() {
+		mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
+	} else {
+		mr.flateReader = getFlateReader(mr.flateBufio, nil)
+	}
 	mr.limitReader.r = mr.flateReader
 	mr.flateTail.Reset(deflateMessageTail)
 }
@@ -122,7 +130,10 @@ func (mr *msgReader) putFlateReader() {
 func (mr *msgReader) close() {
 	mr.c.readMu.forceLock()
 	mr.putFlateReader()
-	mr.dict.close()
+	if mr.dict != nil {
+		mr.dict.close()
+		mr.dict = nil
+	}
 	if mr.flateBufio != nil {
 		putBufioReader(mr.flateBufio)
 	}
@@ -348,14 +359,14 @@ type msgReader struct {
 	flateBufio  *bufio.Reader
 	flateTail   strings.Reader
 	limitReader *limitReader
-	dict        slidingWindow
+	dict        *slidingWindow
 
 	fin           bool
 	payloadLength int64
 	maskKey       uint32
 
-	// readerFunc(mr.Read) to avoid continuous allocations.
-	readFunc readerFunc
+	// util.ReaderFunc(mr.Read) to avoid continuous allocations.
+	readFunc util.ReaderFunc
 }
 
 func (mr *msgReader) reset(ctx context.Context, h header) {
@@ -484,9 +495,3 @@ func (lr *limitReader) Read(p []byte) (int, error) {
 	}
 	return n, err
 }
-
-type readerFunc func(p []byte) (int, error)
-
-func (f readerFunc) Read(p []byte) (int, error) {
-	return f(p)
-}
diff --git a/write.go b/write.go
index 7921eac..500609d 100644
--- a/write.go
+++ b/write.go
@@ -16,6 +16,7 @@ import (
 	"compress/flate"
 
 	"nhooyr.io/websocket/internal/errd"
+	"nhooyr.io/websocket/internal/util"
 )
 
 // Writer returns a writer bounded by the context that will write
@@ -93,7 +94,7 @@ func newMsgWriterState(c *Conn) *msgWriterState {
 func (mw *msgWriterState) ensureFlate() {
 	if mw.trimWriter == nil {
 		mw.trimWriter = &trimLastFourBytesWriter{
-			w: writerFunc(mw.write),
+			w: util.WriterFunc(mw.write),
 		}
 	}
 
@@ -380,17 +381,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
 	return n, nil
 }
 
-type writerFunc func(p []byte) (int, error)
-
-func (f writerFunc) Write(p []byte) (int, error) {
-	return f(p)
-}
-
 // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
 // and returns it.
 func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
 	var writeBuf []byte
-	bw.Reset(writerFunc(func(p2 []byte) (int, error) {
+	bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
 		writeBuf = p2[:cap(p2)]
 		return len(p2), nil
 	}))
diff --git a/ws_js.go b/ws_js.go
index 05f2202..9f0e19e 100644
--- a/ws_js.go
+++ b/ws_js.go
@@ -566,4 +566,5 @@ func (m *mu) unlock() {
 }
 
 type noCopy struct{}
+
 func (*noCopy) Lock() {}
-- 
GitLab