good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit a02cbef5 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

compress.go: Fix context takeover

parent 4e15d756
Branches
Tags
No related merge requests found
...@@ -269,6 +269,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi ...@@ -269,6 +269,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
if strings.HasPrefix(p, "client_max_window_bits") { if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this. // We cannot adjust the read sliding window so cannot make use of this.
// By not responding to it, we tell the client we're ignoring it.
continue continue
} }
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
go test --bench=. "$@" ./... go test --run=^$ --bench=. "$@" ./...
( (
cd ./internal/thirdparty cd ./internal/thirdparty
go test --bench=. "$@" ./... go test --run=^$ --bench=. "$@" ./...
) )
...@@ -31,7 +31,7 @@ const ( ...@@ -31,7 +31,7 @@ const (
CompressionDisabled CompressionMode = iota CompressionDisabled CompressionMode = iota
// CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection. // CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection.
// It reusing the sliding window from previous messages. // It reuses the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient. // As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover. // It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover.
// //
...@@ -80,7 +80,7 @@ func (copts *compressionOptions) setHeader(h http.Header) { ...@@ -80,7 +80,7 @@ func (copts *compressionOptions) setHeader(h http.Header) {
// They are removed when sending to avoid the overhead as // They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then // WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps // we need to add them back otherwise flate.Reader keeps
// trying to return more bytes. // trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff" const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct { type trimLastFourBytesWriter struct {
...@@ -201,23 +201,19 @@ func (sw *slidingWindow) init(n int) { ...@@ -201,23 +201,19 @@ func (sw *slidingWindow) init(n int) {
} }
p := slidingWindowPool(n) p := slidingWindowPool(n)
buf, ok := p.Get().(*[]byte) sw2, ok := p.Get().(*slidingWindow)
if ok { if ok {
sw.buf = (*buf)[:0] *sw = *sw2
} else { } else {
sw.buf = make([]byte, 0, n) sw.buf = make([]byte, 0, n)
} }
} }
func (sw *slidingWindow) close() { func (sw *slidingWindow) close() {
if sw.buf == nil { sw.buf = sw.buf[:0]
return
}
swPoolMu.Lock() swPoolMu.Lock()
swPool[cap(sw.buf)].Put(&sw.buf) swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock() swPoolMu.Unlock()
sw.buf = nil
} }
func (sw *slidingWindow) write(p []byte) { func (sw *slidingWindow) write(p []byte) {
......
...@@ -292,4 +292,5 @@ func (m *mu) unlock() { ...@@ -292,4 +292,5 @@ func (m *mu) unlock() {
} }
type noCopy struct{} type noCopy struct{}
func (*noCopy) Lock() {} func (*noCopy) Lock() {}
...@@ -458,7 +458,7 @@ func BenchmarkConn(b *testing.B) { ...@@ -458,7 +458,7 @@ func BenchmarkConn(b *testing.B) {
typ, r, err := c1.Reader(bb.ctx) typ, r, err := c1.Reader(bb.ctx)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(i, err)
} }
if websocket.MessageText != typ { if websocket.MessageText != typ {
assert.Equal(b, "data type", websocket.MessageText, typ) assert.Equal(b, "data type", websocket.MessageText, typ)
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"time" "time"
"nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/util"
) )
func TestBadDials(t *testing.T) { func TestBadDials(t *testing.T) {
...@@ -27,7 +28,7 @@ func TestBadDials(t *testing.T) { ...@@ -27,7 +28,7 @@ func TestBadDials(t *testing.T) {
name string name string
url string url string
opts *DialOptions opts *DialOptions
rand readerFunc rand util.ReaderFunc
nilCtx bool nilCtx bool
}{ }{
{ {
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
package websocket package websocket
import "nhooyr.io/websocket/internal/util"
func (c *Conn) RecordBytesWritten() *int { func (c *Conn) RecordBytesWritten() *int {
var bytesWritten int var bytesWritten int
c.bw.Reset(writerFunc(func(p []byte) (int, error) { c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) {
bytesWritten += len(p) bytesWritten += len(p)
return c.rwc.Write(p) return c.rwc.Write(p)
})) }))
...@@ -14,7 +16,7 @@ func (c *Conn) RecordBytesWritten() *int { ...@@ -14,7 +16,7 @@ func (c *Conn) RecordBytesWritten() *int {
func (c *Conn) RecordBytesRead() *int { func (c *Conn) RecordBytesRead() *int {
var bytesRead int var bytesRead int
c.br.Reset(readerFunc(func(p []byte) (int, error) { c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) {
n, err := c.rwc.Read(p) n, err := c.rwc.Read(p)
bytesRead += n bytesRead += n
return n, err return n, err
......
...@@ -6,3 +6,10 @@ type WriterFunc func(p []byte) (int, error) ...@@ -6,3 +6,10 @@ type WriterFunc func(p []byte) (int, error)
func (f WriterFunc) Write(p []byte) (int, error) { func (f WriterFunc) Write(p []byte) (int, error) {
return f(p) return f(p)
} }
// ReaderFunc is used to implement one off io.Readers.
type ReaderFunc func(p []byte) (int, error)
func (f ReaderFunc) Read(p []byte) (int, error) {
return f(p)
}
...@@ -2,6 +2,7 @@ package xsync ...@@ -2,6 +2,7 @@ package xsync
import ( import (
"fmt" "fmt"
"runtime/debug"
) )
// Go allows running a function in another goroutine // Go allows running a function in another goroutine
...@@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error { ...@@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error {
r := recover() r := recover()
if r != nil { if r != nil {
select { select {
case errs <- fmt.Errorf("panic in go fn: %v", r): case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()):
default: default:
} }
} }
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"time" "time"
"nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
"nhooyr.io/websocket/internal/xsync" "nhooyr.io/websocket/internal/xsync"
) )
...@@ -101,13 +102,20 @@ func newMsgReader(c *Conn) *msgReader { ...@@ -101,13 +102,20 @@ func newMsgReader(c *Conn) *msgReader {
func (mr *msgReader) resetFlate() { func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() { if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768) mr.dict.init(32768)
} }
if mr.flateBufio == nil { if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc) mr.flateBufio = getBufioReader(mr.readFunc)
} }
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail) mr.flateTail.Reset(deflateMessageTail)
} }
...@@ -122,7 +130,10 @@ func (mr *msgReader) putFlateReader() { ...@@ -122,7 +130,10 @@ func (mr *msgReader) putFlateReader() {
func (mr *msgReader) close() { func (mr *msgReader) close() {
mr.c.readMu.forceLock() mr.c.readMu.forceLock()
mr.putFlateReader() mr.putFlateReader()
if mr.dict != nil {
mr.dict.close() mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil { if mr.flateBufio != nil {
putBufioReader(mr.flateBufio) putBufioReader(mr.flateBufio)
} }
...@@ -348,14 +359,14 @@ type msgReader struct { ...@@ -348,14 +359,14 @@ type msgReader struct {
flateBufio *bufio.Reader flateBufio *bufio.Reader
flateTail strings.Reader flateTail strings.Reader
limitReader *limitReader limitReader *limitReader
dict slidingWindow dict *slidingWindow
fin bool fin bool
payloadLength int64 payloadLength int64
maskKey uint32 maskKey uint32
// readerFunc(mr.Read) to avoid continuous allocations. // util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc readerFunc readFunc util.ReaderFunc
} }
func (mr *msgReader) reset(ctx context.Context, h header) { func (mr *msgReader) reset(ctx context.Context, h header) {
...@@ -484,9 +495,3 @@ func (lr *limitReader) Read(p []byte) (int, error) { ...@@ -484,9 +495,3 @@ func (lr *limitReader) Read(p []byte) (int, error) {
} }
return n, err return n, err
} }
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"compress/flate" "compress/flate"
"nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
) )
// Writer returns a writer bounded by the context that will write // Writer returns a writer bounded by the context that will write
...@@ -93,7 +94,7 @@ func newMsgWriterState(c *Conn) *msgWriterState { ...@@ -93,7 +94,7 @@ func newMsgWriterState(c *Conn) *msgWriterState {
func (mw *msgWriterState) ensureFlate() { func (mw *msgWriterState) ensureFlate() {
if mw.trimWriter == nil { if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{ mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write), w: util.WriterFunc(mw.write),
} }
} }
...@@ -380,17 +381,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { ...@@ -380,17 +381,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
return n, nil return n, nil
} }
type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it. // and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte var writeBuf []byte
bw.Reset(writerFunc(func(p2 []byte) (int, error) { bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)] writeBuf = p2[:cap(p2)]
return len(p2), nil return len(p2), nil
})) }))
......
...@@ -566,4 +566,5 @@ func (m *mu) unlock() { ...@@ -566,4 +566,5 @@ func (m *mu) unlock() {
} }
type noCopy struct{} type noCopy struct{}
func (*noCopy) Lock() {} func (*noCopy) Lock() {}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment