diff --git a/bench_test.go b/bench_test.go deleted file mode 100644 index 6efbf484e6b2b3f9b0b4e56116b30fd734fb633a..0000000000000000000000000000000000000000 --- a/bench_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package websocket_test - -import ( - "context" - "io" - "net/http" - "strconv" - "strings" - "testing" - "time" - - "nhooyr.io/websocket" -) - -func benchConn(b *testing.B, stream bool) { - name := "buffered" - if stream { - name = "stream" - } - - b.Run(name, func(b *testing.B) { - s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - b.Logf("server handshake failed: %+v", err) - return - } - if stream { - streamEchoLoop(r.Context(), c) - } else { - bufferedEchoLoop(r.Context(), c) - } - - })) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - - c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) - if err != nil { - b.Fatalf("failed to dial: %v", err) - } - defer c.Close(websocket.StatusInternalError, "") - - runN := func(n int) { - msg := []byte(strings.Repeat("2", n)) - buf := make([]byte, len(msg)) - b.Run(strconv.Itoa(n), func(b *testing.B) { - b.SetBytes(int64(len(msg))) - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } - - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err, b.N) - } - - _, err = io.ReadFull(r, buf) - if err != nil { - b.Fatal(err) - } - } - }) - } - - runN(32) - runN(128) - runN(512) - runN(1024) - runN(4096) - runN(16384) - runN(65536) - runN(131072) - - c.Close(websocket.StatusNormalClosure, "") - }) -} - -func BenchmarkConn(b *testing.B) { - benchConn(b, false) - benchConn(b, true) -} diff --git a/export_test.go b/export_test.go index 4eae5d63e78a8a5b66352fc29de1afef11f700bb..d180e119cac2fe896220300d3c4f71afcd13f63e 100644 --- a/export_test.go +++ b/export_test.go @@ -8,11 +8,11 @@ import ( // method for when the entire message is in memory and does not need to be streamed // to the peer via Writer. // -// Both paths are zero allocation but Writer always has -// to write an additional fin frame when Close is called on the writer which -// can result in worse performance if the full message exceeds the buffer size -// which is 4096 right now as then two syscalls will be necessary to complete the message. -// TODO this is no good as we cannot write data frame msg in between other ones +// This prevents the allocation of the Writer. +// Furthermore Writer always has to write an additional fin frame when Close is +// called on the writer which can result in worse performance if the full message +// exceeds the buffer size which is 4096 right now as then an extra syscall +// will be necessary to complete the message. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - return c.writeControl(ctx, opcode(typ), p) + return c.writeSingleFrame(ctx, opcode(typ), p) } diff --git a/websocket.go b/websocket.go index 8688509a7facf58f95616d75015d72dc73e39c94..275af9da72d3be94438701401bb17e189148bbd0 100644 --- a/websocket.go +++ b/websocket.go @@ -13,7 +13,7 @@ import ( "golang.org/x/xerrors" ) -type control struct { +type frame struct { opcode opcode payload []byte } @@ -42,7 +42,8 @@ type Conn struct { // ping on writeDone. // writeDone will be closed if the data message write errors. write chan MessageType - control chan control + control chan frame + fastWrite chan frame writeBytes chan []byte writeDone chan struct{} writeFlush chan struct{} @@ -86,7 +87,8 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.write = make(chan MessageType) - c.control = make(chan control) + c.control = make(chan frame) + c.fastWrite = make(chan frame) c.writeBytes = make(chan []byte) c.writeDone = make(chan struct{}) c.writeFlush = make(chan struct{}) @@ -103,6 +105,8 @@ func (c *Conn) init() { go c.readLoop() } +// We never mask inside here because our mask key is always 0,0,0,0. +// See comment on secWebSocketKey. func (c *Conn) writeFrame(h header, p []byte) { b2 := marshalHeader(h) _, err := c.bw.Write(b2) @@ -126,14 +130,14 @@ func (c *Conn) writeFrame(h header, p []byte) { } } -func (c *Conn) writeLoopControl(control control) { +func (c *Conn) writeLoopFastWrite(frame frame) { h := header{ fin: true, - opcode: control.opcode, - payloadLength: int64(len(control.payload)), + opcode: frame.opcode, + payloadLength: int64(len(frame.payload)), masked: c.client, } - c.writeFrame(h, control.payload) + c.writeFrame(h, frame.payload) select { case <-c.closed: case c.writeDone <- struct{}{}: @@ -150,7 +154,11 @@ messageLoop: case <-c.closed: return case control := <-c.control: - c.writeLoopControl(control) + c.writeLoopFastWrite(control) + continue + case frame := <-c.fastWrite: + c.writeLoopFastWrite(frame) + continue case dataType = <-c.write: } @@ -160,7 +168,7 @@ messageLoop: case <-c.closed: return case control := <-c.control: - c.writeLoopControl(control) + c.writeLoopFastWrite(control) case b := <-c.writeBytes: h := header{ fin: false, @@ -220,7 +228,7 @@ func (c *Conn) handleControl(h header) { } if h.masked { - xor(h.maskKey, 0, b) + fastXOR(h.maskKey, 0, b) } switch h.opcode { @@ -314,7 +322,7 @@ func (c *Conn) dataReadLoop(h header) (err error) { left -= int64(len(b)) if h.masked { - maskPos = xor(h.maskKey, maskPos, b) + maskPos = fastXOR(h.maskKey, maskPos, b) } // Must set this before we signal the read is done. @@ -341,7 +349,7 @@ func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opPong, p) + err := c.writeSingleFrame(ctx, opPong, p) return err } @@ -384,7 +392,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opClose, p) + err := c.writeSingleFrame(ctx, opClose, p) c.close(cerr) @@ -399,11 +407,15 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { return nil } -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { +func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) error { + ch := c.fastWrite + if opcode.controlOp() { + ch = c.control + } select { case <-c.closed: return c.closeErr - case c.control <- control{ + case ch <- frame{ opcode: opcode, payload: p, }: diff --git a/websocket_test.go b/websocket_test.go index 2df8c946de05beff7fbc2f2e41d4b34d09a2c679..f4073bce64d7fbfba80c06b409720b13c45df76c 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -448,7 +448,7 @@ func TestAutobahnServer(t *testing.T) { t.Logf("server handshake failed: %+v", err) return } - streamEchoLoop(r.Context(), c) + echoLoop(r.Context(), c) })) defer s.Close() @@ -495,7 +495,7 @@ func TestAutobahnServer(t *testing.T) { checkWSTestIndex(t, "./wstest_reports/server/index.json") } -func streamEchoLoop(ctx context.Context, c *websocket.Conn) { +func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") ctx, cancel := context.WithTimeout(ctx, time.Minute) @@ -534,25 +534,24 @@ func streamEchoLoop(ctx context.Context, c *websocket.Conn) { } } -func bufferedEchoLoop(ctx context.Context, c *websocket.Conn) { +func discardLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - b := make([]byte, 131072+2) + b := make([]byte, 32768) echo := func() error { - typ, r, err := c.Reader(ctx) + _, r, err := c.Reader(ctx) if err != nil { return err } - n, err := io.ReadFull(r, b) - if err != io.ErrUnexpectedEOF { + _, err = io.CopyBuffer(ioutil.Discard, r, b) + if err != nil { return err } - - return c.Write(ctx, typ, b[:n]) + return nil } for { @@ -647,7 +646,7 @@ func TestAutobahnClient(t *testing.T) { if err != nil { t.Fatalf("failed to dial: %v", err) } - streamEchoLoop(ctx, c) + echoLoop(ctx, c) }() } @@ -702,3 +701,105 @@ func checkWSTestIndex(t *testing.T, path string) { } } } + +func benchConn(b *testing.B, echo, stream bool, size int) { + s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + b.Logf("server handshake failed: %+v", err) + return + } + if echo { + echoLoop(r.Context(), c) + } else { + discardLoop(r.Context(), c) + } + })) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) + if err != nil { + b.Fatalf("failed to dial: %v", err) + } + defer c.Close(websocket.StatusInternalError, "") + + msg := []byte(strings.Repeat("2", size)) + buf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if stream { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } + + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) + } + + err = w.Close() + if err != nil { + b.Fatal(err) + } + } else { + err = c.Write(ctx, websocket.MessageText, msg) + if err != nil { + b.Fatal(err) + } + } + + if echo { + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err) + } + + _, err = io.ReadFull(r, buf) + if err != nil { + b.Fatal(err) + } + } + } + b.StopTimer() + + c.Close(websocket.StatusNormalClosure, "") +} + +func BenchmarkConn(b *testing.B) { + sizes := []int{ + 2, + 32, + 512, + 4096, + 16384, + } + + b.Run("write", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + b.Run("stream", func(b *testing.B) { + benchConn(b, false, true, size) + }) + b.Run("buffer", func(b *testing.B) { + benchConn(b, false, false, size) + }) + }) + } + }) + + b.Run("echo", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchConn(b, true, true, size) + }) + } + }) +} diff --git a/xor.go b/xor.go index 1422f847d737248b82d8247a932af23f708575fc..5a68e81d990b2782cf886658a93006a5e8087df6 100644 --- a/xor.go +++ b/xor.go @@ -12,7 +12,7 @@ import ( // The returned value is the position of the next byte // to be used for masking in the key. This is so that // unmasking can be performed without the entire frame. -func xor(key [4]byte, keyPos int, b []byte) int { +func fastXOR(key [4]byte, keyPos int, b []byte) int { // If the payload is greater than 16 bytes, then it's worth // masking 8 bytes at a time. // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 diff --git a/xor_test.go b/xor_test.go index f715eda14c3e15d24c5e2992d13e2e9c41ea91ce..c3adaf580a499bb02891ac6da38143d19013ce3e 100644 --- a/xor_test.go +++ b/xor_test.go @@ -1,6 +1,8 @@ package websocket import ( + "crypto/rand" + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -12,7 +14,7 @@ func Test_xor(t *testing.T) { key := [4]byte{0xa, 0xb, 0xc, 0xff} p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} pos := 0 - pos = xor(key, pos, p) + pos = fastXOR(key, pos, p) if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) @@ -22,3 +24,58 @@ func Test_xor(t *testing.T) { t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos)) } } + +func basixXOR(maskKey [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} + +func BenchmarkXOR(b *testing.B) { + sizes := []int{ + 2, + 32, + 512, + 4096, + 16384, + } + + fns := []struct { + name string + fn func([4]byte, int, []byte) int + }{ + { + "basic", + basixXOR, + }, + { + "fast", + fastXOR, + }, + } + + var maskKey [4]byte + _, err := rand.Read(maskKey[:]) + if err != nil { + b.Fatalf("failed to populate mask key: %v", err) + } + + for _, size := range sizes { + data := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + fn.fn(maskKey, 0, data) + } + }) + } + }) + } +}