diff --git a/compress_notjs.go b/compress_notjs.go index 7c6b2fc013041efec8b7353de38001c2b8c59cf6..a61b7ba472dc85dd98b9a36ccecb2886945e1068 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -3,10 +3,11 @@ package websocket import ( - "compress/flate" "io" "net/http" "sync" + + "github.com/klauspost/compress/flate" ) func (m CompressionMode) opts() *compressionOptions { @@ -45,10 +46,16 @@ type trimLastFourBytesWriter struct { } func (tw *trimLastFourBytesWriter) reset() { - tw.tail = tw.tail[:0] + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + extra := len(tw.tail) + len(p) - 4 if extra <= 0 { @@ -65,7 +72,10 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if err != nil { return 0, err } - tw.tail = tw.tail[extra:] + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, @@ -118,22 +128,32 @@ type slidingWindow struct { buf []byte } -var swPoolMu sync.Mutex +var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} -func (sw *slidingWindow) init(n int) { - if sw.buf != nil { - return +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p } + p = &sync.Pool{} + swPoolMu.Lock() - defer swPoolMu.Unlock() + swPool[n] = p + swPoolMu.Unlock() - p, ok := swPool[n] - if !ok { - p = &sync.Pool{} - swPool[n] = p + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return } + + p := slidingWindowPool(n) buf, ok := p.Get().([]byte) if ok { sw.buf = buf[:0] diff --git a/conn_notjs.go b/conn_notjs.go index 178fcad02a61574b3e1a519ddaaebd39742854df..e6ff7df362d08e400ca26f1225a4bada8a7a56eb 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -39,16 +39,17 @@ type Conn struct { // Read state. readMu *mu - readHeader header + readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error // Write state. - msgWriter *msgWriter - writeFrameMu *mu - writeBuf []byte - writeHeader header + msgWriterState *msgWriterState + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header closed chan struct{} closeMu sync.Mutex @@ -94,14 +95,14 @@ func newConn(cfg connConfig) *Conn { c.msgReader = newMsgReader(c) - c.msgWriter = newMsgWriter(c) + c.msgWriterState = newMsgWriterState(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 256 - if !c.msgWriter.flateContextTakeover() { + if !c.msgWriterState.flateContextTakeover() { c.flateThreshold = 512 } } @@ -142,7 +143,7 @@ func (c *Conn) close(err error) { c.writeFrameMu.Lock(context.Background()) putBufioWriter(c.bw) } - c.msgWriter.close() + c.msgWriterState.close() c.msgReader.close() if c.client { diff --git a/conn_test.go b/conn_test.go index 265156e970cd1cf0c753706580c532f98d8b19f1..398ffd5181e50b5519b977bbd739162ad714d77b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,7 +5,6 @@ package websocket_test import ( "bytes" "context" - "crypto/rand" "fmt" "io" "io/ioutil" @@ -13,6 +12,7 @@ import ( "net/http/httptest" "os" "os/exec" + "strings" "sync" "testing" "time" @@ -379,15 +379,15 @@ func BenchmarkConn(b *testing.B) { mode websocket.CompressionMode }{ { - name: "compressionDisabled", + name: "disabledCompress", mode: websocket.CompressionDisabled, }, { - name: "compression", + name: "compress", mode: websocket.CompressionContextTakeover, }, { - name: "noContextCompression", + name: "compressNoContext", mode: websocket.CompressionNoContextTakeover, }, } @@ -395,44 +395,36 @@ func BenchmarkConn(b *testing.B) { b.Run(bc.name, func(b *testing.B) { bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, - }, nil) + }, &websocket.AcceptOptions{ + CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, + }) defer bb.cleanup() bb.goEchoLoop(c2) - const n = 32768 - writeBuf := make([]byte, n) - readBuf := make([]byte, n) - writes := make(chan websocket.MessageType) + msg := []byte(strings.Repeat("1234", 128)) + readBuf := make([]byte, len(msg)) + writes := make(chan struct{}) defer close(writes) werrs := make(chan error) go func() { - for typ := range writes { - werrs <- c1.Write(bb.ctx, typ, writeBuf) + for range writes { + werrs <- c1.Write(bb.ctx, websocket.MessageText, msg) } }() - b.SetBytes(n) + b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := rand.Reader.Read(writeBuf) - if err != nil { - b.Fatal(err) - } - - expType := websocket.MessageBinary - if writeBuf[0]%2 == 1 { - expType = websocket.MessageText - } - writes <- expType + writes <- struct{}{} typ, r, err := c1.Reader(bb.ctx) if err != nil { b.Fatal(err) } - if expType != typ { - assert.Equal(b, "data type", expType, typ) + if websocket.MessageText != typ { + assert.Equal(b, "data type", websocket.MessageText, typ) } _, err = io.ReadFull(r, readBuf) @@ -448,8 +440,8 @@ func BenchmarkConn(b *testing.B) { assert.Equal(b, "n2", 0, n2) } - if !bytes.Equal(writeBuf, readBuf) { - assert.Equal(b, "msg", writeBuf, readBuf) + if !bytes.Equal(msg, readBuf) { + assert.Equal(b, "msg", msg, readBuf) } err = <-werrs @@ -464,3 +456,8 @@ func BenchmarkConn(b *testing.B) { }) } } + +func TestCompression(t *testing.T) { + t.Parallel() + +} diff --git a/frame.go b/frame.go index 491ae75c33c5bf6ed4f0274191fa8a4b1ab20cff..4acaecf43ff6c9a8eedad12b0e5a52a29af14ae5 100644 --- a/frame.go +++ b/frame.go @@ -3,9 +3,12 @@ package websocket import ( "bufio" "encoding/binary" + "io" "math" "math/bits" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -46,12 +49,12 @@ type header struct { // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. -func readFrameHeader(h *header, r *bufio.Reader) (err error) { +func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { - return err + return header{}, err } h.fin = b&(1<<7) != 0 @@ -63,7 +66,7 @@ func readFrameHeader(h *header, r *bufio.Reader) (err error) { b, err = r.ReadByte() if err != nil { - return err + return header{}, err } h.masked = b&(1<<7) != 0 @@ -73,24 +76,29 @@ func readFrameHeader(h *header, r *bufio.Reader) (err error) { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: - var pl uint16 - err = binary.Read(r, binary.BigEndian, &pl) - h.payloadLength = int64(pl) + _, err = io.ReadFull(r, readBuf[:2]) + h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: - err = binary.Read(r, binary.BigEndian, &h.payloadLength) + _, err = io.ReadFull(r, readBuf) + h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } if err != nil { - return err + return header{}, err + } + + if h.payloadLength < 0 { + return header{}, xerrors.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { - err = binary.Read(r, binary.LittleEndian, &h.maskKey) + _, err = io.ReadFull(r, readBuf[:4]) if err != nil { - return err + return header{}, err } + h.maskKey = binary.LittleEndian.Uint32(readBuf) } - return nil + return h, nil } // maxControlPayload is the maximum length of a control frame payload. @@ -99,7 +107,7 @@ const maxControlPayload = 125 // writeFrameHeader writes the bytes of the header to w. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func writeFrameHeader(h header, w *bufio.Writer) (err error) { +func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { defer errd.Wrap(&err, "failed to write frame header") var b byte @@ -143,16 +151,19 @@ func writeFrameHeader(h header, w *bufio.Writer) (err error) { switch { case h.payloadLength > math.MaxUint16: - err = binary.Write(w, binary.BigEndian, h.payloadLength) + binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) + _, err = w.Write(buf) case h.payloadLength > 125: - err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength)) + binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) + _, err = w.Write(buf[:2]) } if err != nil { return err } if h.masked { - err = binary.Write(w, binary.LittleEndian, h.maskKey) + binary.LittleEndian.PutUint32(buf, h.maskKey) + _, err = w.Write(buf[:4]) if err != nil { return err } diff --git a/frame_test.go b/frame_test.go index 38f1599a890cb52bda60a5d194c15e1449e5dbac..76826248d040e4696e0603b4d28f6a0159b7e363 100644 --- a/frame_test.go +++ b/frame_test.go @@ -80,14 +80,13 @@ func testHeader(t *testing.T, h header) { w := bufio.NewWriter(b) r := bufio.NewReader(b) - err := writeFrameHeader(h, w) + err := writeFrameHeader(h, w, make([]byte, 8)) assert.Success(t, err) err = w.Flush() assert.Success(t, err) - var h2 header - err = readFrameHeader(&h2, r) + h2, err := readFrameHeader(r, make([]byte, 8)) assert.Success(t, err) assert.Equal(t, "read header", h, h2) diff --git a/go.mod b/go.mod index cb37239165d515dd37e5d3d08dcc4c31d6e8d44c..a10c7b1e3e53dde31226a820586acdb7f6624fa8 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/golang/protobuf v1.3.3 github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 + github.com/klauspost/compress v1.10.0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index 8cbc66ce14e6244c32295b76f4df68b7306a99fc..e4bbd62d337c4edbd4e71bd740271cf2544a0467 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= +github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/xsync/go.go b/internal/xsync/go.go index d88ac622c5ddd0ebb202a1c5759aa098a14a3c1c..712739aa2fa5b6484629c47a5aa41c1a80f1bd7b 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -6,7 +6,7 @@ import ( // Go allows running a function in another goroutine // and waiting for its error. -func Go(fn func() error) <- chan error { +func Go(fn func() error) <-chan error { errs := make(chan error, 1) go func() { defer func() { diff --git a/read.go b/read.go index bf7fa6d928835a98b50a43e3f2765406e04e1201..bbad30d14e59f43be73cab8176b9c0cdd14f6216 100644 --- a/read.go +++ b/read.go @@ -3,6 +3,7 @@ package websocket import ( + "bufio" "context" "io" "io/ioutil" @@ -81,8 +82,9 @@ func newMsgReader(c *Conn) *msgReader { c: c, fin: true, } + mr.readFunc = mr.read - mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit+1) + mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } @@ -90,13 +92,16 @@ func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { mr.dict.init(32768) } + if mr.flateBufio == nil { + mr.flateBufio = getBufioReader(mr.readFunc) + } - mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } -func (mr *msgReader) returnFlateReader() { +func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil @@ -105,9 +110,11 @@ func (mr *msgReader) returnFlateReader() { func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) - mr.returnFlateReader() - + mr.putFlateReader() mr.dict.close() + if mr.flateBufio != nil { + putBufioReader(mr.flateBufio) + } } func (mr *msgReader) flateContextTakeover() bool { @@ -173,7 +180,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- ctx: } - err := readFrameHeader(&c.readHeader, c.br) + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: @@ -192,7 +199,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- context.Background(): } - return c.readHeader, nil + return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { @@ -317,6 +324,7 @@ type msgReader struct { ctx context.Context flate bool flateReader io.Reader + flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict slidingWindow @@ -324,12 +332,15 @@ type msgReader struct { fin bool payloadLength int64 maskKey uint32 + + // readerFunc(mr.Read) to avoid continuous allocations. + readFunc readerFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 - mr.limitReader.reset(readerFunc(mr.read)) + mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() @@ -346,15 +357,15 @@ func (mr *msgReader) setFrame(h header) { func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { - errd.Wrap(&err, "failed to read") if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { err = io.EOF } if xerrors.Is(err, io.EOF) { err = io.EOF - - mr.returnFlateReader() + mr.putFlateReader() + return } + errd.Wrap(&err, "failed to read") }() err = mr.c.readMu.Lock(mr.ctx) @@ -372,44 +383,46 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } func (mr *msgReader) read(p []byte) (int, error) { - if mr.payloadLength == 0 { - if mr.fin { - if mr.flate { - return mr.flateTail.Read(p) + for { + if mr.payloadLength == 0 { + if mr.fin { + if mr.flate { + return mr.flateTail.Read(p) + } + return 0, io.EOF } - return 0, io.EOF - } - h, err := mr.c.readLoop(mr.ctx) - if err != nil { - return 0, err - } - if h.opcode != opContinuation { - err := xerrors.New("received new data message without finishing the previous message") - mr.c.writeError(StatusProtocolError, err) - return 0, err + h, err := mr.c.readLoop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := xerrors.New("received new data message without finishing the previous message") + mr.c.writeError(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + + continue } - mr.setFrame(h) - return mr.read(p) - } + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } - if int64(len(p)) > mr.payloadLength { - p = p[:mr.payloadLength] - } + n, err := mr.c.readFramePayload(mr.ctx, p) + if err != nil { + return n, err + } - n, err := mr.c.readFramePayload(mr.ctx, p) - if err != nil { - return n, err - } + mr.payloadLength -= int64(n) - mr.payloadLength -= int64(n) + if !mr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } - if !mr.c.client { - mr.maskKey = mask(mr.maskKey, p) + return n, nil } - - return n, nil } type limitReader struct { diff --git a/write.go b/write.go index 9d4b670f84985570b2334cd90be5169771ca7289..ec3b7d059d4d31bc0597d0f9a5ca42d3a2d654f0 100644 --- a/write.go +++ b/write.go @@ -4,7 +4,6 @@ package websocket import ( "bufio" - "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -12,6 +11,8 @@ import ( "sync" "time" + "github.com/klauspost/compress/flate" + kflate "github.com/klauspost/compress/flate" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" @@ -24,8 +25,6 @@ import ( // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. -// -// Never close the returned writer twice. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { @@ -49,6 +48,26 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { } type msgWriter struct { + mw *msgWriterState + closed bool +} + +func (mw *msgWriter) Write(p []byte) (int, error) { + if mw.closed { + return 0, xerrors.New("cannot use closed writer") + } + return mw.mw.Write(p) +} + +func (mw *msgWriter) Close() error { + if mw.closed { + return xerrors.New("cannot use closed writer") + } + mw.closed = true + return mw.mw.Close() +} + +type msgWriterState struct { c *Conn mu *mu @@ -56,36 +75,42 @@ type msgWriter struct { ctx context.Context opcode opcode - closed bool flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer + dict slidingWindow } -func newMsgWriter(c *Conn) *msgWriter { - mw := &msgWriter{ +func newMsgWriterState(c *Conn) *msgWriterState { + mw := &msgWriterState{ c: c, mu: newMu(c), } return mw } -func (mw *msgWriter) ensureFlate() { +const stateless = true + +func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } } - if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) + if stateless { + mw.dict.init(8192) + } else { + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } } mw.flate = true } -func (mw *msgWriter) flateContextTakeover() bool { +func (mw *msgWriterState) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } @@ -93,11 +118,14 @@ func (mw *msgWriter) flateContextTakeover() bool { } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.msgWriter.reset(ctx, typ) + err := c.msgWriterState.reset(ctx, typ) if err != nil { return nil, err } - return c.msgWriter, nil + return &msgWriter{ + mw: c.msgWriterState, + closed: false, + }, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { @@ -107,8 +135,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error } if !c.flate() { - defer c.msgWriter.mu.Unlock() - return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) + defer c.msgWriterState.mu.Unlock() + return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) } n, err := mw.Write(p) @@ -120,25 +148,22 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { err := mw.mu.Lock(ctx) if err != nil { return err } - mw.closed = false mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false - if mw.trimWriter != nil { - mw.trimWriter.reset() - } + mw.trimWriter.reset() return nil } -func (mw *msgWriter) returnFlateWriter() { +func (mw *msgWriterState) putFlateWriter() { if mw.flateWriter != nil { putFlateWriter(mw.flateWriter) mw.flateWriter = nil @@ -146,16 +171,12 @@ func (mw *msgWriter) returnFlateWriter() { } // Write writes the given bytes to the WebSocket connection. -func (mw *msgWriter) Write(p []byte) (_ int, err error) { +func (mw *msgWriterState) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") mw.writeMu.Lock() defer mw.writeMu.Unlock() - if mw.closed { - return 0, xerrors.New("cannot use closed writer") - } - if mw.c.flate() { // Only enables flate if the length crosses the // threshold on the first frame @@ -165,13 +186,21 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { } if mw.flate { + if stateless { + err = kflate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err + } + mw.dict.write(p) + return len(p), nil + } return mw.flateWriter.Write(p) } return mw.write(p) } -func (mw *msgWriter) write(p []byte) (int, error) { +func (mw *msgWriterState) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, xerrors.Errorf("failed to write data frame: %w", err) @@ -181,42 +210,36 @@ func (mw *msgWriter) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -func (mw *msgWriter) Close() (err error) { +func (mw *msgWriterState) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") mw.writeMu.Lock() defer mw.writeMu.Unlock() - if mw.closed { - return xerrors.New("cannot use closed writer") - } - - if mw.flate { + if mw.flate && !stateless { err = mw.flateWriter.Flush() if err != nil { - return xerrors.Errorf("failed to flush flate writer: %w", err) + return xerrors.Errorf("failed to flush flate: %w", err) } } - // We set closed after flushing the flate writer to ensure Write - // can succeed. - mw.closed = true - _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { - mw.returnFlateWriter() + mw.dict.close() + mw.putFlateWriter() } mw.mu.Unlock() return nil } -func (mw *msgWriter) close() { +func (mw *msgWriterState) close() { mw.writeMu.Lock() - mw.returnFlateWriter() + mw.putFlateWriter() + mw.dict.close() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { @@ -250,10 +273,11 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if c.client { c.writeHeader.masked = true - err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) + _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { return 0, xerrors.Errorf("failed to generate masking key: %w", err) } + c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } c.writeHeader.rsv1 = false @@ -261,7 +285,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco c.writeHeader.rsv1 = true } - err = writeFrameHeader(c.writeHeader, c.bw) + err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) if err != nil { return 0, err }