From b6b56b7499ee09561b87ad3de17709a59f839952 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Wed, 5 Feb 2020 00:21:26 -0600 Subject: [PATCH] Both modes seem to work :) --- accept.go | 14 ++++---- assert_test.go | 3 +- compress.go | 58 +++++++++++++++------------------ compress_test.go | 45 ++++++++++++++++++++++++++ conn.go | 41 ++++++++++++++---------- conn_test.go | 7 ++-- dial.go | 13 ++++---- read.go | 74 ++++++++++++++++++++++-------------------- write.go | 83 ++++++++++++++++++++++++------------------------ 9 files changed, 196 insertions(+), 142 deletions(-) create mode 100644 compress_test.go diff --git a/accept.go b/accept.go index ac7f2de..0394fa6 100644 --- a/accept.go +++ b/accept.go @@ -111,12 +111,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ - subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), - rwc: netConn, - client: false, - copts: copts, - br: brw.Reader, - bw: brw.Writer, + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, + flateThreshold: opts.CompressionOptions.Threshold, + + br: brw.Reader, + bw: brw.Writer, }), nil } diff --git a/assert_test.go b/assert_test.go index cd78fbb..5307ee8 100644 --- a/assert_test.go +++ b/assert_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest/assert" "nhooyr.io/websocket" @@ -33,7 +34,7 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { - t.Helper() + slog.Helper() var act interface{} err := wsjson.Read(ctx, c, &act) diff --git a/compress.go b/compress.go index fd2535c..efd89b3 100644 --- a/compress.go +++ b/compress.go @@ -148,12 +148,12 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { var flateReaderPool sync.Pool -func getFlateReader(r io.Reader) io.Reader { +func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { - return flate.NewReader(r) + return flate.NewReaderDict(r, dict) } - fr.(flate.Resetter).Reset(r, nil) + fr.(flate.Resetter).Reset(r, dict) return fr } @@ -163,10 +163,10 @@ func putFlateReader(fr io.Reader) { var flateWriterPool sync.Pool -func getFlateWriter(w io.Writer, dict []byte) *flate.Writer { +func getFlateWriter(w io.Writer) *flate.Writer { fw, ok := flateWriterPool.Get().(*flate.Writer) if !ok { - fw, _ = flate.NewWriterDict(w, flate.BestSpeed, dict) + fw, _ = flate.NewWriter(w, flate.BestSpeed) return fw } fw.Reset(w) @@ -177,40 +177,32 @@ func putFlateWriter(w *flate.Writer) { flateWriterPool.Put(w) } -type slidingWindowReader struct { - window []byte - - r io.Reader +type slidingWindow struct { + r io.Reader + buf []byte } -func (r slidingWindowReader) Read(p []byte) (int, error) { - n, err := r.r.Read(p) - p = p[:n] - - r.append(p) - - return n, err +func newSlidingWindow(n int) *slidingWindow { + return &slidingWindow{ + buf: make([]byte, 0, n), + } } -func (r slidingWindowReader) append(p []byte) { - if len(r.window) <= cap(r.window) { - r.window = append(r.window, p...) +func (w *slidingWindow) write(p []byte) { + if len(p) >= cap(w.buf) { + w.buf = w.buf[:cap(w.buf)] + p = p[len(p)-cap(w.buf):] + copy(w.buf, p) + return } - if len(p) > cap(r.window) { - p = p[len(p)-cap(r.window):] + left := cap(w.buf) - len(w.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(w.buf, w.buf[spaceNeeded:]) + w.buf = w.buf[:len(w.buf)-spaceNeeded] } - // p now contains at max the last window bytes - // so we need to be able to append all of it to r.window. - // Shift as many bytes from r.window as needed. - - // Maximum window size minus current window minus extra gives - // us the number of bytes that need to be shifted. - off := len(r.window) + len(p) - cap(r.window) - - r.window = append(r.window[:0], r.window[off:]...) - copy(r.window, r.window[off:]) - copy(r.window[len(r.window)-len(p):], p) - return + w.buf = append(w.buf, p...) } diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 0000000..6edfcb1 --- /dev/null +++ b/compress_test.go @@ -0,0 +1,45 @@ +package websocket + +import ( + "crypto/rand" + "encoding/base64" + "math/big" + "strings" + "testing" + + "cdr.dev/slog/sloggers/slogtest/assert" +) + +func Test_slidingWindow(t *testing.T) { + t.Parallel() + + const testCount = 99 + const maxWindow = 99999 + for i := 0; i < testCount; i++ { + input := randStr(t, maxWindow) + windowLength := randInt(t, maxWindow) + r := newSlidingWindow(windowLength) + r.write([]byte(input)) + + if cap(r.buf) != windowLength { + t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) + } + assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf))) + } +} + +func randStr(t *testing.T, max int) string { + n := randInt(t, max) + + b := make([]byte, n) + _, err := rand.Read(b) + assert.Success(t, "rand.Read", err) + + return base64.StdEncoding.EncodeToString(b) +} + +func randInt(t *testing.T, max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + assert.Success(t, "rand.Int", err) + return int(x.Int64()) +} diff --git a/conn.go b/conn.go index ab93e4e..2d36123 100644 --- a/conn.go +++ b/conn.go @@ -38,12 +38,13 @@ const ( // On any error from any method, the connection is closed // with an appropriate reason. type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - br *bufio.Reader - bw *bufio.Writer + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context @@ -71,10 +72,11 @@ type Conn struct { } type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int br *bufio.Reader bw *bufio.Writer @@ -82,10 +84,11 @@ type connConfig struct { func newConn(cfg connConfig) *Conn { c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, @@ -96,6 +99,12 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } + if c.flateThreshold == 0 { + c.flateThreshold = 256 + if c.writeNoContextTakeOver() { + c.flateThreshold = 512 + } + } c.readMu = newMu(c) c.writeFrameMu = newMu(c) @@ -145,12 +154,10 @@ func (c *Conn) close(err error) { } c.msgWriter.close() + c.msgReader.close() if c.client { - c.readMu.Lock(context.Background()) putBufioReader(c.br) - c.readMu.Unlock() } - c.msgReader.close() }() } diff --git a/conn_test.go b/conn_test.go index a65c332..7186da8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,13 +27,15 @@ func TestConn(t *testing.T) { Subprotocols: []string{"echo"}, InsecureSkipVerify: true, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionNoContextTakeover, + Mode: websocket.CompressionContextTakeover, + Threshold: 1, }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") err = echoLoop(r.Context(), c) + t.Logf("server: %v", err) assertCloseStatus(t, websocket.StatusNormalClosure, err) }, false) defer closeFn() @@ -46,7 +48,8 @@ func TestConn(t *testing.T) { opts := &websocket.DialOptions{ Subprotocols: []string{"echo"}, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionNoContextTakeover, + Mode: websocket.CompressionContextTakeover, + Threshold: 1, }, } opts.HTTPClient = s.Client() diff --git a/dial.go b/dial.go index f53d30e..4557602 100644 --- a/dial.go +++ b/dial.go @@ -99,12 +99,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } return newConn(connConfig{ - subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - rwc: rwc, - client: true, - copts: copts, - br: getBufioReader(rwc), - bw: getBufioWriter(rwc), + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + rwc: rwc, + client: true, + copts: copts, + flateThreshold: opts.CompressionOptions.Threshold, + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), }), resp, nil } diff --git a/read.go b/read.go index 4b94f06..73ec0b3 100644 --- a/read.go +++ b/read.go @@ -72,25 +72,40 @@ func (c *Conn) SetReadLimit(n int64) { c.msgReader.limitReader.limit.Store(n) } +const defaultReadLimit = 32768 + func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } - mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) + mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit) return mr } -func (mr *msgReader) initFlateReader() { - mr.flateReader = getFlateReader(readerFunc(mr.read)) +func (mr *msgReader) ensureFlate() { + if mr.flateContextTakeover() && mr.dict == nil { + mr.dict = newSlidingWindow(32768) + } + + if mr.flateContextTakeover() { + mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) + } else { + mr.flateReader = getFlateReader(readerFunc(mr.read), nil) + } mr.limitReader.r = mr.flateReader } +func (mr *msgReader) returnFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) - defer mr.c.readMu.Unlock() - mr.returnFlateReader() } @@ -299,10 +314,11 @@ type msgReader struct { c *Conn ctx context.Context - deflate bool + flate bool flateReader io.Reader - deflateTail strings.Reader + flateTail strings.Reader limitReader *limitReader + dict *slidingWindow fin bool payloadLength int64 @@ -311,12 +327,10 @@ type msgReader struct { func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx - mr.deflate = h.rsv1 - if mr.deflate { - if !mr.flateContextTakeover() { - mr.initFlateReader() - } - mr.deflateTail.Reset(deflateMessageTail) + mr.flate = h.rsv1 + if mr.flate { + mr.ensureFlate() + mr.flateTail.Reset(deflateMessageTail) } mr.limitReader.reset() @@ -331,18 +345,10 @@ func (mr *msgReader) setFrame(h header) { func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { - r := recover() - if r != nil { - if r != "ANMOL" { - panic(r) - } + errd.Wrap(&err, "failed to read") + if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { err = io.EOF - if !mr.flateContextTakeover() { - mr.returnFlateReader() - } } - - errd.Wrap(&err, "failed to read") if xerrors.Is(err, io.EOF) { err = io.EOF } @@ -354,25 +360,23 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } defer mr.c.readMu.Unlock() - return mr.limitReader.Read(p) -} - -func (mr *msgReader) returnFlateReader() { - if mr.flateReader != nil { - putFlateReader(mr.flateReader) - mr.flateReader = nil + n, err = mr.limitReader.Read(p) + if mr.flateContextTakeover() { + p = p[:n] + mr.dict.write(p) } + return n, err } func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { if mr.fin { - if mr.deflate { - if mr.deflateTail.Len() == 0 { - panic("ANMOL") + if mr.flate { + n, err := mr.flateTail.Read(p) + if xerrors.Is(err, io.EOF) { + mr.returnFlateReader() } - n, _ := mr.deflateTail.Read(p) - return n, nil + return n, err } return 0, io.EOF } diff --git a/write.go b/write.go index db47ddb..a7fa5f5 100644 --- a/write.go +++ b/write.go @@ -37,8 +37,8 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // // See the Writer method if you want to stream a message. // -// If compression is disabled, then it is guaranteed to write the message -// in a single frame. +// If compression is disabled or the threshold is not met, then it +// will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { @@ -47,20 +47,38 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { return nil } +type msgWriter struct { + c *Conn + + mu *mu + + ctx context.Context + opcode opcode + closed bool + flate bool + + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer +} + func newMsgWriter(c *Conn) *msgWriter { mw := &msgWriter{ c: c, mu: newMu(c), } - mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), - } return mw } -func (mw *msgWriter) ensureFlateWriter() { +func (mw *msgWriter) ensureFlate() { if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter, nil) + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + } + + mw.flateWriter = getFlateWriter(mw.trimWriter) + mw.flate = true } } @@ -85,8 +103,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return 0, err } - if !c.flate() { - // Fast single frame path. + if !c.flate() || len(p) < c.flateThreshold { defer c.msgWriter.mu.Unlock() return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } @@ -100,20 +117,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -type msgWriter struct { - c *Conn - - mu *mu - - ctx context.Context - opcode opcode - closed bool - - flate bool - trimWriter *trimLastFourBytesWriter - flateWriter *flate.Writer -} - func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.Lock(ctx) if err != nil { @@ -127,6 +130,13 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { return nil } +func (mw *msgWriter) returnFlateWriter() { + if mw.flateWriter != nil { + putFlateWriter(mw.flateWriter) + mw.flateWriter = nil + } +} + // Write writes the given bytes to the WebSocket connection. func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") @@ -135,16 +145,10 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - if mw.c.flate() { - if !mw.flate { - mw.flate = true - - if !mw.flateContextTakeover() { - mw.ensureFlateWriter() - } - mw.trimWriter.reset() - } - + // TODO can make threshold detection robust across writes by writing to buffer + if mw.flate || + mw.c.flate() && len(p) >= mw.c.flateThreshold { + mw.ensureFlate() return mw.flateWriter.Write(p) } @@ -181,21 +185,16 @@ func (mw *msgWriter) Close() (err error) { return xerrors.Errorf("failed to write fin frame: %w", err) } - if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil { - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil + if mw.c.flate() && !mw.flateContextTakeover() { + mw.returnFlateWriter() } - mw.mu.Unlock() return nil } func (mw *msgWriter) close() { - if mw.flateWriter != nil && mw.flateContextTakeover() { - mw.mu.Lock(context.Background()) - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil - } + mw.mu.Lock(context.Background()) + mw.returnFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { -- GitLab