From 503b4696fcbad5c2c18e364fcc31540a7c5e43e9 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Thu, 13 Feb 2020 01:57:19 -0500
Subject: [PATCH] Simplify sliding window API

---
 compress_notjs.go | 43 +++++++++++++++++++++++++------------------
 compress_test.go  | 11 ++++++-----
 conn_test.go      |  7 +++----
 read.go           | 16 +++++-----------
 4 files changed, 39 insertions(+), 38 deletions(-)

diff --git a/compress_notjs.go b/compress_notjs.go
index 3f0d8b9..2076136 100644
--- a/compress_notjs.go
+++ b/compress_notjs.go
@@ -120,41 +120,48 @@ type slidingWindow struct {
 
 var swPool = map[int]*sync.Pool{}
 
-func newSlidingWindow(n int) *slidingWindow {
+func (sw *slidingWindow) init(n int) {
+	if sw.buf != nil {
+		return
+	}
+
 	p, ok := swPool[n]
 	if !ok {
 		p = &sync.Pool{}
 		swPool[n] = p
 	}
-	sw, ok := p.Get().(*slidingWindow)
+	buf, ok := p.Get().([]byte)
 	if ok {
-		return sw
-	}
-	return &slidingWindow{
-		buf: make([]byte, 0, n),
+		sw.buf = buf[:0]
+	} else {
+		sw.buf = make([]byte, 0, n)
 	}
 }
 
-func returnSlidingWindow(sw *slidingWindow) {
-	sw.buf = sw.buf[:0]
-	swPool[cap(sw.buf)].Put(sw)
+func (sw *slidingWindow) close() {
+	if sw.buf == nil {
+		return
+	}
+
+	swPool[cap(sw.buf)].Put(sw.buf)
+	sw.buf = nil
 }
 
-func (w *slidingWindow) write(p []byte) {
-	if len(p) >= cap(w.buf) {
-		w.buf = w.buf[:cap(w.buf)]
-		p = p[len(p)-cap(w.buf):]
-		copy(w.buf, p)
+func (sw *slidingWindow) write(p []byte) {
+	if len(p) >= cap(sw.buf) {
+		sw.buf = sw.buf[:cap(sw.buf)]
+		p = p[len(p)-cap(sw.buf):]
+		copy(sw.buf, p)
 		return
 	}
 
-	left := cap(w.buf) - len(w.buf)
+	left := cap(sw.buf) - len(sw.buf)
 	if left < len(p) {
 		// We need to shift spaceNeeded bytes from the end to make room for p at the end.
 		spaceNeeded := len(p) - left
-		copy(w.buf, w.buf[spaceNeeded:])
-		w.buf = w.buf[:len(w.buf)-spaceNeeded]
+		copy(sw.buf, sw.buf[spaceNeeded:])
+		sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
 	}
 
-	w.buf = append(w.buf, p...)
+	sw.buf = append(sw.buf, p...)
 }
diff --git a/compress_test.go b/compress_test.go
index 364d542..2c4c896 100644
--- a/compress_test.go
+++ b/compress_test.go
@@ -21,12 +21,13 @@ func Test_slidingWindow(t *testing.T) {
 
 			input := xrand.String(maxWindow)
 			windowLength := xrand.Int(maxWindow)
-			r := newSlidingWindow(windowLength)
-			r.write([]byte(input))
+			var sw slidingWindow
+			sw.init(windowLength)
+			sw.write([]byte(input))
 
-			assert.Equal(t, "window length", windowLength, cap(r.buf))
-			if !strings.HasSuffix(input, string(r.buf)) {
-				t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf)
+			assert.Equal(t, "window length", windowLength, cap(sw.buf))
+			if !strings.HasSuffix(input, string(sw.buf)) {
+				t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf)
 			}
 		})
 	}
diff --git a/conn_test.go b/conn_test.go
index e1e6c35..25b0809 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -351,13 +351,12 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
 	ctx, cancel := context.WithCancel(tt.ctx)
 
 	discardLoopErr := xsync.Go(func() error {
+		defer c.Close(websocket.StatusInternalError, "")
+
 		for {
 			_, _, err := c.Read(ctx)
-			if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
-				return nil
-			}
 			if err != nil {
-				return err
+				return assertCloseStatus(websocket.StatusNormalClosure, err)
 			}
 		}
 	})
diff --git a/read.go b/read.go
index 49c03b4..dd73ac9 100644
--- a/read.go
+++ b/read.go
@@ -87,15 +87,11 @@ func newMsgReader(c *Conn) *msgReader {
 }
 
 func (mr *msgReader) resetFlate() {
-	if mr.flateContextTakeover() && mr.dict == nil {
-		mr.dict = newSlidingWindow(32768)
-	}
-
 	if mr.flateContextTakeover() {
-		mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
-	} else {
-		mr.flateReader = getFlateReader(readerFunc(mr.read), nil)
+		mr.dict.init(32768)
 	}
+
+	mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
 	mr.limitReader.r = mr.flateReader
 	mr.flateTail.Reset(deflateMessageTail)
 }
@@ -111,9 +107,7 @@ func (mr *msgReader) close() {
 	mr.c.readMu.Lock(context.Background())
 	mr.returnFlateReader()
 
-	if mr.dict != nil {
-		returnSlidingWindow(mr.dict)
-	}
+	mr.dict.close()
 }
 
 func (mr *msgReader) flateContextTakeover() bool {
@@ -325,7 +319,7 @@ type msgReader struct {
 	flateReader io.Reader
 	flateTail   strings.Reader
 	limitReader *limitReader
-	dict        *slidingWindow
+	dict        slidingWindow
 
 	fin           bool
 	payloadLength int64
-- 
GitLab