diff --git a/dial.go b/dial.go index 2b25e3517d666f5740c4905c9790c8897dd06eb7..509882e0806dd8f4c3cfdd3b58fbb73e0e07be8b 100644 --- a/dial.go +++ b/dial.go @@ -8,7 +8,6 @@ import ( "context" "crypto/rand" "encoding/base64" - "errors" "fmt" "io" "io/ioutil" @@ -74,7 +73,17 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient + } else if opts.HTTPClient.Timeout > 0 { + var cancel context.CancelFunc + + ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + defer cancel() + + newClient := *opts.HTTPClient + newClient.Timeout = 0 + opts.HTTPClient = &newClient } + if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } @@ -133,10 +142,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { - if opts.HTTPClient.Timeout > 0 { - return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) diff --git a/dial_test.go b/dial_test.go index 7f13a93413bbdd36e9305da6a5b067d8dab6643c..28c255c652d71eb9002243f2aa9617357d1c69e8 100644 --- a/dial_test.go +++ b/dial_test.go @@ -36,15 +36,6 @@ func TestBadDials(t *testing.T) { name: "badURLScheme", url: "ftp://nhooyr.io", }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io",