diff --git a/ci/test.mk b/ci/test.mk index 3fc34bbf6ce0ff0b5d44d492ae542d66acac916b..3d1f0ed12ac7029d250b0c93953f4a8088ea540d 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -1,4 +1,4 @@ -test: gotest ci/out/coverage.html +test: ci/out/coverage.html ifdef CI test: coveralls endif diff --git a/compress_notjs.go b/compress_notjs.go index a61b7ba472dc85dd98b9a36ccecb2886945e1068..a69110567bcf568e6503ac44e65e0216aba55daf 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -108,22 +108,6 @@ func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } -var flateWriterPool sync.Pool - -func getFlateWriter(w io.Writer) *flate.Writer { - fw, ok := flateWriterPool.Get().(*flate.Writer) - if !ok { - fw, _ = flate.NewWriter(w, flate.BestSpeed) - return fw - } - fw.Reset(w) - return fw -} - -func putFlateWriter(w *flate.Writer) { - flateWriterPool.Put(w) -} - type slidingWindow struct { buf []byte } diff --git a/conn_test.go b/conn_test.go index 398ffd5181e50b5519b977bbd739162ad714d77b..3b7fcdb5110b8375302e18860d2610e586551c6a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -402,6 +402,9 @@ func BenchmarkConn(b *testing.B) { bb.goEchoLoop(c2) + bytesWritten := c1.RecordBytesWritten() + bytesRead := c1.RecordBytesRead() + msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) writes := make(chan struct{}) @@ -451,6 +454,9 @@ func BenchmarkConn(b *testing.B) { } b.StopTimer() + b.ReportMetric(float64(*bytesWritten/b.N), "written/op") + b.ReportMetric(float64(*bytesRead/b.N), "read/op") + err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(b, err) }) diff --git a/export_test.go b/export_test.go new file mode 100644 index 0000000000000000000000000000000000000000..88b82c9f3eaae099ee43398c279b1778387e5e55 --- /dev/null +++ b/export_test.go @@ -0,0 +1,22 @@ +// +build !js + +package websocket + +func (c *Conn) RecordBytesWritten() *int { + var bytesWritten int + c.bw.Reset(writerFunc(func(p []byte) (int, error) { + bytesWritten += len(p) + return c.rwc.Write(p) + })) + return &bytesWritten +} + +func (c *Conn) RecordBytesRead() *int { + var bytesRead int + c.br.Reset(readerFunc(func(p []byte) (int, error) { + n, err := c.rwc.Read(p) + bytesRead += n + return n, err + })) + return &bytesRead +} diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 2bc01dbac9039234b78b8626637e26d745f85d18..602b887e87d574a02987ada1c130f8ef288148a2 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -39,8 +39,8 @@ func Error(t testing.TB, err error) { func Contains(t testing.TB, v interface{}, sub string) { t.Helper() - vstr := fmt.Sprint(v) - if !strings.Contains(vstr, sub) { - t.Fatalf("expected %q to contain %q", vstr, sub) + s := fmt.Sprint(v) + if !strings.Contains(s, sub) { + t.Fatalf("expected %q to contain %q", s, sub) } } diff --git a/write.go b/write.go index ec3b7d059d4d31bc0597d0f9a5ca42d3a2d654f0..b560b44cf25f79f999528b76a01b065ca50379e1 100644 --- a/write.go +++ b/write.go @@ -12,7 +12,6 @@ import ( "time" "github.com/klauspost/compress/flate" - kflate "github.com/klauspost/compress/flate" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" @@ -77,9 +76,8 @@ type msgWriterState struct { opcode opcode flate bool - trimWriter *trimLastFourBytesWriter - flateWriter *flate.Writer - dict slidingWindow + trimWriter *trimLastFourBytesWriter + dict slidingWindow } func newMsgWriterState(c *Conn) *msgWriterState { @@ -90,8 +88,6 @@ func newMsgWriterState(c *Conn) *msgWriterState { return mw } -const stateless = true - func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ @@ -99,14 +95,7 @@ func (mw *msgWriterState) ensureFlate() { } } - if stateless { - mw.dict.init(8192) - } else { - if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) - } - } - + mw.dict.init(8192) mw.flate = true } @@ -163,13 +152,6 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { return nil } -func (mw *msgWriterState) putFlateWriter() { - if mw.flateWriter != nil { - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil - } -} - // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") @@ -186,15 +168,12 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { } if mw.flate { - if stateless { - err = kflate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) - if err != nil { - return 0, err - } - mw.dict.write(p) - return len(p), nil + err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err } - return mw.flateWriter.Write(p) + mw.dict.write(p) + return len(p), nil } return mw.write(p) @@ -216,13 +195,6 @@ func (mw *msgWriterState) Close() (err error) { mw.writeMu.Lock() defer mw.writeMu.Unlock() - if mw.flate && !stateless { - err = mw.flateWriter.Flush() - if err != nil { - return xerrors.Errorf("failed to flush flate: %w", err) - } - } - _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) @@ -230,7 +202,6 @@ func (mw *msgWriterState) Close() (err error) { if mw.flate && !mw.flateContextTakeover() { mw.dict.close() - mw.putFlateWriter() } mw.mu.Unlock() return nil @@ -238,7 +209,6 @@ func (mw *msgWriterState) Close() (err error) { func (mw *msgWriterState) close() { mw.writeMu.Lock() - mw.putFlateWriter() mw.dict.close() } @@ -311,14 +281,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return n, nil } -func (c *Conn) writeFramePayload(p []byte) (_ int, err error) { +func (c *Conn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { return c.bw.Write(p) } - var n int maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush.