From 6b765363d1e5ce21e6ca3bdb7bde03ecba1a2a98 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Wed, 29 Jan 2020 22:08:29 -0600
Subject: [PATCH] Up dial coverage to 100%

---
 .github/ISSUE_TEMPLATE.md |   3 -
 ci/image/Dockerfile       |   2 +-
 conn.go                   |   2 +-
 dial.go                   |  13 +--
 dial_test.go              | 165 +++++++++++++++++++++++++++++---------
 doc.go                    |   3 +-
 internal/bpool/bpool.go   |   6 +-
 ws_js.go                  |   2 +-
 wspb/wspb.go              |  16 ++--
 9 files changed, 151 insertions(+), 61 deletions(-)
 delete mode 100644 .github/ISSUE_TEMPLATE.md

diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
deleted file mode 100644
index 7b58093..0000000
--- a/.github/ISSUE_TEMPLATE.md
+++ /dev/null
@@ -1,3 +0,0 @@
-<!--
-Please be as descriptive as possible.
--->
diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile
index bfc05fc..070c50e 100644
--- a/ci/image/Dockerfile
+++ b/ci/image/Dockerfile
@@ -6,7 +6,7 @@ RUN apt-get install -y chromium
 ENV GOFLAGS="-mod=readonly"
 ENV PAGER=cat
 ENV CI=true
-ENV MAKEFLAGS="--jobs=8 --output-sync=target"
+ENV MAKEFLAGS="--jobs=16 --output-sync=target"
 
 RUN npm install -g prettier
 RUN go get golang.org/x/tools/cmd/stringer
diff --git a/conn.go b/conn.go
index 5ccf9f9..a017649 100644
--- a/conn.go
+++ b/conn.go
@@ -22,7 +22,7 @@ type MessageType int
 const (
 	// MessageText is for UTF-8 encoded text messages like JSON.
 	MessageText MessageType = iota + 1
-	// MessageBinary is for binary messages like Protobufs.
+	// MessageBinary is for binary messages like protobufs.
 	MessageBinary
 )
 
diff --git a/dial.go b/dial.go
index af94501..58c0a9c 100644
--- a/dial.go
+++ b/dial.go
@@ -50,10 +50,10 @@ type DialOptions struct {
 // in net/http to perform WebSocket handshakes.
 // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
 func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
-	return dial(ctx, u, opts)
+	return dial(ctx, u, opts, nil)
 }
 
-func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) {
+func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
 	defer errd.Wrap(&err, "failed to WebSocket dial")
 
 	if opts == nil {
@@ -67,7 +67,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http
 		opts.HTTPHeader = http.Header{}
 	}
 
-	secWebSocketKey, err := secWebSocketKey()
+	secWebSocketKey, err := secWebSocketKey(rand)
 	if err != nil {
 		return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
 	}
@@ -148,9 +148,12 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
 	return resp, nil
 }
 
-func secWebSocketKey() (string, error) {
+func secWebSocketKey(rr io.Reader) (string, error) {
+	if rr == nil {
+		rr = rand.Reader
+	}
 	b := make([]byte, 16)
-	_, err := io.ReadFull(rand.Reader, b)
+	_, err := io.ReadFull(rr, b)
 	if err != nil {
 		return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
 	}
diff --git a/dial_test.go b/dial_test.go
index 6286f0f..4314f98 100644
--- a/dial_test.go
+++ b/dial_test.go
@@ -4,58 +4,117 @@ package websocket
 
 import (
 	"context"
+	"crypto/rand"
+	"io"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"strings"
 	"testing"
 	"time"
+
+	"cdr.dev/slog/sloggers/slogtest/assert"
 )
 
 func TestBadDials(t *testing.T) {
 	t.Parallel()
 
-	testCases := []struct {
-		name string
-		url  string
-		opts *DialOptions
-	}{
-		{
-			name: "badURL",
-			url:  "://noscheme",
-		},
-		{
-			name: "badURLScheme",
-			url:  "ftp://nhooyr.io",
-		},
-		{
-			name: "badHTTPClient",
-			url:  "ws://nhooyr.io",
-			opts: &DialOptions{
-				HTTPClient: &http.Client{
-					Timeout: time.Minute,
+	t.Run("badReq", func(t *testing.T) {
+		t.Parallel()
+
+		testCases := []struct {
+			name string
+			url  string
+			opts *DialOptions
+			rand readerFunc
+		}{
+			{
+				name: "badURL",
+				url:  "://noscheme",
+			},
+			{
+				name: "badURLScheme",
+				url:  "ftp://nhooyr.io",
+			},
+			{
+				name: "badHTTPClient",
+				url:  "ws://nhooyr.io",
+				opts: &DialOptions{
+					HTTPClient: &http.Client{
+						Timeout: time.Minute,
+					},
 				},
 			},
-		},
-		{
-			name: "badTLS",
-			url:  "wss://totallyfake.nhooyr.io",
-		},
-	}
+			{
+				name: "badTLS",
+				url:  "wss://totallyfake.nhooyr.io",
+			},
+			{
+				name: "badReader",
+				rand: func(p []byte) (int, error) {
+					return 0, io.EOF
+				},
+			},
+		}
 
-	for _, tc := range testCases {
-		tc := tc
-		t.Run(tc.name, func(t *testing.T) {
-			t.Parallel()
+		for _, tc := range testCases {
+			tc := tc
+			t.Run(tc.name, func(t *testing.T) {
+				t.Parallel()
 
-			ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
-			defer cancel()
+				ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+				defer cancel()
 
-			_, _, err := Dial(ctx, tc.url, tc.opts)
-			if err == nil {
-				t.Fatalf("expected non nil error: %+v", err)
-			}
+				if tc.rand == nil {
+					tc.rand = rand.Reader.Read
+				}
+
+				_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
+				assert.Error(t, "dial", err)
+			})
+		}
+	})
+
+	t.Run("badResponse", func(t *testing.T) {
+		t.Parallel()
+
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+		defer cancel()
+
+		_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
+			HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
+				return &http.Response{
+					Body: ioutil.NopCloser(strings.NewReader("hi")),
+				}, nil
+			}),
 		})
-	}
+		assert.ErrorContains(t, "dial", err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
+	})
+
+	t.Run("badBody", func(t *testing.T) {
+		t.Parallel()
+
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+		defer cancel()
+
+		rt := func(r *http.Request) (*http.Response, error) {
+			h := http.Header{}
+			h.Set("Connection", "Upgrade")
+			h.Set("Upgrade", "websocket")
+			h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
+
+			return &http.Response{
+				StatusCode: http.StatusSwitchingProtocols,
+				Header:     h,
+				Body:       ioutil.NopCloser(strings.NewReader("hi")),
+			}, nil
+		}
+
+		_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
+			HTTPClient: mockHTTPClient(rt),
+		})
+		assert.ErrorContains(t, "dial", err, "response body is not a io.ReadWriteCloser")
+	})
 }
 
 func Test_verifyServerHandshake(t *testing.T) {
@@ -110,6 +169,26 @@ func Test_verifyServerHandshake(t *testing.T) {
 			},
 			success: false,
 		},
+		{
+			name: "unsupportedExtension",
+			response: func(w http.ResponseWriter) {
+				w.Header().Set("Connection", "Upgrade")
+				w.Header().Set("Upgrade", "websocket")
+				w.Header().Set("Sec-WebSocket-Extensions", "meow")
+				w.WriteHeader(http.StatusSwitchingProtocols)
+			},
+			success: false,
+		},
+		{
+			name: "unsupportedDeflateParam",
+			response: func(w http.ResponseWriter) {
+				w.Header().Set("Connection", "Upgrade")
+				w.Header().Set("Upgrade", "websocket")
+				w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
+				w.WriteHeader(http.StatusSwitchingProtocols)
+			},
+			success: false,
+		},
 		{
 			name: "success",
 			response: func(w http.ResponseWriter) {
@@ -131,7 +210,7 @@ func Test_verifyServerHandshake(t *testing.T) {
 			resp := w.Result()
 
 			r := httptest.NewRequest("GET", "/", nil)
-			key, err := secWebSocketKey()
+			key, err := secWebSocketKey(rand.Reader)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -151,3 +230,15 @@ func Test_verifyServerHandshake(t *testing.T) {
 		})
 	}
 }
+
+func mockHTTPClient(fn roundTripperFunc) *http.Client {
+	return &http.Client{
+		Transport: fn,
+	}
+}
+
+type roundTripperFunc func(*http.Request) (*http.Response, error)
+
+func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
+	return f(r)
+}
diff --git a/doc.go b/doc.go
index 6847d53..c8f5550 100644
--- a/doc.go
+++ b/doc.go
@@ -12,7 +12,7 @@
 //
 // The examples are the best way to understand how to correctly use the library.
 //
-// The wsjson and wspb subpackages contain helpers for JSON and Protobuf messages.
+// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages.
 //
 // More documentation at https://nhooyr.io/websocket.
 //
@@ -28,5 +28,4 @@
 //  - Conn.Ping is no-op
 //  - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
 //  - *http.Response from Dial is &http.Response{} on success
-//
 package websocket // import "nhooyr.io/websocket"
diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go
index e2c5f76..aa826fb 100644
--- a/internal/bpool/bpool.go
+++ b/internal/bpool/bpool.go
@@ -5,12 +5,12 @@ import (
 	"sync"
 )
 
-var pool sync.Pool
+var bpool sync.Pool
 
 // Get returns a buffer from the pool or creates a new one if
 // the pool is empty.
 func Get() *bytes.Buffer {
-	b := pool.Get()
+	b := bpool.Get()
 	if b == nil {
 		return &bytes.Buffer{}
 	}
@@ -20,5 +20,5 @@ func Get() *bytes.Buffer {
 // Put returns a buffer into the pool.
 func Put(b *bytes.Buffer) {
 	b.Reset()
-	pool.Put(b)
+	bpool.Put(b)
 }
diff --git a/ws_js.go b/ws_js.go
index 950aa01..2aaef73 100644
--- a/ws_js.go
+++ b/ws_js.go
@@ -23,7 +23,7 @@ type MessageType int
 const (
 	// MessageText is for UTF-8 encoded text messages like JSON.
 	MessageText MessageType = iota + 1
-	// MessageBinary is for binary messages like Protobufs.
+	// MessageBinary is for binary messages like protobufs.
 	MessageBinary
 )
 
diff --git a/wspb/wspb.go b/wspb/wspb.go
index 666c6fa..e43042d 100644
--- a/wspb/wspb.go
+++ b/wspb/wspb.go
@@ -13,14 +13,14 @@ import (
 	"nhooyr.io/websocket/internal/errd"
 )
 
-// Read reads a Protobuf message from c into v.
+// Read reads a protobuf message from c into v.
 // It will reuse buffers in between calls to avoid allocations.
 func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
 	return read(ctx, c, v)
 }
 
 func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) {
-	defer errd.Wrap(&err, "failed to read Protobuf message")
+	defer errd.Wrap(&err, "failed to read protobuf message")
 
 	typ, r, err := c.Reader(ctx)
 	if err != nil {
@@ -29,7 +29,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) {
 
 	if typ != websocket.MessageBinary {
 		c.Close(websocket.StatusUnsupportedData, "expected binary message")
-		return fmt.Errorf("expected binary message for Protobuf but got: %v", typ)
+		return fmt.Errorf("expected binary message for protobuf but got: %v", typ)
 	}
 
 	b := bpool.Get()
@@ -42,21 +42,21 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) {
 
 	err = proto.Unmarshal(b.Bytes(), v)
 	if err != nil {
-		c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal Protobuf")
-		return fmt.Errorf("failed to unmarshal Protobuf: %w", err)
+		c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf")
+		return fmt.Errorf("failed to unmarshal protobuf: %w", err)
 	}
 
 	return nil
 }
 
-// Write writes the Protobuf message v to c.
+// Write writes the protobuf message v to c.
 // It will reuse buffers in between calls to avoid allocations.
 func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
 	return write(ctx, c, v)
 }
 
 func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) {
-	defer errd.Wrap(&err, "failed to write Protobuf message")
+	defer errd.Wrap(&err, "failed to write protobuf message")
 
 	b := bpool.Get()
 	pb := proto.NewBuffer(b.Bytes())
@@ -66,7 +66,7 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error)
 
 	err = pb.Marshal(v)
 	if err != nil {
-		return fmt.Errorf("failed to marshal Protobuf: %w", err)
+		return fmt.Errorf("failed to marshal protobuf: %w", err)
 	}
 
 	return c.Write(ctx, websocket.MessageBinary, pb.Bytes())
-- 
GitLab