diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6961e5c894a84fe051b8aa56f11e25a80de13f2f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +websocket.test diff --git a/autobahn_test.go b/autobahn_test.go index 0763bc9713e50046061605bd53b6493c18b6788d..fb24a06bb9e4058db4103d20a72100b39c1d265d 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -59,8 +59,6 @@ func TestAutobahn(t *testing.T) { for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() diff --git a/compress_notjs.go b/compress_notjs.go index 2076136241250f3d8328669f91bce0e00269b5d3..7c6b2fc013041efec8b7353de38001c2b8c59cf6 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -118,6 +118,7 @@ type slidingWindow struct { buf []byte } +var swPoolMu sync.Mutex var swPool = map[int]*sync.Pool{} func (sw *slidingWindow) init(n int) { @@ -125,6 +126,9 @@ func (sw *slidingWindow) init(n int) { return } + swPoolMu.Lock() + defer swPoolMu.Unlock() + p, ok := swPool[n] if !ok { p = &sync.Pool{} @@ -143,6 +147,9 @@ func (sw *slidingWindow) close() { return } + swPoolMu.Lock() + defer swPoolMu.Unlock() + swPool[cap(sw.buf)].Put(sw.buf) sw.buf = nil } diff --git a/conn_notjs.go b/conn_notjs.go index 4d8762bfb7240a5c9bb208caceace344cad43ac6..178fcad02a61574b3e1a519ddaaebd39742854df 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -39,6 +39,7 @@ type Conn struct { // Read state. readMu *mu + readHeader header readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error diff --git a/conn_test.go b/conn_test.go index 25b0809d3d0fa35a4b18d13009b85fd0b678def8..265156e970cd1cf0c753706580c532f98d8b19f1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,9 @@ package websocket_test import ( + "bytes" "context" + "crypto/rand" "fmt" "io" "io/ioutil" @@ -48,7 +50,7 @@ func TestConn(t *testing.T) { }, &websocket.AcceptOptions{ CompressionOptions: copts(), }) - defer tt.done() + defer tt.cleanup() tt.goEchoLoop(c2) @@ -67,7 +69,7 @@ func TestConn(t *testing.T) { t.Run("badClose", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") @@ -75,7 +77,7 @@ func TestConn(t *testing.T) { t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) @@ -91,7 +93,7 @@ func TestConn(t *testing.T) { t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() c2.CloseRead(tt.ctx) @@ -104,7 +106,7 @@ func TestConn(t *testing.T) { t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() tt.goDiscardLoop(c2) @@ -129,7 +131,7 @@ func TestConn(t *testing.T) { t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) @@ -143,7 +145,7 @@ func TestConn(t *testing.T) { t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) @@ -179,7 +181,7 @@ func TestConn(t *testing.T) { t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) @@ -201,7 +203,7 @@ func TestConn(t *testing.T) { t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() tt.goEchoLoop(c2) @@ -227,7 +229,7 @@ func TestConn(t *testing.T) { t.Run("wspb", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() tt.goEchoLoop(c2) @@ -297,14 +299,16 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { } type connTest struct { - t *testing.T + t testing.TB ctx context.Context doneFuncs []func() } -func newConnTest(t *testing.T, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { - t.Parallel() +func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { + if t, ok := t.(*testing.T); ok { + t.Parallel() + } t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) @@ -325,7 +329,7 @@ func (tt *connTest) appendDone(f func()) { tt.doneFuncs = append(tt.doneFuncs, f) } -func (tt *connTest) done() { +func (tt *connTest) cleanup() { for i := len(tt.doneFuncs) - 1; i >= 0; i-- { tt.doneFuncs[i]() } @@ -368,3 +372,95 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } }) } + +func BenchmarkConn(b *testing.B) { + var benchCases = []struct { + name string + mode websocket.CompressionMode + }{ + { + name: "compressionDisabled", + mode: websocket.CompressionDisabled, + }, + { + name: "compression", + mode: websocket.CompressionContextTakeover, + }, + { + name: "noContextCompression", + mode: websocket.CompressionNoContextTakeover, + }, + } + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ + CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, + }, nil) + defer bb.cleanup() + + bb.goEchoLoop(c2) + + const n = 32768 + writeBuf := make([]byte, n) + readBuf := make([]byte, n) + writes := make(chan websocket.MessageType) + defer close(writes) + werrs := make(chan error) + + go func() { + for typ := range writes { + werrs <- c1.Write(bb.ctx, typ, writeBuf) + } + }() + b.SetBytes(n) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rand.Reader.Read(writeBuf) + if err != nil { + b.Fatal(err) + } + + expType := websocket.MessageBinary + if writeBuf[0]%2 == 1 { + expType = websocket.MessageText + } + writes <- expType + + typ, r, err := c1.Reader(bb.ctx) + if err != nil { + b.Fatal(err) + } + if expType != typ { + assert.Equal(b, "data type", expType, typ) + } + + _, err = io.ReadFull(r, readBuf) + if err != nil { + b.Fatal(err) + } + + n2, err := r.Read(readBuf) + if err != io.EOF { + assert.Equal(b, "read err", io.EOF, err) + } + if n2 != 0 { + assert.Equal(b, "n2", 0, n2) + } + + if !bytes.Equal(writeBuf, readBuf) { + assert.Equal(b, "msg", writeBuf, readBuf) + } + + err = <-werrs + if err != nil { + b.Fatal(err) + } + } + b.StopTimer() + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(b, err) + }) + } +} diff --git a/frame.go b/frame.go index 0257835e3bc959df178fd309fc40881285d7df55..491ae75c33c5bf6ed4f0274191fa8a4b1ab20cff 100644 --- a/frame.go +++ b/frame.go @@ -46,15 +46,14 @@ type header struct { // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. -func readFrameHeader(r *bufio.Reader) (_ header, err error) { +func readFrameHeader(h *header, r *bufio.Reader) (err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { - return header{}, err + return err } - var h header h.fin = b&(1<<7) != 0 h.rsv1 = b&(1<<6) != 0 h.rsv2 = b&(1<<5) != 0 @@ -64,7 +63,7 @@ func readFrameHeader(r *bufio.Reader) (_ header, err error) { b, err = r.ReadByte() if err != nil { - return header{}, err + return err } h.masked = b&(1<<7) != 0 @@ -81,17 +80,17 @@ func readFrameHeader(r *bufio.Reader) (_ header, err error) { err = binary.Read(r, binary.BigEndian, &h.payloadLength) } if err != nil { - return header{}, err + return err } if h.masked { err = binary.Read(r, binary.LittleEndian, &h.maskKey) if err != nil { - return header{}, err + return err } } - return h, nil + return nil } // maxControlPayload is the maximum length of a control frame payload. diff --git a/frame_test.go b/frame_test.go index 8745da0ba208cb4414a7af151e53d17df813112e..38f1599a890cb52bda60a5d194c15e1449e5dbac 100644 --- a/frame_test.go +++ b/frame_test.go @@ -86,7 +86,8 @@ func testHeader(t *testing.T, h header) { err = w.Flush() assert.Success(t, err) - h2, err := readFrameHeader(r) + var h2 header + err = readFrameHeader(&h2, r) assert.Success(t, err) assert.Equal(t, "read header", h, h2) diff --git a/internal/xsync/go.go b/internal/xsync/go.go index 96cf81039715c8cc1ff54d5259ca46031d91cbc5..d88ac622c5ddd0ebb202a1c5759aa098a14a3c1c 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -6,7 +6,7 @@ import ( // Go allows running a function in another goroutine // and waiting for its error. -func Go(fn func() error) chan error { +func Go(fn func() error) <- chan error { errs := make(chan error, 1) go func() { defer func() { diff --git a/read.go b/read.go index dd73ac92530f768e688a78f63cc1b68505746614..bf7fa6d928835a98b50a43e3f2765406e04e1201 100644 --- a/read.go +++ b/read.go @@ -173,7 +173,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br) + err := readFrameHeader(&c.readHeader, c.br) if err != nil { select { case <-c.closed: @@ -192,7 +192,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- context.Background(): } - return h, nil + return c.readHeader, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { @@ -390,6 +390,8 @@ func (mr *msgReader) read(p []byte) (int, error) { return 0, err } mr.setFrame(h) + + return mr.read(p) } if int64(len(p)) > mr.payloadLength {