From 537b26b9c25f621a1e6299b8397ed9684838c12a Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Thu, 29 Aug 2019 17:07:20 -0500
Subject: [PATCH] Change options to be pointer structures

Closes #122
---
 README.md            |  4 ++--
 accept.go            |  8 ++++++--
 accept_test.go       |  4 ++--
 dial.go              | 12 ++++++++++--
 dial_test.go         |  4 ++--
 example_echo_test.go |  4 ++--
 example_test.go      |  6 +++---
 websocket_test.go    | 46 ++++++++++++++++++++++----------------------
 8 files changed, 50 insertions(+), 38 deletions(-)

diff --git a/README.md b/README.md
index cf20b87..d53046c 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ For a production quality example that shows off the full API, see the [echo exam
 
 ```go
 http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
-	c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+	c, err := websocket.Accept(w, r, nil)
 	if err != nil {
 		// ...
 	}
@@ -64,7 +64,7 @@ in net/http](https://github.com/golang/go/issues/26937#issuecomment-415855861) t
 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 defer cancel()
 
-c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{})
+c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
 if err != nil {
 	// ...
 }
diff --git a/accept.go b/accept.go
index 7b727d1..afad1be 100644
--- a/accept.go
+++ b/accept.go
@@ -84,7 +84,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
 //
 // If an error occurs, Accept will always write an appropriate response so you do not
 // have to.
-func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) {
+func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
 	c, err := accept(w, r, opts)
 	if err != nil {
 		return nil, xerrors.Errorf("failed to accept websocket connection: %w", err)
@@ -92,7 +92,11 @@ func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
 	return c, nil
 }
 
-func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) {
+func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
+	if opts == nil {
+		opts = &AcceptOptions{}
+	}
+
 	err := verifyClientRequest(w, r)
 	if err != nil {
 		return nil, err
diff --git a/accept_test.go b/accept_test.go
index 8634066..6602a8d 100644
--- a/accept_test.go
+++ b/accept_test.go
@@ -15,7 +15,7 @@ func TestAccept(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("GET", "/", nil)
 
-		_, err := Accept(w, r, AcceptOptions{})
+		_, err := Accept(w, r, nil)
 		if err == nil {
 			t.Fatalf("unexpected error value: %v", err)
 		}
@@ -32,7 +32,7 @@ func TestAccept(t *testing.T) {
 		r.Header.Set("Sec-WebSocket-Version", "13")
 		r.Header.Set("Sec-WebSocket-Key", "meow123")
 
-		_, err := Accept(w, r, AcceptOptions{})
+		_, err := Accept(w, r, nil)
 		if err == nil || !strings.Contains(err.Error(), "http.Hijacker") {
 			t.Fatalf("unexpected error value: %v", err)
 		}
diff --git a/dial.go b/dial.go
index ac632c1..461817f 100644
--- a/dial.go
+++ b/dial.go
@@ -41,7 +41,7 @@ type DialOptions struct {
 // This function requires at least Go 1.12 to succeed as it uses a new feature
 // in net/http to perform WebSocket handshakes and get a writable body
 // from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861
-func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) {
+func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
 	c, r, err := dial(ctx, u, opts)
 	if err != nil {
 		return nil, r, xerrors.Errorf("failed to websocket dial: %w", err)
@@ -49,7 +49,15 @@ func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Respons
 	return c, r, nil
 }
 
-func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) {
+func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) {
+	if opts == nil {
+		opts = &DialOptions{}
+	}
+
+	// Shallow copy to ensure defaults do not affect user passed options.
+	opts2 := *opts
+	opts = &opts2
+
 	if opts.HTTPClient == nil {
 		opts.HTTPClient = http.DefaultClient
 	}
diff --git a/dial_test.go b/dial_test.go
index 4607493..96537bd 100644
--- a/dial_test.go
+++ b/dial_test.go
@@ -14,7 +14,7 @@ func TestBadDials(t *testing.T) {
 	testCases := []struct {
 		name string
 		url  string
-		opts DialOptions
+		opts *DialOptions
 	}{
 		{
 			name: "badURL",
@@ -27,7 +27,7 @@ func TestBadDials(t *testing.T) {
 		{
 			name: "badHTTPClient",
 			url:  "ws://nhooyr.io",
-			opts: DialOptions{
+			opts: &DialOptions{
 				HTTPClient: &http.Client{
 					Timeout: time.Minute,
 				},
diff --git a/example_echo_test.go b/example_echo_test.go
index 6923bc0..3e7e7f9 100644
--- a/example_echo_test.go
+++ b/example_echo_test.go
@@ -68,7 +68,7 @@ func Example_echo() {
 func echoServer(w http.ResponseWriter, r *http.Request) error {
 	log.Printf("serving %v", r.RemoteAddr)
 
-	c, err := websocket.Accept(w, r, websocket.AcceptOptions{
+	c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 		Subprotocols: []string{"echo"},
 	})
 	if err != nil {
@@ -128,7 +128,7 @@ func client(url string) error {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 	defer cancel()
 
-	c, _, err := websocket.Dial(ctx, url, websocket.DialOptions{
+	c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
 		Subprotocols: []string{"echo"},
 	})
 	if err != nil {
diff --git a/example_test.go b/example_test.go
index 0b59e6a..22c3120 100644
--- a/example_test.go
+++ b/example_test.go
@@ -14,7 +14,7 @@ import (
 // message from the client and then closes the connection.
 func ExampleAccept() {
 	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+		c, err := websocket.Accept(w, r, nil)
 		if err != nil {
 			log.Println(err)
 			return
@@ -46,7 +46,7 @@ func ExampleDial() {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 	defer cancel()
 
-	c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{})
+	c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
 	if err != nil {
 		log.Fatal(err)
 	}
@@ -64,7 +64,7 @@ func ExampleDial() {
 // on which you will only write and do not expect to read data messages.
 func Example_writeOnly() {
 	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+		c, err := websocket.Accept(w, r, nil)
 		if err != nil {
 			log.Println(err)
 			return
diff --git a/websocket_test.go b/websocket_test.go
index b45f024..1f1b524 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -44,7 +44,7 @@ func TestHandshake(t *testing.T) {
 		{
 			name: "handshake",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{
+				c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 					Subprotocols: []string{"myproto"},
 				})
 				if err != nil {
@@ -54,7 +54,7 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 			client: func(ctx context.Context, u string) error {
-				c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, resp, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					Subprotocols: []string{"myproto"},
 				})
 				if err != nil {
@@ -81,7 +81,7 @@ func TestHandshake(t *testing.T) {
 		{
 			name: "defaultSubprotocol",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				c, err := websocket.Accept(w, r, nil)
 				if err != nil {
 					return err
 				}
@@ -93,7 +93,7 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					Subprotocols: []string{"meow"},
 				})
 				if err != nil {
@@ -110,7 +110,7 @@ func TestHandshake(t *testing.T) {
 		{
 			name: "subprotocol",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{
+				c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 					Subprotocols: []string{"echo", "lar"},
 				})
 				if err != nil {
@@ -124,7 +124,7 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 			client: func(ctx context.Context, u string) error {
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					Subprotocols: []string{"poof", "echo"},
 				})
 				if err != nil {
@@ -141,7 +141,7 @@ func TestHandshake(t *testing.T) {
 		{
 			name: "badOrigin",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				c, err := websocket.Accept(w, r, nil)
 				if err == nil {
 					c.Close(websocket.StatusInternalError, "")
 					return xerrors.New("expected error regarding bad origin")
@@ -151,7 +151,7 @@ func TestHandshake(t *testing.T) {
 			client: func(ctx context.Context, u string) error {
 				h := http.Header{}
 				h.Set("Origin", "http://unauthorized.com")
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					HTTPHeader: h,
 				})
 				if err == nil {
@@ -164,7 +164,7 @@ func TestHandshake(t *testing.T) {
 		{
 			name: "acceptSecureOrigin",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				c, err := websocket.Accept(w, r, nil)
 				if err != nil {
 					return err
 				}
@@ -174,7 +174,7 @@ func TestHandshake(t *testing.T) {
 			client: func(ctx context.Context, u string) error {
 				h := http.Header{}
 				h.Set("Origin", u)
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					HTTPHeader: h,
 				})
 				if err != nil {
@@ -187,7 +187,7 @@ func TestHandshake(t *testing.T) {
 		{
 			name: "acceptInsecureOrigin",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{
+				c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 					InsecureSkipVerify: true,
 				})
 				if err != nil {
@@ -199,7 +199,7 @@ func TestHandshake(t *testing.T) {
 			client: func(ctx context.Context, u string) error {
 				h := http.Header{}
 				h.Set("Origin", "https://example.com")
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					HTTPHeader: h,
 				})
 				if err != nil {
@@ -219,7 +219,7 @@ func TestHandshake(t *testing.T) {
 				if cookie.Value != "myvalue" {
 					return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
 				}
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				c, err := websocket.Accept(w, r, nil)
 				if err != nil {
 					return err
 				}
@@ -245,7 +245,7 @@ func TestHandshake(t *testing.T) {
 				hc := &http.Client{
 					Jar: jar,
 				}
-				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+				c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{
 					HTTPClient: hc,
 				})
 				if err != nil {
@@ -801,7 +801,7 @@ func TestConn(t *testing.T) {
 			},
 			client: func(ctx context.Context, c *websocket.Conn) error {
 				_, _, err := c.Read(ctx)
-				if err == nil || strings.Contains(err.Error(), "opcode") {
+				if err == nil || !strings.Contains(err.Error(), "opcode") {
 					return xerrors.Errorf("expected error that contains opcode: %+v", err)
 				}
 				return nil
@@ -839,7 +839,7 @@ func TestConn(t *testing.T) {
 			tls := rand.Intn(2) == 1
 
 			s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error {
-				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				c, err := websocket.Accept(w, r, nil)
 				if err != nil {
 					return err
 				}
@@ -854,7 +854,7 @@ func TestConn(t *testing.T) {
 			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 			defer cancel()
 
-			opts := websocket.DialOptions{}
+			opts := &websocket.DialOptions{}
 			if tls {
 				opts.HTTPClient = s.Client()
 			}
@@ -920,7 +920,7 @@ func TestAutobahnServer(t *testing.T) {
 	}
 
 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r, websocket.AcceptOptions{
+		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 			Subprotocols: []string{"echo"},
 		})
 		if err != nil {
@@ -1120,7 +1120,7 @@ func TestAutobahnClient(t *testing.T) {
 
 	var cases int
 	func() {
-		c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", websocket.DialOptions{})
+		c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -1147,7 +1147,7 @@ func TestAutobahnClient(t *testing.T) {
 			ctx, cancel := context.WithTimeout(ctx, time.Second*45)
 			defer cancel()
 
-			c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), websocket.DialOptions{})
+			c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -1155,7 +1155,7 @@ func TestAutobahnClient(t *testing.T) {
 		}()
 	}
 
-	c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), websocket.DialOptions{})
+	c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -1207,7 +1207,7 @@ func checkWSTestIndex(t *testing.T, path string) {
 
 func benchConn(b *testing.B, echo, stream bool, size int) {
 	s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error {
-		c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+		c, err := websocket.Accept(w, r, nil)
 		if err != nil {
 			return err
 		}
@@ -1225,7 +1225,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
 	defer cancel()
 
-	c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{})
+	c, _, err := websocket.Dial(ctx, wsURL, nil)
 	if err != nil {
 		b.Fatal(err)
 	}
-- 
GitLab