diff --git a/websocket_bench_test.go b/websocket_bench_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4ad8646cbb4cb9fad6463f80e32c3c2925c02cd2 --- /dev/null +++ b/websocket_bench_test.go @@ -0,0 +1,146 @@ +package websocket_test + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "nhooyr.io/websocket" + "strconv" + "strings" + "testing" + "time" +) + + +func BenchmarkConn(b *testing.B) { + sizes := []int{ + 2, + 16, + 32, + 512, + 4096, + 16384, + } + + b.Run("write", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + b.Run("stream", func(b *testing.B) { + benchConn(b, false, true, size) + }) + b.Run("buffer", func(b *testing.B) { + benchConn(b, false, false, size) + }) + }) + } + }) + + b.Run("echo", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchConn(b, true, true, size) + }) + } + }) +} + +func benchConn(b *testing.B, echo, stream bool, size int) { + s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + if echo { + echoLoop(r.Context(), c) + } else { + discardLoop(r.Context(), c) + } + return nil + }, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + b.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + msg := []byte(strings.Repeat("2", size)) + readBuf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if stream { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } + + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) + } + + err = w.Close() + if err != nil { + b.Fatal(err) + } + } else { + err = c.Write(ctx, websocket.MessageText, msg) + if err != nil { + b.Fatal(err) + } + } + + if echo { + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err) + } + + _, err = io.ReadFull(r, readBuf) + if err != nil { + b.Fatal(err) + } + } + } + b.StopTimer() + + c.Close(websocket.StatusNormalClosure, "") +} + + +func discardLoop(ctx context.Context, c *websocket.Conn) { + defer c.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32768) + echo := func() error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + + _, err = io.CopyBuffer(ioutil.Discard, r, b) + if err != nil { + return err + } + return nil + } + + for { + err := echo() + if err != nil { + return + } + } +} diff --git a/websocket_test.go b/websocket_test.go index 906014ca3e5692b87bf2dc1e49eff79f777a667b..732fc94cb6a8922786475fa4e068f11dbbb81a3b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -1092,6 +1092,8 @@ func TestAutobahn(t *testing.T) { // Section 2. t.Run("pingPong", func(t *testing.T) { + t.Parallel() + run(t, "emptyPayload", func(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) return c.PingWithPayload(ctx, "") @@ -1197,40 +1199,17 @@ func TestAutobahn(t *testing.T) { }) }) -} - -func BenchmarkConn(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } - - b.Run("write", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("stream", func(b *testing.B) { - benchConn(b, false, true, size) - }) - b.Run("buffer", func(b *testing.B) { - benchConn(b, false, false, size) - }) - }) - } - }) + // Section 3. + t.Run("reserved", func(t *testing.T) { + t.Parallel() - b.Run("echo", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) - }) - } + run(t, "rsv1", func(ctx context.Context, c *websocket.Conn) error { + c.WriteFrame() + }) }) } + func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") @@ -1272,105 +1251,6 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { } } -func discardLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32768) - echo := func() error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - - _, err = io.CopyBuffer(ioutil.Discard, r, b) - if err != nil { - return err - } - return nil - } - - for { - err := echo() - if err != nil { - return - } - } -} - -func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - if echo { - echoLoop(r.Context(), c) - } else { - discardLoop(r.Context(), c) - } - return nil - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - - c, _, err := websocket.Dial(ctx, wsURL, nil) - if err != nil { - b.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - msg := []byte(strings.Repeat("2", size)) - readBuf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } - - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } - - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) - } - - _, err = io.ReadFull(r, readBuf) - if err != nil { - b.Fatal(err) - } - } - } - b.StopTimer() - - c.Close(websocket.StatusNormalClosure, "") -} - func assertCloseStatus(err error, code websocket.StatusCode) error { var cerr websocket.CloseError if !xerrors.As(err, &cerr) {