From 537b26b9c25f621a1e6299b8397ed9684838c12a Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Thu, 29 Aug 2019 17:07:20 -0500 Subject: [PATCH] Change options to be pointer structures Closes #122 --- README.md | 4 ++-- accept.go | 8 ++++++-- accept_test.go | 4 ++-- dial.go | 12 ++++++++++-- dial_test.go | 4 ++-- example_echo_test.go | 4 ++-- example_test.go | 6 +++--- websocket_test.go | 46 ++++++++++++++++++++++---------------------- 8 files changed, 50 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index cf20b87..d53046c 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ For a production quality example that shows off the full API, see the [echo exam ```go http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { // ... } @@ -64,7 +64,7 @@ in net/http](https://github.com/golang/go/issues/26937#issuecomment-415855861) t ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() -c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) +c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } diff --git a/accept.go b/accept.go index 7b727d1..afad1be 100644 --- a/accept.go +++ b/accept.go @@ -84,7 +84,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // // If an error occurs, Accept will always write an appropriate response so you do not // have to. -func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) if err != nil { return nil, xerrors.Errorf("failed to accept websocket connection: %w", err) @@ -92,7 +92,11 @@ func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, return c, nil } -func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + if opts == nil { + opts = &AcceptOptions{} + } + err := verifyClientRequest(w, r) if err != nil { return nil, err diff --git a/accept_test.go b/accept_test.go index 8634066..6602a8d 100644 --- a/accept_test.go +++ b/accept_test.go @@ -15,7 +15,7 @@ func TestAccept(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) - _, err := Accept(w, r, AcceptOptions{}) + _, err := Accept(w, r, nil) if err == nil { t.Fatalf("unexpected error value: %v", err) } @@ -32,7 +32,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Key", "meow123") - _, err := Accept(w, r, AcceptOptions{}) + _, err := Accept(w, r, nil) if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { t.Fatalf("unexpected error value: %v", err) } diff --git a/dial.go b/dial.go index ac632c1..461817f 100644 --- a/dial.go +++ b/dial.go @@ -41,7 +41,7 @@ type DialOptions struct { // This function requires at least Go 1.12 to succeed as it uses a new feature // in net/http to perform WebSocket handshakes and get a writable body // from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) { +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { c, r, err := dial(ctx, u, opts) if err != nil { return nil, r, xerrors.Errorf("failed to websocket dial: %w", err) @@ -49,7 +49,15 @@ func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Respons return c, r, nil } -func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) { +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + if opts == nil { + opts = &DialOptions{} + } + + // Shallow copy to ensure defaults do not affect user passed options. + opts2 := *opts + opts = &opts2 + if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } diff --git a/dial_test.go b/dial_test.go index 4607493..96537bd 100644 --- a/dial_test.go +++ b/dial_test.go @@ -14,7 +14,7 @@ func TestBadDials(t *testing.T) { testCases := []struct { name string url string - opts DialOptions + opts *DialOptions }{ { name: "badURL", @@ -27,7 +27,7 @@ func TestBadDials(t *testing.T) { { name: "badHTTPClient", url: "ws://nhooyr.io", - opts: DialOptions{ + opts: &DialOptions{ HTTPClient: &http.Client{ Timeout: time.Minute, }, diff --git a/example_echo_test.go b/example_echo_test.go index 6923bc0..3e7e7f9 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -68,7 +68,7 @@ func Example_echo() { func echoServer(w http.ResponseWriter, r *http.Request) error { log.Printf("serving %v", r.RemoteAddr) - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { @@ -128,7 +128,7 @@ func client(url string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, url, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) if err != nil { diff --git a/example_test.go b/example_test.go index 0b59e6a..22c3120 100644 --- a/example_test.go +++ b/example_test.go @@ -14,7 +14,7 @@ import ( // message from the client and then closes the connection. func ExampleAccept() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return @@ -46,7 +46,7 @@ func ExampleDial() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { log.Fatal(err) } @@ -64,7 +64,7 @@ func ExampleDial() { // on which you will only write and do not expect to read data messages. func Example_writeOnly() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { log.Println(err) return diff --git a/websocket_test.go b/websocket_test.go index b45f024..1f1b524 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -44,7 +44,7 @@ func TestHandshake(t *testing.T) { { name: "handshake", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"myproto"}, }) if err != nil { @@ -54,7 +54,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, resp, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"myproto"}, }) if err != nil { @@ -81,7 +81,7 @@ func TestHandshake(t *testing.T) { { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -93,7 +93,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"meow"}, }) if err != nil { @@ -110,7 +110,7 @@ func TestHandshake(t *testing.T) { { name: "subprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo", "lar"}, }) if err != nil { @@ -124,7 +124,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ Subprotocols: []string{"poof", "echo"}, }) if err != nil { @@ -141,7 +141,7 @@ func TestHandshake(t *testing.T) { { name: "badOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err == nil { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected error regarding bad origin") @@ -151,7 +151,7 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", "http://unauthorized.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPHeader: h, }) if err == nil { @@ -164,7 +164,7 @@ func TestHandshake(t *testing.T) { { name: "acceptSecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -174,7 +174,7 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", u) - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPHeader: h, }) if err != nil { @@ -187,7 +187,7 @@ func TestHandshake(t *testing.T) { { name: "acceptInsecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ InsecureSkipVerify: true, }) if err != nil { @@ -199,7 +199,7 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", "https://example.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPHeader: h, }) if err != nil { @@ -219,7 +219,7 @@ func TestHandshake(t *testing.T) { if cookie.Value != "myvalue" { return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) } - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -245,7 +245,7 @@ func TestHandshake(t *testing.T) { hc := &http.Client{ Jar: jar, } - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ HTTPClient: hc, }) if err != nil { @@ -801,7 +801,7 @@ func TestConn(t *testing.T) { }, client: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil || strings.Contains(err.Error(), "opcode") { + if err == nil || !strings.Contains(err.Error(), "opcode") { return xerrors.Errorf("expected error that contains opcode: %+v", err) } return nil @@ -839,7 +839,7 @@ func TestConn(t *testing.T) { tls := rand.Intn(2) == 1 s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -854,7 +854,7 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - opts := websocket.DialOptions{} + opts := &websocket.DialOptions{} if tls { opts.HTTPClient = s.Client() } @@ -920,7 +920,7 @@ func TestAutobahnServer(t *testing.T) { } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) if err != nil { @@ -1120,7 +1120,7 @@ func TestAutobahnClient(t *testing.T) { var cases int func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) if err != nil { t.Fatal(err) } @@ -1147,7 +1147,7 @@ func TestAutobahnClient(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second*45) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) if err != nil { t.Fatal(err) } @@ -1155,7 +1155,7 @@ func TestAutobahnClient(t *testing.T) { }() } - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) if err != nil { t.Fatal(err) } @@ -1207,7 +1207,7 @@ func checkWSTestIndex(t *testing.T, path string) { func benchConn(b *testing.B, echo, stream bool, size int) { s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + c, err := websocket.Accept(w, r, nil) if err != nil { return err } @@ -1225,7 +1225,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) + c, _, err := websocket.Dial(ctx, wsURL, nil) if err != nil { b.Fatal(err) } -- GitLab