good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 3673c2cf authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Use basic test assertions

parent 9c5bfabc
Branches
Tags
Loading
package websocket
import (
"net/http"
"golang.org/x/xerrors"
)
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
Subprotocols []string
InsecureSkipVerify bool
CompressionOptions *CompressionOptions
}
// Accept is stubbed out for Wasm.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, xerrors.New("unimplemented")
}
......@@ -12,7 +12,7 @@ import (
"golang.org/x/xerrors"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
)
func TestAccept(t *testing.T) {
......@@ -25,9 +25,7 @@ func TestAccept(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
_, err := Accept(w, r, nil)
if !cmp.ErrorContains(err, "protocol violation") {
t.Fatal(err)
}
assert.Contains(t, err, "protocol violation")
})
t.Run("badOrigin", func(t *testing.T) {
......@@ -42,9 +40,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Origin", "harhar.com")
_, err := Accept(w, r, nil)
if !cmp.ErrorContains(err, `request Origin "harhar.com" is not authorized for Host`) {
t.Fatal(err)
}
assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`)
})
t.Run("badCompression", func(t *testing.T) {
......@@ -61,9 +57,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
_, err := Accept(w, r, nil)
if !cmp.ErrorContains(err, `unsupported permessage-deflate parameter`) {
t.Fatal(err)
}
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
})
t.Run("requireHttpHijacker", func(t *testing.T) {
......@@ -77,9 +71,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Sec-WebSocket-Key", "meow123")
_, err := Accept(w, r, nil)
if !cmp.ErrorContains(err, `http.ResponseWriter does not implement http.Hijacker`) {
t.Fatal(err)
}
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
})
t.Run("badHijack", func(t *testing.T) {
......@@ -99,9 +91,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Sec-WebSocket-Key", "meow123")
_, err := Accept(w, r, nil)
if !cmp.ErrorContains(err, `failed to hijack connection`) {
t.Fatal(err)
}
assert.Contains(t, err, `failed to hijack connection`)
})
}
......@@ -193,8 +183,10 @@ func Test_verifyClientHandshake(t *testing.T) {
}
_, err := verifyClientRequest(httptest.NewRecorder(), r)
if tc.success != (err == nil) {
t.Fatalf("unexpected error value: %v", err)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
......@@ -244,9 +236,7 @@ func Test_selectSubprotocol(t *testing.T) {
r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))
negotiated := selectSubprotocol(r, tc.serverProtocols)
if !cmp.Equal(tc.negotiated, negotiated) {
t.Fatalf("unexpected negotiated: %v", cmp.Diff(tc.negotiated, negotiated))
}
assert.Equal(t, "negotiated", tc.negotiated, negotiated)
})
}
}
......@@ -300,8 +290,10 @@ func Test_authenticateOrigin(t *testing.T) {
r.Header.Set("Origin", tc.origin)
err := authenticateOrigin(r)
if tc.success != (err == nil) {
t.Fatalf("unexpected error value: %v", err)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
......@@ -373,21 +365,13 @@ func Test_acceptCompression(t *testing.T) {
w := httptest.NewRecorder()
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
if err == nil {
t.Fatalf("expected error: %v", copts)
}
assert.Error(t, err)
return
}
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(tc.expCopts, copts) {
t.Fatalf("unexpected compression options: %v", cmp.Diff(tc.expCopts, copts))
}
if !cmp.Equal(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) {
t.Fatalf("unexpected respHeader: %v", cmp.Diff(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")))
}
assert.Success(t, err)
assert.Equal(t, "compression options", tc.expCopts, copts)
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
})
}
}
......
......@@ -19,6 +19,7 @@ import (
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
)
......@@ -45,32 +46,26 @@ func TestAutobahn(t *testing.T) {
defer cancel()
wstestURL, closeFn, err := wstestClientServer(ctx)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
defer closeFn()
err = waitWS(ctx, wstestURL)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
cases, err := wstestCaseCount(ctx, wstestURL)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
t.Run("cases", func(t *testing.T) {
for i := 1; i <= cases; i++ {
i := i
t.Run("", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
err = wstest.EchoLoop(ctx, c)
t.Logf("echoLoop: %v", err)
})
......@@ -78,9 +73,7 @@ func TestAutobahn(t *testing.T) {
})
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
......@@ -172,18 +165,14 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := ioutil.ReadFile(path)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
var indexJSON map[string]map[string]struct {
Behavior string `json:"behavior"`
BehaviorClose string `json:"behaviorClose"`
}
err = json.Unmarshal(wstestOut, &indexJSON)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
for _, tests := range indexJSON {
for test, result := range tests {
......
......@@ -8,7 +8,7 @@ import (
"strings"
"testing"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
)
func TestCloseError(t *testing.T) {
......@@ -51,8 +51,10 @@ func TestCloseError(t *testing.T) {
t.Parallel()
_, err := tc.ce.bytesErr()
if tc.success != (err == nil) {
t.Fatalf("unexpected error value (wanted err == nil == %v): %v", tc.success, err)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
......@@ -63,10 +65,7 @@ func TestCloseError(t *testing.T) {
Code: StatusInternalError,
Reason: "meow",
}.Error()
if (act) != exp {
t.Fatal(cmp.Diff(exp, act))
}
assert.Equal(t, "CloseError.Error()", exp, act)
})
}
......@@ -114,14 +113,10 @@ func Test_parseClosePayload(t *testing.T) {
ce, err := parseClosePayload(tc.p)
if tc.success {
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(tc.ce, ce) {
t.Fatalf("expected %v but got %v", tc.ce, ce)
}
} else if err == nil {
t.Errorf("expected error: %v %v", ce, err)
assert.Success(t, err)
assert.Equal(t, "close payload", tc.ce, ce)
} else {
assert.Error(t, err)
}
})
}
......@@ -168,9 +163,7 @@ func Test_validWireCloseCode(t *testing.T) {
t.Parallel()
act := validWireCloseCode(tc.code)
if !cmp.Equal(tc.valid, act) {
t.Fatalf("unexpected valid: %v", cmp.Diff(tc.valid, act))
}
assert.Equal(t, "wire close code", tc.valid, act)
})
}
}
......@@ -208,9 +201,7 @@ func TestCloseStatus(t *testing.T) {
t.Parallel()
act := CloseStatus(tc.in)
if !cmp.Equal(tc.exp, act) {
t.Fatalf("unexpected closeStatus: %v", cmp.Diff(tc.exp, act))
}
assert.Equal(t, "close status", tc.exp, act)
})
}
}
......@@ -6,6 +6,7 @@ import (
"strings"
"testing"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand"
)
......@@ -23,10 +24,7 @@ func Test_slidingWindow(t *testing.T) {
r := newSlidingWindow(windowLength)
r.write([]byte(input))
if cap(r.buf) != windowLength {
t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength)
}
assert.Equal(t, "window length", windowLength, cap(r.buf))
if !strings.HasSuffix(input, string(r.buf)) {
t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf)
}
......
......@@ -20,7 +20,7 @@ import (
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
"nhooyr.io/websocket/internal/test/xrand"
"nhooyr.io/websocket/internal/xsync"
......@@ -34,26 +34,21 @@ func TestConn(t *testing.T) {
t.Run("fuzzData", func(t *testing.T) {
t.Parallel()
for i := 0; i < 5; i++ {
t.Run("", func(t *testing.T) {
tt := newTest(t)
defer tt.done()
dialCopts := &websocket.CompressionOptions{
copts := func() *websocket.CompressionOptions {
return &websocket.CompressionOptions{
Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)),
Threshold: xrand.Int(9999),
}
acceptCopts := &websocket.CompressionOptions{
Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)),
Threshold: xrand.Int(9999),
}
c1, c2 := tt.pipe(&websocket.DialOptions{
CompressionOptions: dialCopts,
for i := 0; i < 5; i++ {
t.Run("", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
CompressionOptions: copts(),
}, &websocket.AcceptOptions{
CompressionOptions: acceptCopts,
CompressionOptions: copts(),
})
defer tt.done()
tt.goEchoLoop(c2)
......@@ -61,60 +56,53 @@ func TestConn(t *testing.T) {
for i := 0; i < 5; i++ {
err := wstest.Echo(tt.ctx, c1, 131072)
tt.success(err)
assert.Success(t, err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
assert.Success(t, err)
})
}
})
t.Run("badClose", func(t *testing.T) {
tt := newTest(t)
tt, c1, _ := newConnTest(t, nil, nil)
defer tt.done()
c1, _ := tt.pipe(nil, nil)
err := c1.Close(-1, "")
tt.errContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
})
t.Run("ping", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
for i := 0; i < 10; i++ {
err := c1.Ping(tt.ctx)
tt.success(err)
assert.Success(t, err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
assert.Success(t, err)
})
t.Run("badPing", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
c2.CloseRead(tt.ctx)
err := c1.Ping(tt.ctx)
tt.errContains(err, "failed to wait for pong")
assert.Contains(t, err, "failed to wait for pong")
})
t.Run("concurrentWrite", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
tt.goDiscardLoop(c2)
msg := xrand.Bytes(xrand.Int(9999))
......@@ -129,35 +117,31 @@ func TestConn(t *testing.T) {
for i := 0; i < count; i++ {
err := <-errs
tt.success(err)
assert.Success(t, err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
assert.Success(t, err)
})
t.Run("concurrentWriteError", func(t *testing.T) {
tt := newTest(t)
tt, c1, _ := newConnTest(t, nil, nil)
defer tt.done()
c1, _ := tt.pipe(nil, nil)
_, err := c1.Writer(tt.ctx, websocket.MessageText)
tt.success(err)
assert.Success(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
err = c1.Write(ctx, websocket.MessageText, []byte("x"))
tt.eq(context.DeadlineExceeded, err)
assert.Equal(t, "write error", context.DeadlineExceeded, err)
})
t.Run("netConn", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
......@@ -166,9 +150,9 @@ func TestConn(t *testing.T) {
n1.SetDeadline(d)
n1.SetDeadline(time.Time{})
tt.eq(n1.RemoteAddr(), n1.LocalAddr())
tt.eq("websocket/unknown-addr", n1.RemoteAddr().String())
tt.eq("websocket", n1.RemoteAddr().Network())
assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String())
assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network())
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
......@@ -179,23 +163,21 @@ func TestConn(t *testing.T) {
})
b, err := ioutil.ReadAll(n1)
tt.success(err)
assert.Success(t, err)
_, err = n1.Read(nil)
tt.eq(err, io.EOF)
assert.Equal(t, "read error", err, io.EOF)
err = <-errs
tt.success(err)
assert.Success(t, err)
tt.eq([]byte("hello"), b)
assert.Equal(t, "read msg", []byte("hello"), b)
})
t.Run("netConn/BadMsg", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
......@@ -208,18 +190,16 @@ func TestConn(t *testing.T) {
})
_, err := ioutil.ReadAll(n1)
tt.errContains(err, `unexpected frame type read (expected MessageBinary): MessageText`)
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
err = <-errs
tt.success(err)
assert.Success(t, err)
})
t.Run("wsjson", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(1 << 30)
......@@ -232,35 +212,33 @@ func TestConn(t *testing.T) {
var act interface{}
err := wsjson.Read(tt.ctx, c1, &act)
tt.success(err)
tt.eq(exp, act)
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
err = <-werr
tt.success(err)
assert.Success(t, err)
err = c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
assert.Success(t, err)
})
t.Run("wspb", func(t *testing.T) {
tt := newTest(t)
tt, c1, c2 := newConnTest(t, nil, nil)
defer tt.done()
c1, c2 := tt.pipe(nil, nil)
tt.goEchoLoop(c2)
exp := ptypes.DurationProto(100)
err := wspb.Write(tt.ctx, c1, exp)
tt.success(err)
assert.Success(t, err)
act := &duration.Duration{}
err = wspb.Read(tt.ctx, c1, act)
tt.success(err)
tt.eq(exp, act)
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
err = c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
assert.Success(t, err)
})
}
......@@ -277,14 +255,17 @@ func TestWasm(t *testing.T) {
InsecureSkipVerify: true,
})
if err != nil {
t.Error(err)
t.Errorf("echo server failed: %v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
err = wstest.EchoLoop(r.Context(), c)
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
t.Errorf("echoLoop failed: %v", err)
err = assertCloseStatus(websocket.StatusNormalClosure, err)
if err != nil {
t.Errorf("echo server failed: %v", err)
return
}
}))
defer wg.Wait()
......@@ -307,38 +288,47 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error {
return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err)
}
if websocket.CloseStatus(err) != exp {
return xerrors.Errorf("unexpected close status (%v):%v", exp, err)
return xerrors.Errorf("expected close status %v but got ", exp, err)
}
return nil
}
type test struct {
type connTest struct {
t *testing.T
ctx context.Context
doneFuncs []func()
}
func newTest(t *testing.T) *test {
func newConnTest(t *testing.T, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
t.Parallel()
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
tt := &test{t: t, ctx: ctx}
tt = &connTest{t: t, ctx: ctx}
tt.appendDone(cancel)
return tt
c1, c2, err := wstest.Pipe(dialOpts, acceptOpts)
assert.Success(tt.t, err)
tt.appendDone(func() {
c2.Close(websocket.StatusInternalError, "")
c1.Close(websocket.StatusInternalError, "")
})
return tt, c1, c2
}
func (tt *test) appendDone(f func()) {
func (tt *connTest) appendDone(f func()) {
tt.doneFuncs = append(tt.doneFuncs, f)
}
func (tt *test) done() {
func (tt *connTest) done() {
for i := len(tt.doneFuncs) - 1; i >= 0; i-- {
tt.doneFuncs[i]()
}
}
func (tt *test) goEchoLoop(c *websocket.Conn) {
func (tt *connTest) goEchoLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
echoLoopErr := xsync.Go(func() error {
......@@ -354,7 +344,7 @@ func (tt *test) goEchoLoop(c *websocket.Conn) {
})
}
func (tt *test) goDiscardLoop(c *websocket.Conn) {
func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
discardLoopErr := xsync.Go(func() error {
......@@ -376,38 +366,3 @@ func (tt *test) goDiscardLoop(c *websocket.Conn) {
}
})
}
func (tt *test) pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (c1, c2 *websocket.Conn) {
tt.t.Helper()
c1, c2, err := wstest.Pipe(dialOpts, acceptOpts)
if err != nil {
tt.t.Fatal(err)
}
tt.appendDone(func() {
c2.Close(websocket.StatusInternalError, "")
c1.Close(websocket.StatusInternalError, "")
})
return c1, c2
}
func (tt *test) success(err error) {
tt.t.Helper()
if err != nil {
tt.t.Fatal(err)
}
}
func (tt *test) errContains(err error, sub string) {
tt.t.Helper()
if !cmp.ErrorContains(err, sub) {
tt.t.Fatalf("error does not contain %q: %v", sub, err)
}
}
func (tt *test) eq(exp, act interface{}) {
tt.t.Helper()
if !cmp.Equal(exp, act) {
tt.t.Fatalf(cmp.Diff(exp, act))
}
}
......@@ -13,7 +13,7 @@ import (
"testing"
"time"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
)
func TestBadDials(t *testing.T) {
......@@ -70,9 +70,7 @@ func TestBadDials(t *testing.T) {
}
_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
if err == nil {
t.Fatalf("expected error")
}
assert.Error(t, err)
})
}
})
......@@ -90,9 +88,7 @@ func TestBadDials(t *testing.T) {
}, nil
}),
})
if !cmp.ErrorContains(err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") {
t.Fatal(err)
}
assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
})
t.Run("badBody", func(t *testing.T) {
......@@ -117,9 +113,7 @@ func TestBadDials(t *testing.T) {
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
HTTPClient: mockHTTPClient(rt),
})
if !cmp.ErrorContains(err, "response body is not a io.ReadWriteCloser") {
t.Fatal(err)
}
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
})
}
......@@ -217,9 +211,7 @@ func Test_verifyServerHandshake(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
key, err := secWebSocketKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
......@@ -230,8 +222,10 @@ func Test_verifyServerHandshake(t *testing.T) {
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, key, resp)
if (err == nil) != tc.success {
t.Fatalf("unexpected error: %v", err)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
......
......@@ -16,7 +16,7 @@ import (
"github.com/gobwas/ws"
_ "github.com/gorilla/websocket"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
)
func TestHeader(t *testing.T) {
......@@ -81,22 +81,15 @@ func testHeader(t *testing.T, h header) {
r := bufio.NewReader(b)
err := writeFrameHeader(h, w)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
err = w.Flush()
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
h2, err := readFrameHeader(r)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
if !cmp.Equal(h, h2) {
t.Fatal(cmp.Diff(h, h2))
}
assert.Equal(t, "read header", h, h2)
}
func Test_mask(t *testing.T) {
......@@ -108,14 +101,10 @@ func Test_mask(t *testing.T) {
gotKey32 := mask(key32, p)
expP := []byte{0, 0, 0, 0x0d, 0x6}
if !cmp.Equal(expP, p) {
t.Fatal(cmp.Diff(expP, p))
}
assert.Equal(t, "p", expP, p)
expKey32 := bits.RotateLeft32(key32, -8)
if !cmp.Equal(expKey32, gotKey32) {
t.Fatal(cmp.Diff(expKey32, gotKey32))
}
assert.Equal(t, "key32", expKey32, gotKey32)
}
func basicMask(maskKey [4]byte, pos int, b []byte) int {
......
package assert
import (
"fmt"
"strings"
"testing"
"nhooyr.io/websocket/internal/test/cmp"
)
// Equal asserts exp == act.
func Equal(t testing.TB, name string, exp, act interface{}) {
t.Helper()
if diff := cmp.Diff(exp, act); diff != "" {
t.Fatalf("unexpected %v: %v", name, diff)
}
}
// Success asserts err == nil.
func Success(t testing.TB, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// Error asserts err != nil.
func Error(t testing.TB, err error) {
t.Helper()
if err == nil {
t.Fatal("expected error")
}
}
// Contains asserts the fmt.Sprint(v) contains sub.
func Contains(t testing.TB, v interface{}, sub string) {
t.Helper()
vstr := fmt.Sprint(v)
if !strings.Contains(vstr, sub) {
t.Fatalf("expected %q to contain %q", vstr, sub)
}
}
......@@ -2,31 +2,15 @@ package cmp
import (
"reflect"
"strings"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
// Equal checks if v1 and v2 are equal with go-cmp.
func Equal(v1, v2 interface{}) bool {
return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool {
return true
}), cmp.Comparer(proto.Equal))
}
// Diff returns a human readable diff between v1 and v2
func Diff(v1, v2 interface{}) string {
return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool {
return true
}))
}
// ErrorContains returns whether err.Error() contains sub.
func ErrorContains(err error, sub string) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), sub)
}), cmp.Comparer(proto.Equal))
}
......@@ -3,7 +3,7 @@ package xsync
import (
"testing"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
)
func TestGoRecover(t *testing.T) {
......@@ -14,7 +14,5 @@ func TestGoRecover(t *testing.T) {
})
err := <-errs
if !cmp.ErrorContains(err, "anmol") {
t.Fatalf("unexpected err: %v", err)
}
assert.Contains(t, err, "anmol")
}
......@@ -8,7 +8,7 @@ import (
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/test/cmp"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
)
......@@ -21,28 +21,18 @@ func TestWasm(t *testing.T) {
c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
if !cmp.Equal("echo", c.Subprotocol()) {
t.Fatalf("unexpected subprotocol: %v", cmp.Diff("echo", c.Subprotocol()))
}
if !cmp.Equal(http.StatusSwitchingProtocols, resp.StatusCode) {
t.Fatalf("unexpected status code: %v", cmp.Diff(http.StatusSwitchingProtocols, resp.StatusCode))
}
assert.Equal(t, "subprotocol", "echo", c.Subprotocol())
assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode)
c.SetReadLimit(65536)
for i := 0; i < 10; i++ {
err = wstest.Echo(ctx, c, 65536)
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
}
err = c.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatal(err)
}
assert.Success(t, err)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment