From a02cbef5605d23c97972fbea8dd16488cf437b7a Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 13 Oct 2023 03:34:15 -0700 Subject: [PATCH] compress.go: Fix context takeover --- accept.go | 1 + ci/bench.sh | 4 ++-- compress.go | 16 ++++++---------- conn.go | 1 + conn_test.go | 4 ++-- dial_test.go | 3 ++- export_test.go | 6 ++++-- internal/util/util.go | 7 +++++++ internal/xsync/go.go | 3 ++- read.go | 27 ++++++++++++++++----------- write.go | 11 +++-------- ws_js.go | 1 + 12 files changed, 47 insertions(+), 37 deletions(-) diff --git a/accept.go b/accept.go index ff2033e..6c63e73 100644 --- a/accept.go +++ b/accept.go @@ -269,6 +269,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi if strings.HasPrefix(p, "client_max_window_bits") { // We cannot adjust the read sliding window so cannot make use of this. + // By not responding to it, we tell the client we're ignoring it. continue } diff --git a/ci/bench.sh b/ci/bench.sh index 31bf2f1..8f99278 100755 --- a/ci/bench.sh +++ b/ci/bench.sh @@ -2,8 +2,8 @@ set -eu cd -- "$(dirname "$0")/.." -go test --bench=. "$@" ./... +go test --run=^$ --bench=. "$@" ./... ( cd ./internal/thirdparty - go test --bench=. "$@" ./... + go test --run=^$ --bench=. "$@" ./... ) diff --git a/compress.go b/compress.go index e6722fc..61e6e26 100644 --- a/compress.go +++ b/compress.go @@ -31,7 +31,7 @@ const ( CompressionDisabled CompressionMode = iota // CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection. - // It reusing the sliding window from previous messages. + // It reuses the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. // It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover. // @@ -80,7 +80,7 @@ func (copts *compressionOptions) setHeader(h http.Header) { // They are removed when sending to avoid the overhead as // WebSocket framing tell's when the message has ended but then // we need to add them back otherwise flate.Reader keeps -// trying to return more bytes. +// trying to read more bytes. const deflateMessageTail = "\x00\x00\xff\xff" type trimLastFourBytesWriter struct { @@ -201,23 +201,19 @@ func (sw *slidingWindow) init(n int) { } p := slidingWindowPool(n) - buf, ok := p.Get().(*[]byte) + sw2, ok := p.Get().(*slidingWindow) if ok { - sw.buf = (*buf)[:0] + *sw = *sw2 } else { sw.buf = make([]byte, 0, n) } } func (sw *slidingWindow) close() { - if sw.buf == nil { - return - } - + sw.buf = sw.buf[:0] swPoolMu.Lock() - swPool[cap(sw.buf)].Put(&sw.buf) + swPool[cap(sw.buf)].Put(sw) swPoolMu.Unlock() - sw.buf = nil } func (sw *slidingWindow) write(p []byte) { diff --git a/conn.go b/conn.go index 17a6b96..81a57c7 100644 --- a/conn.go +++ b/conn.go @@ -292,4 +292,5 @@ func (m *mu) unlock() { } type noCopy struct{} + func (*noCopy) Lock() {} diff --git a/conn_test.go b/conn_test.go index 59661b7..7a6a0c3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -267,7 +267,7 @@ func TestConn(t *testing.T) { t.Run("HTTPClient.Timeout", func(t *testing.T) { tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ - HTTPClient: &http.Client{Timeout: time.Second*5}, + HTTPClient: &http.Client{Timeout: time.Second * 5}, }, nil) tt.goEchoLoop(c2) @@ -458,7 +458,7 @@ func BenchmarkConn(b *testing.B) { typ, r, err := c1.Reader(bb.ctx) if err != nil { - b.Fatal(err) + b.Fatal(i, err) } if websocket.MessageText != typ { assert.Equal(b, "data type", websocket.MessageText, typ) diff --git a/dial_test.go b/dial_test.go index 8680147..e072db2 100644 --- a/dial_test.go +++ b/dial_test.go @@ -15,6 +15,7 @@ import ( "time" "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/util" ) func TestBadDials(t *testing.T) { @@ -27,7 +28,7 @@ func TestBadDials(t *testing.T) { name string url string opts *DialOptions - rand readerFunc + rand util.ReaderFunc nilCtx bool }{ { diff --git a/export_test.go b/export_test.go index d618a15..8731b6d 100644 --- a/export_test.go +++ b/export_test.go @@ -3,9 +3,11 @@ package websocket +import "nhooyr.io/websocket/internal/util" + func (c *Conn) RecordBytesWritten() *int { var bytesWritten int - c.bw.Reset(writerFunc(func(p []byte) (int, error) { + c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { bytesWritten += len(p) return c.rwc.Write(p) })) @@ -14,7 +16,7 @@ func (c *Conn) RecordBytesWritten() *int { func (c *Conn) RecordBytesRead() *int { var bytesRead int - c.br.Reset(readerFunc(func(p []byte) (int, error) { + c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { n, err := c.rwc.Read(p) bytesRead += n return n, err diff --git a/internal/util/util.go b/internal/util/util.go index f23fb67..aa21070 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -6,3 +6,10 @@ type WriterFunc func(p []byte) (int, error) func (f WriterFunc) Write(p []byte) (int, error) { return f(p) } + +// ReaderFunc is used to implement one off io.Readers. +type ReaderFunc func(p []byte) (int, error) + +func (f ReaderFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/internal/xsync/go.go b/internal/xsync/go.go index 7a61f27..5229b12 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -2,6 +2,7 @@ package xsync import ( "fmt" + "runtime/debug" ) // Go allows running a function in another goroutine @@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error { r := recover() if r != nil { select { - case errs <- fmt.Errorf("panic in go fn: %v", r): + case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()): default: } } diff --git a/read.go b/read.go index 7bc6f20..d321786 100644 --- a/read.go +++ b/read.go @@ -13,6 +13,7 @@ import ( "time" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" "nhooyr.io/websocket/internal/xsync" ) @@ -101,13 +102,20 @@ func newMsgReader(c *Conn) *msgReader { func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { + if mr.dict == nil { + mr.dict = &slidingWindow{} + } mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } - mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + if mr.flateContextTakeover() { + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + } else { + mr.flateReader = getFlateReader(mr.flateBufio, nil) + } mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } @@ -122,7 +130,10 @@ func (mr *msgReader) putFlateReader() { func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() - mr.dict.close() + if mr.dict != nil { + mr.dict.close() + mr.dict = nil + } if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } @@ -348,14 +359,14 @@ type msgReader struct { flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader - dict slidingWindow + dict *slidingWindow fin bool payloadLength int64 maskKey uint32 - // readerFunc(mr.Read) to avoid continuous allocations. - readFunc readerFunc + // util.ReaderFunc(mr.Read) to avoid continuous allocations. + readFunc util.ReaderFunc } func (mr *msgReader) reset(ctx context.Context, h header) { @@ -484,9 +495,3 @@ func (lr *limitReader) Read(p []byte) (int, error) { } return n, err } - -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} diff --git a/write.go b/write.go index 7921eac..500609d 100644 --- a/write.go +++ b/write.go @@ -16,6 +16,7 @@ import ( "compress/flate" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" ) // Writer returns a writer bounded by the context that will write @@ -93,7 +94,7 @@ func newMsgWriterState(c *Conn) *msgWriterState { func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), + w: util.WriterFunc(mw.write), } } @@ -380,17 +381,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, nil } -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte - bw.Reset(writerFunc(func(p2 []byte) (int, error) { + bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) diff --git a/ws_js.go b/ws_js.go index 05f2202..9f0e19e 100644 --- a/ws_js.go +++ b/ws_js.go @@ -566,4 +566,5 @@ func (m *mu) unlock() { } type noCopy struct{} + func (*noCopy) Lock() {} -- GitLab