diff --git a/conn_test.go b/conn_test.go index 5abc9f46ec85f2f23ce8a2b5dcf8937a4cafcf6f..b2a35af822b4d1543eb7fa395427c7c4f017cd76 100644 --- a/conn_test.go +++ b/conn_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" "golang.org/x/xerrors" @@ -37,153 +36,86 @@ func TestConn(t *testing.T) { for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - copts := &websocket.CompressionOptions{ + dialCopts := &websocket.CompressionOptions{ Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), Threshold: xrand.Int(9999), } - c1, c2, err := wstest.Pipe(&websocket.DialOptions{ - CompressionOptions: copts, - }, &websocket.AcceptOptions{ - CompressionOptions: copts, - }) - if err != nil { - t.Fatal(err) + acceptCopts := &websocket.CompressionOptions{ + Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), + Threshold: xrand.Int(9999), } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c2) - return assertCloseStatus(websocket.StatusNormalClosure, err) + c1, c2 := tt.pipe(&websocket.DialOptions{ + CompressionOptions: dialCopts, + }, &websocket.AcceptOptions{ + CompressionOptions: acceptCopts, }) - defer func() { - err := <-echoLoopErr - if err != nil { - t.Errorf("echo loop error: %v", err) - } - }() - defer cancel() + + tt.goEchoLoop(c2) c1.SetReadLimit(131072) for i := 0; i < 5; i++ { - err := wstest.Echo(ctx, c1, 131072) - if err != nil { - t.Fatal(err) - } + err := wstest.Echo(tt.ctx, c1, 131072) + tt.success(err) } - err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(websocket.StatusNormalClosure, "") + tt.success(err) }) } }) t.Run("badClose", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c1.Close(websocket.StatusInternalError, "") - defer c2.Close(websocket.StatusInternalError, "") + c1, _ := tt.pipe(nil, nil) - err = c1.Close(-1, "") - if !cmp.ErrorContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(-1, "") + tt.errContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") }) t.Run("ping", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c1.Close(websocket.StatusInternalError, "") - defer c2.Close(websocket.StatusInternalError, "") + c1, c2 := tt.pipe(nil, nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() - - c2.CloseRead(ctx) - c1.CloseRead(ctx) + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) for i := 0; i < 10; i++ { - err = c1.Ping(ctx) - if err != nil { - t.Fatal(err) - } + err := c1.Ping(tt.ctx) + tt.success(err) } - err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(websocket.StatusNormalClosure, "") + tt.success(err) }) t.Run("badPing", func(t *testing.T) { - t.Parallel() - - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c1.Close(websocket.StatusInternalError, "") - defer c2.Close(websocket.StatusInternalError, "") + tt := newTest(t) + defer tt.done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - c2.CloseRead(ctx) + c2.CloseRead(tt.ctx) - err = c1.Ping(ctx) - if !cmp.ErrorContains(err, "failed to wait for pong") { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Ping(tt.ctx) + tt.errContains(err, "failed to wait for pong") }) t.Run("concurrentWrite", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() - - discardLoopErr := xsync.Go(func() error { - for { - _, _, err := c2.Read(ctx) - if websocket.CloseStatus(err) == websocket.StatusNormalClosure { - return nil - } - if err != nil { - return err - } - } - }) - defer func() { - err := <-discardLoopErr - if err != nil { - t.Errorf("discard loop error: %v", err) - } - }() - defer cancel() + c1, c2 := tt.pipe(nil, nil) + tt.goDiscardLoop(c2) msg := xrand.Bytes(xrand.Int(9999)) const count = 100 @@ -191,74 +123,52 @@ func TestConn(t *testing.T) { for i := 0; i < count; i++ { go func() { - errs <- c1.Write(ctx, websocket.MessageBinary, msg) + errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg) }() } for i := 0; i < count; i++ { err := <-errs - if err != nil { - t.Fatal(err) - } + tt.success(err) } - err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(websocket.StatusNormalClosure, "") + tt.success(err) }) t.Run("concurrentWriteError", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") + c1, _ := tt.pipe(nil, nil) - _, err = c1.Writer(context.Background(), websocket.MessageText) - if err != nil { - t.Fatal(err) - } + _, err := c1.Writer(tt.ctx, websocket.MessageText) + tt.success(err) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) - if !xerrors.Is(err, context.DeadlineExceeded) { - t.Fatal(err) - } + tt.eq(context.DeadlineExceeded, err) }) t.Run("netConn", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) - n2 := websocket.NetConn(ctx, c2, websocket.MessageBinary) + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) // Does not give any confidence but at least ensures no crashes. - d, _ := ctx.Deadline() + d, _ := tt.ctx.Deadline() n1.SetDeadline(d) n1.SetDeadline(time.Time{}) - if n1.RemoteAddr() != n1.LocalAddr() { - t.Fatal() - } - if n1.RemoteAddr().String() != "websocket/unknown-addr" || n1.RemoteAddr().Network() != "websocket" { - t.Fatal(n1.RemoteAddr()) - } + tt.eq(n1.RemoteAddr(), n1.LocalAddr()) + tt.eq("websocket/unknown-addr", n1.RemoteAddr().String()) + tt.eq("websocket", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -269,40 +179,25 @@ func TestConn(t *testing.T) { }) b, err := ioutil.ReadAll(n1) - if err != nil { - t.Fatal(err) - } + tt.success(err) _, err = n1.Read(nil) - if err != io.EOF { - t.Fatalf("expected EOF: %v", err) - } + tt.eq(err, io.EOF) err = <-errs - if err != nil { - t.Fatal(err) - } + tt.success(err) - if !cmp.Equal([]byte("hello"), b) { - t.Fatalf("unexpected msg: %v", cmp.Diff([]byte("hello"), b)) - } + tt.eq([]byte("hello"), b) }) t.Run("netConn/BadMsg", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) - n2 := websocket.NetConn(ctx, c2, websocket.MessageText) + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -312,114 +207,60 @@ func TestConn(t *testing.T) { return nil }) - _, err = ioutil.ReadAll(n1) - if !cmp.ErrorContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) { - t.Fatal(err) - } + _, err := ioutil.ReadAll(n1) + tt.errContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) err = <-errs - if err != nil { - t.Fatal(err) - } + tt.success(err) }) t.Run("wsjson", func(t *testing.T) { - t.Parallel() - - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") + tt := newTest(t) + defer tt.done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c2) - return assertCloseStatus(websocket.StatusNormalClosure, err) - }) - defer func() { - err := <-echoLoopErr - if err != nil { - t.Errorf("echo loop error: %v", err) - } - }() - defer cancel() + tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { - return wsjson.Write(ctx, c1, exp) + return wsjson.Write(tt.ctx, c1, exp) }) var act interface{} - err = wsjson.Read(ctx, c1, &act) - if err != nil { - t.Fatal(err) - } - if exp != act { - t.Fatal(cmp.Diff(exp, act)) - } + err := wsjson.Read(tt.ctx, c1, &act) + tt.success(err) + tt.eq(exp, act) err = <-werr - if err != nil { - t.Fatal(err) - } + tt.success(err) err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + tt.success(err) }) t.Run("wspb", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c2) - return assertCloseStatus(websocket.StatusNormalClosure, err) - }) - defer func() { - err := <-echoLoopErr - if err != nil { - t.Errorf("echo loop error: %v", err) - } - }() - defer cancel() + tt.goEchoLoop(c2) exp := ptypes.DurationProto(100) - err = wspb.Write(ctx, c1, exp) - if err != nil { - t.Fatal(err) - } + err := wspb.Write(tt.ctx, c1, exp) + tt.success(err) act := &duration.Duration{} - err = wspb.Read(ctx, c1, act) - if err != nil { - t.Fatal(err) - } - if !proto.Equal(exp, act) { - t.Fatal(cmp.Diff(exp, act)) - } + err = wspb.Read(tt.ctx, c1, act) + tt.success(err) + tt.eq(exp, act) err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + tt.success(err) }) } @@ -443,7 +284,7 @@ func TestWasm(t *testing.T) { err = wstest.EchoLoop(r.Context(), c) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - t.Errorf("echoLoop: %v", err) + t.Errorf("echoLoop failed: %v", err) } })) defer wg.Wait() @@ -470,3 +311,103 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { } return nil } + +type test struct { + t *testing.T + ctx context.Context + + doneFuncs []func() +} + +func newTest(t *testing.T) *test { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + tt := &test{t: t, ctx: ctx} + tt.appendDone(cancel) + return tt +} + +func (tt *test) appendDone(f func()) { + tt.doneFuncs = append(tt.doneFuncs, f) +} + +func (tt *test) done() { + for i := len(tt.doneFuncs) - 1; i >= 0; i-- { + tt.doneFuncs[i]() + } +} + +func (tt *test) goEchoLoop(c *websocket.Conn) { + ctx, cancel := context.WithCancel(tt.ctx) + + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c) + return assertCloseStatus(websocket.StatusNormalClosure, err) + }) + tt.appendDone(func() { + cancel() + err := <-echoLoopErr + if err != nil { + tt.t.Errorf("echo loop error: %v", err) + } + }) +} + +func (tt *test) goDiscardLoop(c *websocket.Conn) { + ctx, cancel := context.WithCancel(tt.ctx) + + discardLoopErr := xsync.Go(func() error { + for { + _, _, err := c.Read(ctx) + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + return nil + } + if err != nil { + return err + } + } + }) + tt.appendDone(func() { + cancel() + err := <-discardLoopErr + if err != nil { + tt.t.Errorf("discard loop error: %v", err) + } + }) +} + +func (tt *test) pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (c1, c2 *websocket.Conn) { + tt.t.Helper() + + c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) + if err != nil { + tt.t.Fatal(err) + } + tt.appendDone(func() { + c2.Close(websocket.StatusInternalError, "") + c1.Close(websocket.StatusInternalError, "") + }) + return c1, c2 +} + +func (tt *test) success(err error) { + tt.t.Helper() + if err != nil { + tt.t.Fatal(err) + } +} + +func (tt *test) errContains(err error, sub string) { + tt.t.Helper() + if !cmp.ErrorContains(err, sub) { + tt.t.Fatalf("error does not contain %q: %v", sub, err) + } +} + +func (tt *test) eq(exp, act interface{}) { + tt.t.Helper() + if !cmp.Equal(exp, act) { + tt.t.Fatalf(cmp.Diff(exp, act)) + } +} diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go index cdbadf70a72a5133cec2a7232b7bb1f14e8519f8..6f3dd70675fe5d4ea1ae24e4a9924cecc5d7eb86 100644 --- a/internal/test/cmp/cmp.go +++ b/internal/test/cmp/cmp.go @@ -4,6 +4,7 @@ import ( "reflect" "strings" + "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) @@ -12,7 +13,7 @@ import ( func Equal(v1, v2 interface{}) bool { return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { return true - })) + }), cmp.Comparer(proto.Equal)) } // Diff returns a human readable diff between v1 and v2