diff --git a/dial.go b/dial.go index 7a7787ff71fafb89321fadbb4f4b33ccc1fa93c6..0ae0d570c6ecf077ec6d4eaa94879db6ae1b9f13 100644 --- a/dial.go +++ b/dial.go @@ -157,7 +157,10 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } - req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to build HTTP request: %w", err) + } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") diff --git a/dial_test.go b/dial_test.go index 28c255c652d71eb9002243f2aa9617357d1c69e8..80ba9a3d718cfff926520cc312d2a16d74bfdbba 100644 --- a/dial_test.go +++ b/dial_test.go @@ -23,10 +23,11 @@ func TestBadDials(t *testing.T) { t.Parallel() testCases := []struct { - name string - url string - opts *DialOptions - rand readerFunc + name string + url string + opts *DialOptions + rand readerFunc + nilCtx bool }{ { name: "badURL", @@ -46,6 +47,11 @@ func TestBadDials(t *testing.T) { return 0, io.EOF }, }, + { + name: "nilContext", + url: "http://localhost", + nilCtx: true, + }, } for _, tc := range testCases { @@ -53,8 +59,12 @@ func TestBadDials(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + var ctx context.Context + var cancel func() + if !tc.nilCtx { + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + } if tc.rand == nil { tc.rand = rand.Reader.Read