diff --git a/websocket_test.go b/websocket_test.go index 94f61029ebee4d1babb453d2e599bc24720fb63a..f4073bce64d7fbfba80c06b409720b13c45df76c 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -702,99 +702,104 @@ func checkWSTestIndex(t *testing.T, path string) { } } -func benchConn(b *testing.B, echo, stream bool) { - name := "buffered" - if stream { - name = "stream" +func benchConn(b *testing.B, echo, stream bool, size int) { + s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + b.Logf("server handshake failed: %+v", err) + return + } + if echo { + echoLoop(r.Context(), c) + } else { + discardLoop(r.Context(), c) + } + })) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) + if err != nil { + b.Fatalf("failed to dial: %v", err) } + defer c.Close(websocket.StatusInternalError, "") - b.Run(name, func(b *testing.B) { - s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + msg := []byte(strings.Repeat("2", size)) + buf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if stream { + w, err := c.Writer(ctx, websocket.MessageText) if err != nil { - b.Logf("server handshake failed: %+v", err) - return + b.Fatal(err) } - if echo { - echoLoop(r.Context(), c) - } else { - discardLoop(r.Context(), c) + + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) } - })) - defer closeFn() - wsURL := strings.Replace(s.URL, "http", "ws", 1) + err = w.Close() + if err != nil { + b.Fatal(err) + } + } else { + err = c.Write(ctx, websocket.MessageText, msg) + if err != nil { + b.Fatal(err) + } + } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() + if echo { + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err) + } - c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) - if err != nil { - b.Fatalf("failed to dial: %v", err) + _, err = io.ReadFull(r, buf) + if err != nil { + b.Fatal(err) + } } - defer c.Close(websocket.StatusInternalError, "") + } + b.StopTimer() - sizes := []int{ - 2, - 512, - 4096, - 16384, - } + c.Close(websocket.StatusNormalClosure, "") +} +func BenchmarkConn(b *testing.B) { + sizes := []int{ + 2, + 32, + 512, + 4096, + 16384, + } + + b.Run("write", func(b *testing.B) { for _, size := range sizes { - msg := []byte(strings.Repeat("2", size)) - buf := make([]byte, len(msg)) b.Run(strconv.Itoa(size), func(b *testing.B) { - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } - - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } - - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) - } - - _, err = io.ReadFull(r, buf) - if err != nil { - b.Fatal(err) - } - } - } + b.Run("stream", func(b *testing.B) { + benchConn(b, false, true, size) + }) + b.Run("buffer", func(b *testing.B) { + benchConn(b, false, false, size) + }) }) } - - c.Close(websocket.StatusNormalClosure, "") }) -} -func BenchmarkConn(b *testing.B) { - b.Run("write", func(b *testing.B) { - benchConn(b, false, false) - benchConn(b, false, true) - }) b.Run("echo", func(b *testing.B) { - benchConn(b, true, true) + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchConn(b, true, true, size) + }) + } }) }