diff --git a/bench_test.go b/bench_test.go index 6efbf484e6b2b3f9b0b4e56116b30fd734fb633a..e9cb4fe4cfd6232d5554e962990f8d767a7ce69a 100644 --- a/bench_test.go +++ b/bench_test.go @@ -12,7 +12,7 @@ import ( "nhooyr.io/websocket" ) -func benchConn(b *testing.B, stream bool) { +func benchConn(b *testing.B, echo, stream bool) { name := "buffered" if stream { name = "stream" @@ -25,12 +25,11 @@ func benchConn(b *testing.B, stream bool) { b.Logf("server handshake failed: %+v", err) return } - if stream { - streamEchoLoop(r.Context(), c) + if echo { + echoLoop(r.Context(), c) } else { - bufferedEchoLoop(r.Context(), c) + discardLoop(r.Context(), c) } - })) defer closeFn() @@ -50,6 +49,7 @@ func benchConn(b *testing.B, stream bool) { buf := make([]byte, len(msg)) b.Run(strconv.Itoa(n), func(b *testing.B) { b.SetBytes(int64(len(msg))) + b.ReportAllocs() for i := 0; i < b.N; i++ { if stream { w, err := c.Writer(ctx, websocket.MessageText) @@ -72,14 +72,17 @@ func benchConn(b *testing.B, stream bool) { b.Fatal(err) } } - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err, b.N) - } - _, err = io.ReadFull(r, buf) - if err != nil { - b.Fatal(err) + if echo { + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err) + } + + _, err = io.ReadFull(r, buf) + if err != nil { + b.Fatal(err) + } } } }) @@ -99,6 +102,11 @@ func benchConn(b *testing.B, stream bool) { } func BenchmarkConn(b *testing.B) { - benchConn(b, false) - benchConn(b, true) + b.Run("write", func(b *testing.B) { + benchConn(b, false, false) + benchConn(b, false, true) + }) + b.Run("echo", func(b *testing.B) { + benchConn(b, true, true) + }) } diff --git a/export_test.go b/export_test.go index 4eae5d63e78a8a5b66352fc29de1afef11f700bb..d180e119cac2fe896220300d3c4f71afcd13f63e 100644 --- a/export_test.go +++ b/export_test.go @@ -8,11 +8,11 @@ import ( // method for when the entire message is in memory and does not need to be streamed // to the peer via Writer. // -// Both paths are zero allocation but Writer always has -// to write an additional fin frame when Close is called on the writer which -// can result in worse performance if the full message exceeds the buffer size -// which is 4096 right now as then two syscalls will be necessary to complete the message. -// TODO this is no good as we cannot write data frame msg in between other ones +// This prevents the allocation of the Writer. +// Furthermore Writer always has to write an additional fin frame when Close is +// called on the writer which can result in worse performance if the full message +// exceeds the buffer size which is 4096 right now as then an extra syscall +// will be necessary to complete the message. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - return c.writeControl(ctx, opcode(typ), p) + return c.writeSingleFrame(ctx, opcode(typ), p) } diff --git a/websocket.go b/websocket.go index 8688509a7facf58f95616d75015d72dc73e39c94..298ac8347de835dc88aae8e3e6b7d0147b104ce1 100644 --- a/websocket.go +++ b/websocket.go @@ -13,7 +13,7 @@ import ( "golang.org/x/xerrors" ) -type control struct { +type frame struct { opcode opcode payload []byte } @@ -42,7 +42,8 @@ type Conn struct { // ping on writeDone. // writeDone will be closed if the data message write errors. write chan MessageType - control chan control + control chan frame + fastWrite chan frame writeBytes chan []byte writeDone chan struct{} writeFlush chan struct{} @@ -86,7 +87,8 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.write = make(chan MessageType) - c.control = make(chan control) + c.control = make(chan frame) + c.fastWrite = make(chan frame) c.writeBytes = make(chan []byte) c.writeDone = make(chan struct{}) c.writeFlush = make(chan struct{}) @@ -103,6 +105,8 @@ func (c *Conn) init() { go c.readLoop() } +// We never mask inside here because our mask key is always 0,0,0,0. +// See comment on secWebSocketKey. func (c *Conn) writeFrame(h header, p []byte) { b2 := marshalHeader(h) _, err := c.bw.Write(b2) @@ -126,14 +130,14 @@ func (c *Conn) writeFrame(h header, p []byte) { } } -func (c *Conn) writeLoopControl(control control) { +func (c *Conn) writeLoopFastWrite(frame frame) { h := header{ fin: true, - opcode: control.opcode, - payloadLength: int64(len(control.payload)), + opcode: frame.opcode, + payloadLength: int64(len(frame.payload)), masked: c.client, } - c.writeFrame(h, control.payload) + c.writeFrame(h, frame.payload) select { case <-c.closed: case c.writeDone <- struct{}{}: @@ -150,7 +154,11 @@ messageLoop: case <-c.closed: return case control := <-c.control: - c.writeLoopControl(control) + c.writeLoopFastWrite(control) + continue + case frame := <-c.fastWrite: + c.writeLoopFastWrite(frame) + continue case dataType = <-c.write: } @@ -160,7 +168,7 @@ messageLoop: case <-c.closed: return case control := <-c.control: - c.writeLoopControl(control) + c.writeLoopFastWrite(control) case b := <-c.writeBytes: h := header{ fin: false, @@ -341,7 +349,7 @@ func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opPong, p) + err := c.writeSingleFrame(ctx, opPong, p) return err } @@ -384,7 +392,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opClose, p) + err := c.writeSingleFrame(ctx, opClose, p) c.close(cerr) @@ -399,11 +407,15 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { return nil } -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { +func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) error { + ch := c.fastWrite + if opcode.controlOp() { + ch = c.control + } select { case <-c.closed: return c.closeErr - case c.control <- control{ + case ch <- frame{ opcode: opcode, payload: p, }: diff --git a/websocket_test.go b/websocket_test.go index 2df8c946de05beff7fbc2f2e41d4b34d09a2c679..e4fa781b5c2879d0190a384f2f4b2b90322ee0e2 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -448,7 +448,7 @@ func TestAutobahnServer(t *testing.T) { t.Logf("server handshake failed: %+v", err) return } - streamEchoLoop(r.Context(), c) + echoLoop(r.Context(), c) })) defer s.Close() @@ -495,7 +495,7 @@ func TestAutobahnServer(t *testing.T) { checkWSTestIndex(t, "./wstest_reports/server/index.json") } -func streamEchoLoop(ctx context.Context, c *websocket.Conn) { +func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") ctx, cancel := context.WithTimeout(ctx, time.Minute) @@ -534,25 +534,24 @@ func streamEchoLoop(ctx context.Context, c *websocket.Conn) { } } -func bufferedEchoLoop(ctx context.Context, c *websocket.Conn) { +func discardLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - b := make([]byte, 131072+2) + b := make([]byte, 32768) echo := func() error { - typ, r, err := c.Reader(ctx) + _, r, err := c.Reader(ctx) if err != nil { return err } - n, err := io.ReadFull(r, b) - if err != io.ErrUnexpectedEOF { + _, err = io.CopyBuffer(ioutil.Discard, r, b) + if err != nil { return err } - - return c.Write(ctx, typ, b[:n]) + return nil } for { @@ -647,7 +646,7 @@ func TestAutobahnClient(t *testing.T) { if err != nil { t.Fatalf("failed to dial: %v", err) } - streamEchoLoop(ctx, c) + echoLoop(ctx, c) }() }