From 503b4696fcbad5c2c18e364fcc31540a7c5e43e9 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Thu, 13 Feb 2020 01:57:19 -0500 Subject: [PATCH] Simplify sliding window API --- compress_notjs.go | 43 +++++++++++++++++++++++++------------------ compress_test.go | 11 ++++++----- conn_test.go | 7 +++---- read.go | 16 +++++----------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/compress_notjs.go b/compress_notjs.go index 3f0d8b9..2076136 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -120,41 +120,48 @@ type slidingWindow struct { var swPool = map[int]*sync.Pool{} -func newSlidingWindow(n int) *slidingWindow { +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + p, ok := swPool[n] if !ok { p = &sync.Pool{} swPool[n] = p } - sw, ok := p.Get().(*slidingWindow) + buf, ok := p.Get().([]byte) if ok { - return sw - } - return &slidingWindow{ - buf: make([]byte, 0, n), + sw.buf = buf[:0] + } else { + sw.buf = make([]byte, 0, n) } } -func returnSlidingWindow(sw *slidingWindow) { - sw.buf = sw.buf[:0] - swPool[cap(sw.buf)].Put(sw) +func (sw *slidingWindow) close() { + if sw.buf == nil { + return + } + + swPool[cap(sw.buf)].Put(sw.buf) + sw.buf = nil } -func (w *slidingWindow) write(p []byte) { - if len(p) >= cap(w.buf) { - w.buf = w.buf[:cap(w.buf)] - p = p[len(p)-cap(w.buf):] - copy(w.buf, p) +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) return } - left := cap(w.buf) - len(w.buf) + left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left - copy(w.buf, w.buf[spaceNeeded:]) - w.buf = w.buf[:len(w.buf)-spaceNeeded] + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } - w.buf = append(w.buf, p...) + sw.buf = append(sw.buf, p...) } diff --git a/compress_test.go b/compress_test.go index 364d542..2c4c896 100644 --- a/compress_test.go +++ b/compress_test.go @@ -21,12 +21,13 @@ func Test_slidingWindow(t *testing.T) { input := xrand.String(maxWindow) windowLength := xrand.Int(maxWindow) - r := newSlidingWindow(windowLength) - r.write([]byte(input)) + var sw slidingWindow + sw.init(windowLength) + sw.write([]byte(input)) - assert.Equal(t, "window length", windowLength, cap(r.buf)) - if !strings.HasSuffix(input, string(r.buf)) { - t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf) + assert.Equal(t, "window length", windowLength, cap(sw.buf)) + if !strings.HasSuffix(input, string(sw.buf)) { + t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) } }) } diff --git a/conn_test.go b/conn_test.go index e1e6c35..25b0809 100644 --- a/conn_test.go +++ b/conn_test.go @@ -351,13 +351,12 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { + defer c.Close(websocket.StatusInternalError, "") + for { _, _, err := c.Read(ctx) - if websocket.CloseStatus(err) == websocket.StatusNormalClosure { - return nil - } if err != nil { - return err + return assertCloseStatus(websocket.StatusNormalClosure, err) } } }) diff --git a/read.go b/read.go index 49c03b4..dd73ac9 100644 --- a/read.go +++ b/read.go @@ -87,15 +87,11 @@ func newMsgReader(c *Conn) *msgReader { } func (mr *msgReader) resetFlate() { - if mr.flateContextTakeover() && mr.dict == nil { - mr.dict = newSlidingWindow(32768) - } - if mr.flateContextTakeover() { - mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) - } else { - mr.flateReader = getFlateReader(readerFunc(mr.read), nil) + mr.dict.init(32768) } + + mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } @@ -111,9 +107,7 @@ func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) mr.returnFlateReader() - if mr.dict != nil { - returnSlidingWindow(mr.dict) - } + mr.dict.close() } func (mr *msgReader) flateContextTakeover() bool { @@ -325,7 +319,7 @@ type msgReader struct { flateReader io.Reader flateTail strings.Reader limitReader *limitReader - dict *slidingWindow + dict slidingWindow fin bool payloadLength int64 -- GitLab