From 6b765363d1e5ce21e6ca3bdb7bde03ecba1a2a98 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Wed, 29 Jan 2020 22:08:29 -0600 Subject: [PATCH] Up dial coverage to 100% --- .github/ISSUE_TEMPLATE.md | 3 - ci/image/Dockerfile | 2 +- conn.go | 2 +- dial.go | 13 +-- dial_test.go | 165 +++++++++++++++++++++++++++++--------- doc.go | 3 +- internal/bpool/bpool.go | 6 +- ws_js.go | 2 +- wspb/wspb.go | 16 ++-- 9 files changed, 151 insertions(+), 61 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 7b58093..0000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,3 +0,0 @@ -<!-- -Please be as descriptive as possible. ---> diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index bfc05fc..070c50e 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -6,7 +6,7 @@ RUN apt-get install -y chromium ENV GOFLAGS="-mod=readonly" ENV PAGER=cat ENV CI=true -ENV MAKEFLAGS="--jobs=8 --output-sync=target" +ENV MAKEFLAGS="--jobs=16 --output-sync=target" RUN npm install -g prettier RUN go get golang.org/x/tools/cmd/stringer diff --git a/conn.go b/conn.go index 5ccf9f9..a017649 100644 --- a/conn.go +++ b/conn.go @@ -22,7 +22,7 @@ type MessageType int const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. + // MessageBinary is for binary messages like protobufs. MessageBinary ) diff --git a/dial.go b/dial.go index af94501..58c0a9c 100644 --- a/dial.go +++ b/dial.go @@ -50,10 +50,10 @@ type DialOptions struct { // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - return dial(ctx, u, opts) + return dial(ctx, u, opts, nil) } -func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { +func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") if opts == nil { @@ -67,7 +67,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http opts.HTTPHeader = http.Header{} } - secWebSocketKey, err := secWebSocketKey() + secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } @@ -148,9 +148,12 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe return resp, nil } -func secWebSocketKey() (string, error) { +func secWebSocketKey(rr io.Reader) (string, error) { + if rr == nil { + rr = rand.Reader + } b := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, b) + _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } diff --git a/dial_test.go b/dial_test.go index 6286f0f..4314f98 100644 --- a/dial_test.go +++ b/dial_test.go @@ -4,58 +4,117 @@ package websocket import ( "context" + "crypto/rand" + "io" + "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "time" + + "cdr.dev/slog/sloggers/slogtest/assert" ) func TestBadDials(t *testing.T) { t.Parallel() - testCases := []struct { - name string - url string - opts *DialOptions - }{ - { - name: "badURL", - url: "://noscheme", - }, - { - name: "badURLScheme", - url: "ftp://nhooyr.io", - }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, + t.Run("badReq", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + rand readerFunc + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, }, }, - }, - { - name: "badTLS", - url: "wss://totallyfake.nhooyr.io", - }, - } + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + { + name: "badReader", + rand: func(p []byte) (int, error) { + return 0, io.EOF + }, + }, + } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - _, _, err := Dial(ctx, tc.url, tc.opts) - if err == nil { - t.Fatalf("expected non nil error: %+v", err) - } + if tc.rand == nil { + tc.rand = rand.Reader.Read + } + + _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) + assert.Error(t, "dial", err) + }) + } + }) + + t.Run("badResponse", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("hi")), + }, nil + }), }) - } + assert.ErrorContains(t, "dial", err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") + }) + + t.Run("badBody", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + h := http.Header{} + h.Set("Connection", "Upgrade") + h.Set("Upgrade", "websocket") + h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: h, + Body: ioutil.NopCloser(strings.NewReader("hi")), + }, nil + } + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(rt), + }) + assert.ErrorContains(t, "dial", err, "response body is not a io.ReadWriteCloser") + }) } func Test_verifyServerHandshake(t *testing.T) { @@ -110,6 +169,26 @@ func Test_verifyServerHandshake(t *testing.T) { }, success: false, }, + { + name: "unsupportedExtension", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Extensions", "meow") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "unsupportedDeflateParam", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, { name: "success", response: func(w http.ResponseWriter) { @@ -131,7 +210,7 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key, err := secWebSocketKey() + key, err := secWebSocketKey(rand.Reader) if err != nil { t.Fatal(err) } @@ -151,3 +230,15 @@ func Test_verifyServerHandshake(t *testing.T) { }) } } + +func mockHTTPClient(fn roundTripperFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} diff --git a/doc.go b/doc.go index 6847d53..c8f5550 100644 --- a/doc.go +++ b/doc.go @@ -12,7 +12,7 @@ // // The examples are the best way to understand how to correctly use the library. // -// The wsjson and wspb subpackages contain helpers for JSON and Protobuf messages. +// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. // // More documentation at https://nhooyr.io/websocket. // @@ -28,5 +28,4 @@ // - Conn.Ping is no-op // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} on success -// package websocket // import "nhooyr.io/websocket" diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go index e2c5f76..aa826fb 100644 --- a/internal/bpool/bpool.go +++ b/internal/bpool/bpool.go @@ -5,12 +5,12 @@ import ( "sync" ) -var pool sync.Pool +var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { - b := pool.Get() + b := bpool.Get() if b == nil { return &bytes.Buffer{} } @@ -20,5 +20,5 @@ func Get() *bytes.Buffer { // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() - pool.Put(b) + bpool.Put(b) } diff --git a/ws_js.go b/ws_js.go index 950aa01..2aaef73 100644 --- a/ws_js.go +++ b/ws_js.go @@ -23,7 +23,7 @@ type MessageType int const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. + // MessageBinary is for binary messages like protobufs. MessageBinary ) diff --git a/wspb/wspb.go b/wspb/wspb.go index 666c6fa..e43042d 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -13,14 +13,14 @@ import ( "nhooyr.io/websocket/internal/errd" ) -// Read reads a Protobuf message from c into v. +// Read reads a protobuf message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to read Protobuf message") + defer errd.Wrap(&err, "failed to read protobuf message") typ, r, err := c.Reader(ctx) if err != nil { @@ -29,7 +29,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { if typ != websocket.MessageBinary { c.Close(websocket.StatusUnsupportedData, "expected binary message") - return fmt.Errorf("expected binary message for Protobuf but got: %v", typ) + return fmt.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() @@ -42,21 +42,21 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { err = proto.Unmarshal(b.Bytes(), v) if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal Protobuf") - return fmt.Errorf("failed to unmarshal Protobuf: %w", err) + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") + return fmt.Errorf("failed to unmarshal protobuf: %w", err) } return nil } -// Write writes the Protobuf message v to c. +// Write writes the protobuf message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to write Protobuf message") + defer errd.Wrap(&err, "failed to write protobuf message") b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) @@ -66,7 +66,7 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) err = pb.Marshal(v) if err != nil { - return fmt.Errorf("failed to marshal Protobuf: %w", err) + return fmt.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) -- GitLab