diff --git a/assert_test.go b/assert_test.go index e67ed539f5717ce579649e0dd2334a80908d3d1a..26fd1d486319f203da775d1c60a246a9c490281b 100644 --- a/assert_test.go +++ b/assert_test.go @@ -2,15 +2,12 @@ package websocket_test import ( "context" - "fmt" "math/rand" - "reflect" "strings" "time" - "github.com/google/go-cmp/cmp" - "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/wsjson" ) @@ -18,59 +15,6 @@ func init() { rand.Seed(time.Now().UnixNano()) } -// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 -func cmpDiff(exp, act interface{}) string { - return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) -} - -func deepAllowUnexported(vs ...interface{}) cmp.Option { - m := make(map[reflect.Type]struct{}) - for _, v := range vs { - structTypes(reflect.ValueOf(v), m) - } - var typs []interface{} - for t := range m { - typs = append(typs, reflect.New(t).Elem().Interface()) - } - return cmp.AllowUnexported(typs...) -} - -func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { - if !v.IsValid() { - return - } - switch v.Kind() { - case reflect.Ptr: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Interface: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Slice, reflect.Array: - for i := 0; i < v.Len(); i++ { - structTypes(v.Index(i), m) - } - case reflect.Map: - for _, k := range v.MapKeys() { - structTypes(v.MapIndex(k), m) - } - case reflect.Struct: - m[v.Type()] = struct{}{} - for i := 0; i < v.NumField(); i++ { - structTypes(v.Field(i), m) - } - } -} - -func assertEqualf(exp, act interface{}, f string, v ...interface{}) error { - if diff := cmpDiff(exp, act); diff != "" { - return fmt.Errorf(f+": %v", append(v, diff)...) - } - return nil -} - func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error { exp := randString(n) err := wsjson.Write(ctx, c, exp) @@ -84,7 +28,7 @@ func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error { return err } - return assertEqualf(exp, act, "unexpected JSON") + return assert.Equalf(exp, act, "unexpected JSON") } func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { @@ -94,7 +38,7 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) err return err } - return assertEqualf(exp, act, "unexpected JSON") + return assert.Equalf(exp, act, "unexpected JSON") } func randBytes(n int) []byte { @@ -126,13 +70,13 @@ func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageTyp if err != nil { return err } - err = assertEqualf(typ, typ2, "unexpected data type") + err = assert.Equalf(typ, typ2, "unexpected data type") if err != nil { return err } - return assertEqualf(p, p2, "unexpected payload") + return assert.Equalf(p, p2, "unexpected payload") } func assertSubprotocol(c *websocket.Conn, exp string) error { - return assertEqualf(exp, c.Subprotocol(), "unexpected subprotocol") + return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol") } diff --git a/conn_test.go b/conn_test.go index 12788c30c0f0847af19f3e11376ce8ebfc8ba15b..6ef77829b904531bccd4cfb14ffd67dd5ad683fe 100644 --- a/conn_test.go +++ b/conn_test.go @@ -32,6 +32,7 @@ import ( "go.uber.org/multierr" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/internal/wsecho" "nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wspb" @@ -127,7 +128,7 @@ func TestHandshake(t *testing.T) { if err != nil { return fmt.Errorf("request is missing mycookie: %w", err) } - err = assertEqualf("myvalue", cookie.Value, "unexpected cookie value") + err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value") if err != nil { return err } @@ -219,7 +220,7 @@ func TestConn(t *testing.T) { } for h, exp := range headers { value := resp.Header.Get(h) - err := assertEqualf(exp, value, "unexpected value for header %v", h) + err := assert.Equalf(exp, value, "unexpected value for header %v", h) if err != nil { return err } @@ -276,11 +277,11 @@ func TestConn(t *testing.T) { time.Sleep(1) nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - err := assertEqualf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr") + err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr") if err != nil { return err } - err = assertEqualf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr") + err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr") if err != nil { return err } @@ -310,13 +311,13 @@ func TestConn(t *testing.T) { // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. err2 := assertNetConnRead(nc, "hello") - err := assertEqualf(io.EOF, err2, "unexpected error") + err := assert.Equalf(io.EOF, err2, "unexpected error") if err != nil { return err } err2 = assertNetConnRead(nc, "hello") - return assertEqualf(io.EOF, err2, "unexpected error") + return assert.Equalf(io.EOF, err2, "unexpected error") }, }, { @@ -772,7 +773,7 @@ func TestConn(t *testing.T) { if err != nil { return err } - err = assertEqualf("hi", v, "unexpected JSON") + err = assert.Equalf("hi", v, "unexpected JSON") if err != nil { return err } @@ -780,7 +781,7 @@ func TestConn(t *testing.T) { if err != nil { return err } - return assertEqualf("hi", string(b), "unexpected JSON") + return assert.Equalf("hi", string(b), "unexpected JSON") }, client: func(ctx context.Context, c *websocket.Conn) error { err := wsjson.Write(ctx, c, "hi") @@ -1079,11 +1080,11 @@ func TestAutobahn(t *testing.T) { if err != nil { return err } - err = assertEqualf(typ, actTyp, "unexpected message type") + err = assert.Equalf(typ, actTyp, "unexpected message type") if err != nil { return err } - return assertEqualf(p, p2, "unexpected message") + return assert.Equalf(p, p2, "unexpected message") }) } } @@ -1859,7 +1860,7 @@ func assertCloseStatus(err error, code websocket.StatusCode) error { if !errors.As(err, &cerr) { return fmt.Errorf("no websocket close error in error chain: %+v", err) } - return assertEqualf(code, cerr.Code, "unexpected status code") + return assert.Equalf(code, cerr.Code, "unexpected status code") } func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { @@ -1871,7 +1872,7 @@ func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) return err } - return assertEqualf(exp, act, "unexpected protobuf") + return assert.Equalf(exp, act, "unexpected protobuf") } func assertNetConnRead(r io.Reader, exp string) error { @@ -1880,7 +1881,7 @@ func assertNetConnRead(r io.Reader, exp string) error { if err != nil { return err } - return assertEqualf(exp, string(act), "unexpected net conn read") + return assert.Equalf(exp, string(act), "unexpected net conn read") } func assertErrorContains(err error, exp string) error { @@ -1902,11 +1903,11 @@ func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.Op if err != nil { return err } - err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) + err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) if err != nil { return err } - return assertEqualf(p, actP, "unexpected frame %v payload", opcode) + return assert.Equalf(p, actP, "unexpected frame %v payload", opcode) } func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error { @@ -1914,7 +1915,7 @@ func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket if err != nil { return err } - err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) + err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) if err != nil { return err } @@ -1922,7 +1923,7 @@ func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket if err != nil { return fmt.Errorf("failed to parse close frame payload: %w", err) } - return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) + return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) } func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error { @@ -1960,11 +1961,11 @@ func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.Mes if err != nil { return err } - err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) + err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) if err != nil { return err } - return assertEqualf(p, actP, "unexpected frame %v payload", actTyp) + return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp) } func BenchmarkConn(b *testing.B) { diff --git a/frame.go b/frame.go index 4b170c5f7ec830c0c93ff2d3163829ba8546c76a..796c1c85d57afe5ae1f68905b190897d6f53a41d 100644 --- a/frame.go +++ b/frame.go @@ -2,6 +2,7 @@ package websocket import ( "encoding/binary" + "errors" "fmt" "io" "math" @@ -252,6 +253,17 @@ func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } +// CloseStatus is a convenience wrapper around xerrors.As to grab +// the status code from a *CloseError. If the passed error is nil +// or not a *CloseError, the returned StatusCode will be -1. +func CloseStatus(err error) StatusCode { + var ce *CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ diff --git a/frame_test.go b/frame_test.go index 7d2a571958633a4ae9e2b2fb9bfd4149bcc66b64..a4fead4937f576d37cfc6f2e1c13c9448b304671 100644 --- a/frame_test.go +++ b/frame_test.go @@ -13,6 +13,8 @@ import ( "time" "github.com/google/go-cmp/cmp" + + "nhooyr.io/websocket/internal/assert" ) func init() { @@ -376,3 +378,43 @@ func BenchmarkXOR(b *testing.B) { }) } } + +func TestCloseStatus(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in error + exp StatusCode + }{ + { + name: "nil", + in: nil, + exp: -1, + }, + { + name: "io.EOF", + in: io.EOF, + exp: -1, + }, + { + name: "StatusInternalError", + in: &CloseError{ + Code: StatusInternalError, + }, + exp: StatusInternalError, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status") + if err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/internal/assert/assert.go b/internal/assert/assert.go new file mode 100644 index 0000000000000000000000000000000000000000..e57abfd9e1fe37b7a597c2a0c53dfa2b8c148e6f --- /dev/null +++ b/internal/assert/assert.go @@ -0,0 +1,63 @@ +package assert + +import ( + "fmt" + "reflect" + + "github.com/google/go-cmp/cmp" +) + +// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 +func cmpDiff(exp, act interface{}) string { + return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) +} + +func deepAllowUnexported(vs ...interface{}) cmp.Option { + m := make(map[reflect.Type]struct{}) + for _, v := range vs { + structTypes(reflect.ValueOf(v), m) + } + var typs []interface{} + for t := range m { + typs = append(typs, reflect.New(t).Elem().Interface()) + } + return cmp.AllowUnexported(typs...) +} + +func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { + if !v.IsValid() { + return + } + switch v.Kind() { + case reflect.Ptr: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Interface: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + structTypes(v.Index(i), m) + } + case reflect.Map: + for _, k := range v.MapKeys() { + structTypes(v.MapIndex(k), m) + } + case reflect.Struct: + m[v.Type()] = struct{}{} + for i := 0; i < v.NumField(); i++ { + structTypes(v.Field(i), m) + } + } +} + +// Equalf compares exp to act and if they are not equal, returns +// an error describing an error. +func Equalf(exp, act interface{}, f string, v ...interface{}) error { + if diff := cmpDiff(exp, act); diff != "" { + return fmt.Errorf(f+": %v", append(v, diff)...) + } + return nil +} diff --git a/websocket_js_test.go b/websocket_js_test.go index a3bb7639362edead3bc0dbf18e1fcec07147deb8..9808e708cc1fd485ee4a7b389a98fff594d87832 100644 --- a/websocket_js_test.go +++ b/websocket_js_test.go @@ -8,6 +8,7 @@ import ( "time" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/assert" ) func TestConn(t *testing.T) { @@ -29,7 +30,7 @@ func TestConn(t *testing.T) { t.Fatal(err) } - err = assertEqualf(&http.Response{}, resp, "unexpected http response") + err = assert.Equalf(&http.Response{}, resp, "unexpected http response") if err != nil { t.Fatal(err) }