From b6b56b7499ee09561b87ad3de17709a59f839952 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Wed, 5 Feb 2020 00:21:26 -0600
Subject: [PATCH] Both modes seem to work :)

---
 accept.go        | 14 ++++----
 assert_test.go   |  3 +-
 compress.go      | 58 +++++++++++++++------------------
 compress_test.go | 45 ++++++++++++++++++++++++++
 conn.go          | 41 ++++++++++++++----------
 conn_test.go     |  7 ++--
 dial.go          | 13 ++++----
 read.go          | 74 ++++++++++++++++++++++--------------------
 write.go         | 83 ++++++++++++++++++++++++------------------------
 9 files changed, 196 insertions(+), 142 deletions(-)
 create mode 100644 compress_test.go

diff --git a/accept.go b/accept.go
index ac7f2de..0394fa6 100644
--- a/accept.go
+++ b/accept.go
@@ -111,12 +111,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
 	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
 
 	return newConn(connConfig{
-		subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
-		rwc:         netConn,
-		client:      false,
-		copts:       copts,
-		br:          brw.Reader,
-		bw:          brw.Writer,
+		subprotocol:    w.Header().Get("Sec-WebSocket-Protocol"),
+		rwc:            netConn,
+		client:         false,
+		copts:          copts,
+		flateThreshold: opts.CompressionOptions.Threshold,
+
+		br: brw.Reader,
+		bw: brw.Writer,
 	}), nil
 }
 
diff --git a/assert_test.go b/assert_test.go
index cd78fbb..5307ee8 100644
--- a/assert_test.go
+++ b/assert_test.go
@@ -6,6 +6,7 @@ import (
 	"strings"
 	"testing"
 
+	"cdr.dev/slog"
 	"cdr.dev/slog/sloggers/slogtest/assert"
 
 	"nhooyr.io/websocket"
@@ -33,7 +34,7 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int)
 }
 
 func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
-	t.Helper()
+	slog.Helper()
 
 	var act interface{}
 	err := wsjson.Read(ctx, c, &act)
diff --git a/compress.go b/compress.go
index fd2535c..efd89b3 100644
--- a/compress.go
+++ b/compress.go
@@ -148,12 +148,12 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
 
 var flateReaderPool sync.Pool
 
-func getFlateReader(r io.Reader) io.Reader {
+func getFlateReader(r io.Reader, dict []byte) io.Reader {
 	fr, ok := flateReaderPool.Get().(io.Reader)
 	if !ok {
-		return flate.NewReader(r)
+		return flate.NewReaderDict(r, dict)
 	}
-	fr.(flate.Resetter).Reset(r, nil)
+	fr.(flate.Resetter).Reset(r, dict)
 	return fr
 }
 
@@ -163,10 +163,10 @@ func putFlateReader(fr io.Reader) {
 
 var flateWriterPool sync.Pool
 
-func getFlateWriter(w io.Writer, dict []byte) *flate.Writer {
+func getFlateWriter(w io.Writer) *flate.Writer {
 	fw, ok := flateWriterPool.Get().(*flate.Writer)
 	if !ok {
-		fw, _ = flate.NewWriterDict(w, flate.BestSpeed, dict)
+		fw, _ = flate.NewWriter(w, flate.BestSpeed)
 		return fw
 	}
 	fw.Reset(w)
@@ -177,40 +177,32 @@ func putFlateWriter(w *flate.Writer) {
 	flateWriterPool.Put(w)
 }
 
-type slidingWindowReader struct {
-	window []byte
-
-	r io.Reader
+type slidingWindow struct {
+	r   io.Reader
+	buf []byte
 }
 
-func (r slidingWindowReader) Read(p []byte) (int, error) {
-	n, err := r.r.Read(p)
-	p = p[:n]
-
-	r.append(p)
-
-	return n, err
+func newSlidingWindow(n int) *slidingWindow {
+	return &slidingWindow{
+		buf: make([]byte, 0, n),
+	}
 }
 
-func (r slidingWindowReader) append(p []byte) {
-	if len(r.window) <= cap(r.window) {
-		r.window = append(r.window, p...)
+func (w *slidingWindow) write(p []byte) {
+	if len(p) >= cap(w.buf) {
+		w.buf = w.buf[:cap(w.buf)]
+		p = p[len(p)-cap(w.buf):]
+		copy(w.buf, p)
+		return
 	}
 
-	if len(p) > cap(r.window) {
-		p = p[len(p)-cap(r.window):]
+	left := cap(w.buf) - len(w.buf)
+	if left < len(p) {
+		// We need to shift spaceNeeded bytes from the end to make room for p at the end.
+		spaceNeeded := len(p) - left
+		copy(w.buf, w.buf[spaceNeeded:])
+		w.buf = w.buf[:len(w.buf)-spaceNeeded]
 	}
 
-	// p now contains at max the last window bytes
-	// so we need to be able to append all of it to r.window.
-	// Shift as many bytes from r.window as needed.
-
-	// Maximum window size minus current window minus extra gives
-	// us the number of bytes that need to be shifted.
-	off := len(r.window) + len(p) - cap(r.window)
-
-	r.window = append(r.window[:0], r.window[off:]...)
-	copy(r.window, r.window[off:])
-	copy(r.window[len(r.window)-len(p):], p)
-	return
+	w.buf = append(w.buf, p...)
 }
diff --git a/compress_test.go b/compress_test.go
new file mode 100644
index 0000000..6edfcb1
--- /dev/null
+++ b/compress_test.go
@@ -0,0 +1,45 @@
+package websocket
+
+import (
+	"crypto/rand"
+	"encoding/base64"
+	"math/big"
+	"strings"
+	"testing"
+
+	"cdr.dev/slog/sloggers/slogtest/assert"
+)
+
+func Test_slidingWindow(t *testing.T) {
+	t.Parallel()
+
+	const testCount = 99
+	const maxWindow = 99999
+	for i := 0; i < testCount; i++ {
+		input := randStr(t, maxWindow)
+		windowLength := randInt(t, maxWindow)
+		r := newSlidingWindow(windowLength)
+		r.write([]byte(input))
+
+		if cap(r.buf) != windowLength {
+			t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength)
+		}
+		assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf)))
+	}
+}
+
+func randStr(t *testing.T, max int) string {
+	n := randInt(t, max)
+
+	b := make([]byte, n)
+	_, err := rand.Read(b)
+	assert.Success(t, "rand.Read", err)
+
+	return base64.StdEncoding.EncodeToString(b)
+}
+
+func randInt(t *testing.T, max int) int {
+	x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
+	assert.Success(t, "rand.Int", err)
+	return int(x.Int64())
+}
diff --git a/conn.go b/conn.go
index ab93e4e..2d36123 100644
--- a/conn.go
+++ b/conn.go
@@ -38,12 +38,13 @@ const (
 // On any error from any method, the connection is closed
 // with an appropriate reason.
 type Conn struct {
-	subprotocol string
-	rwc         io.ReadWriteCloser
-	client      bool
-	copts       *compressionOptions
-	br          *bufio.Reader
-	bw          *bufio.Writer
+	subprotocol    string
+	rwc            io.ReadWriteCloser
+	client         bool
+	copts          *compressionOptions
+	flateThreshold int
+	br             *bufio.Reader
+	bw             *bufio.Writer
 
 	readTimeout  chan context.Context
 	writeTimeout chan context.Context
@@ -71,10 +72,11 @@ type Conn struct {
 }
 
 type connConfig struct {
-	subprotocol string
-	rwc         io.ReadWriteCloser
-	client      bool
-	copts       *compressionOptions
+	subprotocol    string
+	rwc            io.ReadWriteCloser
+	client         bool
+	copts          *compressionOptions
+	flateThreshold int
 
 	br *bufio.Reader
 	bw *bufio.Writer
@@ -82,10 +84,11 @@ type connConfig struct {
 
 func newConn(cfg connConfig) *Conn {
 	c := &Conn{
-		subprotocol: cfg.subprotocol,
-		rwc:         cfg.rwc,
-		client:      cfg.client,
-		copts:       cfg.copts,
+		subprotocol:    cfg.subprotocol,
+		rwc:            cfg.rwc,
+		client:         cfg.client,
+		copts:          cfg.copts,
+		flateThreshold: cfg.flateThreshold,
 
 		br: cfg.br,
 		bw: cfg.bw,
@@ -96,6 +99,12 @@ func newConn(cfg connConfig) *Conn {
 		closed:      make(chan struct{}),
 		activePings: make(map[string]chan<- struct{}),
 	}
+	if c.flateThreshold == 0 {
+		c.flateThreshold = 256
+		if c.writeNoContextTakeOver() {
+			c.flateThreshold = 512
+		}
+	}
 
 	c.readMu = newMu(c)
 	c.writeFrameMu = newMu(c)
@@ -145,12 +154,10 @@ func (c *Conn) close(err error) {
 		}
 		c.msgWriter.close()
 
+		c.msgReader.close()
 		if c.client {
-			c.readMu.Lock(context.Background())
 			putBufioReader(c.br)
-			c.readMu.Unlock()
 		}
-		c.msgReader.close()
 	}()
 }
 
diff --git a/conn_test.go b/conn_test.go
index a65c332..7186da8 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -27,13 +27,15 @@ func TestConn(t *testing.T) {
 				Subprotocols:       []string{"echo"},
 				InsecureSkipVerify: true,
 				CompressionOptions: websocket.CompressionOptions{
-					Mode: websocket.CompressionNoContextTakeover,
+					Mode:      websocket.CompressionContextTakeover,
+					Threshold: 1,
 				},
 			})
 			assert.Success(t, "accept", err)
 			defer c.Close(websocket.StatusInternalError, "")
 
 			err = echoLoop(r.Context(), c)
+			t.Logf("server: %v", err)
 			assertCloseStatus(t, websocket.StatusNormalClosure, err)
 		}, false)
 		defer closeFn()
@@ -46,7 +48,8 @@ func TestConn(t *testing.T) {
 		opts := &websocket.DialOptions{
 			Subprotocols: []string{"echo"},
 			CompressionOptions: websocket.CompressionOptions{
-				Mode: websocket.CompressionNoContextTakeover,
+				Mode:      websocket.CompressionContextTakeover,
+				Threshold: 1,
 			},
 		}
 		opts.HTTPClient = s.Client()
diff --git a/dial.go b/dial.go
index f53d30e..4557602 100644
--- a/dial.go
+++ b/dial.go
@@ -99,12 +99,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
 	}
 
 	return newConn(connConfig{
-		subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
-		rwc:         rwc,
-		client:      true,
-		copts:       copts,
-		br:          getBufioReader(rwc),
-		bw:          getBufioWriter(rwc),
+		subprotocol:    resp.Header.Get("Sec-WebSocket-Protocol"),
+		rwc:            rwc,
+		client:         true,
+		copts:          copts,
+		flateThreshold: opts.CompressionOptions.Threshold,
+		br:             getBufioReader(rwc),
+		bw:             getBufioWriter(rwc),
 	}), resp, nil
 }
 
diff --git a/read.go b/read.go
index 4b94f06..73ec0b3 100644
--- a/read.go
+++ b/read.go
@@ -72,25 +72,40 @@ func (c *Conn) SetReadLimit(n int64) {
 	c.msgReader.limitReader.limit.Store(n)
 }
 
+const defaultReadLimit = 32768
+
 func newMsgReader(c *Conn) *msgReader {
 	mr := &msgReader{
 		c:   c,
 		fin: true,
 	}
 
-	mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768)
+	mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit)
 	return mr
 }
 
-func (mr *msgReader) initFlateReader() {
-	mr.flateReader = getFlateReader(readerFunc(mr.read))
+func (mr *msgReader) ensureFlate() {
+	if mr.flateContextTakeover() && mr.dict == nil {
+		mr.dict = newSlidingWindow(32768)
+	}
+
+	if mr.flateContextTakeover() {
+		mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
+	} else {
+		mr.flateReader = getFlateReader(readerFunc(mr.read), nil)
+	}
 	mr.limitReader.r = mr.flateReader
 }
 
+func (mr *msgReader) returnFlateReader() {
+	if mr.flateReader != nil {
+		putFlateReader(mr.flateReader)
+		mr.flateReader = nil
+	}
+}
+
 func (mr *msgReader) close() {
 	mr.c.readMu.Lock(context.Background())
-	defer mr.c.readMu.Unlock()
-
 	mr.returnFlateReader()
 }
 
@@ -299,10 +314,11 @@ type msgReader struct {
 	c *Conn
 
 	ctx         context.Context
-	deflate     bool
+	flate       bool
 	flateReader io.Reader
-	deflateTail strings.Reader
+	flateTail   strings.Reader
 	limitReader *limitReader
+	dict        *slidingWindow
 
 	fin           bool
 	payloadLength int64
@@ -311,12 +327,10 @@ type msgReader struct {
 
 func (mr *msgReader) reset(ctx context.Context, h header) {
 	mr.ctx = ctx
-	mr.deflate = h.rsv1
-	if mr.deflate {
-		if !mr.flateContextTakeover() {
-			mr.initFlateReader()
-		}
-		mr.deflateTail.Reset(deflateMessageTail)
+	mr.flate = h.rsv1
+	if mr.flate {
+		mr.ensureFlate()
+		mr.flateTail.Reset(deflateMessageTail)
 	}
 
 	mr.limitReader.reset()
@@ -331,18 +345,10 @@ func (mr *msgReader) setFrame(h header) {
 
 func (mr *msgReader) Read(p []byte) (n int, err error) {
 	defer func() {
-		r := recover()
-		if r != nil {
-			if r != "ANMOL" {
-				panic(r)
-			}
+		errd.Wrap(&err, "failed to read")
+		if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
 			err = io.EOF
-			if !mr.flateContextTakeover() {
-				mr.returnFlateReader()
-			}
 		}
-
-		errd.Wrap(&err, "failed to read")
 		if xerrors.Is(err, io.EOF) {
 			err = io.EOF
 		}
@@ -354,25 +360,23 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
 	}
 	defer mr.c.readMu.Unlock()
 
-	return mr.limitReader.Read(p)
-}
-
-func (mr *msgReader) returnFlateReader() {
-	if mr.flateReader != nil {
-		putFlateReader(mr.flateReader)
-		mr.flateReader = nil
+	n, err = mr.limitReader.Read(p)
+	if mr.flateContextTakeover() {
+		p = p[:n]
+		mr.dict.write(p)
 	}
+	return n, err
 }
 
 func (mr *msgReader) read(p []byte) (int, error) {
 	if mr.payloadLength == 0 {
 		if mr.fin {
-			if mr.deflate {
-				if mr.deflateTail.Len() == 0 {
-					panic("ANMOL")
+			if mr.flate {
+				n, err := mr.flateTail.Read(p)
+				if xerrors.Is(err, io.EOF) {
+					mr.returnFlateReader()
 				}
-				n, _ := mr.deflateTail.Read(p)
-				return n, nil
+				return n, err
 			}
 			return 0, io.EOF
 		}
diff --git a/write.go b/write.go
index db47ddb..a7fa5f5 100644
--- a/write.go
+++ b/write.go
@@ -37,8 +37,8 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
 //
 // See the Writer method if you want to stream a message.
 //
-// If compression is disabled, then it is guaranteed to write the message
-// in a single frame.
+// If compression is disabled or the threshold is not met, then it
+// will write the message in a single frame.
 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
 	_, err := c.write(ctx, typ, p)
 	if err != nil {
@@ -47,20 +47,38 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
 	return nil
 }
 
+type msgWriter struct {
+	c *Conn
+
+	mu *mu
+
+	ctx    context.Context
+	opcode opcode
+	closed bool
+	flate  bool
+
+	trimWriter  *trimLastFourBytesWriter
+	flateWriter *flate.Writer
+}
+
 func newMsgWriter(c *Conn) *msgWriter {
 	mw := &msgWriter{
 		c:  c,
 		mu: newMu(c),
 	}
-	mw.trimWriter = &trimLastFourBytesWriter{
-		w: writerFunc(mw.write),
-	}
 	return mw
 }
 
-func (mw *msgWriter) ensureFlateWriter() {
+func (mw *msgWriter) ensureFlate() {
 	if mw.flateWriter == nil {
-		mw.flateWriter = getFlateWriter(mw.trimWriter, nil)
+		if mw.trimWriter == nil {
+			mw.trimWriter = &trimLastFourBytesWriter{
+				w: writerFunc(mw.write),
+			}
+		}
+
+		mw.flateWriter = getFlateWriter(mw.trimWriter)
+		mw.flate = true
 	}
 }
 
@@ -85,8 +103,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
 		return 0, err
 	}
 
-	if !c.flate() {
-		// Fast single frame path.
+	if !c.flate() || len(p) < c.flateThreshold {
 		defer c.msgWriter.mu.Unlock()
 		return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
 	}
@@ -100,20 +117,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
 	return n, err
 }
 
-type msgWriter struct {
-	c *Conn
-
-	mu *mu
-
-	ctx    context.Context
-	opcode opcode
-	closed bool
-
-	flate       bool
-	trimWriter  *trimLastFourBytesWriter
-	flateWriter *flate.Writer
-}
-
 func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
 	err := mw.mu.Lock(ctx)
 	if err != nil {
@@ -127,6 +130,13 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
 	return nil
 }
 
+func (mw *msgWriter) returnFlateWriter() {
+	if mw.flateWriter != nil {
+		putFlateWriter(mw.flateWriter)
+		mw.flateWriter = nil
+	}
+}
+
 // Write writes the given bytes to the WebSocket connection.
 func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 	defer errd.Wrap(&err, "failed to write")
@@ -135,16 +145,10 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 		return 0, xerrors.New("cannot use closed writer")
 	}
 
-	if mw.c.flate() {
-		if !mw.flate {
-			mw.flate = true
-
-			if !mw.flateContextTakeover() {
-				mw.ensureFlateWriter()
-			}
-			mw.trimWriter.reset()
-		}
-
+	// TODO can make threshold detection robust across writes by writing to buffer
+	if mw.flate ||
+		mw.c.flate() && len(p) >= mw.c.flateThreshold {
+		mw.ensureFlate()
 		return mw.flateWriter.Write(p)
 	}
 
@@ -181,21 +185,16 @@ func (mw *msgWriter) Close() (err error) {
 		return xerrors.Errorf("failed to write fin frame: %w", err)
 	}
 
-	if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil {
-		putFlateWriter(mw.flateWriter)
-		mw.flateWriter = nil
+	if mw.c.flate() && !mw.flateContextTakeover() {
+		mw.returnFlateWriter()
 	}
-
 	mw.mu.Unlock()
 	return nil
 }
 
 func (mw *msgWriter) close() {
-	if mw.flateWriter != nil && mw.flateContextTakeover() {
-		mw.mu.Lock(context.Background())
-		putFlateWriter(mw.flateWriter)
-		mw.flateWriter = nil
-	}
+	mw.mu.Lock(context.Background())
+	mw.returnFlateWriter()
 }
 
 func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
-- 
GitLab