diff --git a/accept.go b/accept.go index 428abba4f1ec96a3977fa79bb323f34fb74fdb5b..6e1f494e32c9027f926d63fa8079ad01b4d713fb 100644 --- a/accept.go +++ b/accept.go @@ -163,13 +163,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } - if !headerContainsToken(r.Header, "Connection", "Upgrade") { + if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } - if !headerContainsToken(r.Header, "Upgrade", "websocket") { + if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) @@ -313,11 +313,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com return copts, nil } -func headerContainsToken(h http.Header, key, token string) bool { - token = strings.ToLower(token) - +func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { - if t == token { + if strings.EqualFold(t, token) { return true } } @@ -358,7 +356,6 @@ func headerTokens(h http.Header, key string) []string { for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { - t = strings.ToLower(t) t = strings.TrimSpace(t) tokens = append(tokens, t) } diff --git a/accept_test.go b/accept_test.go index f7bc669356a7f301974d896be5c64959ab51ba26..d19f54e15c230208420b3e083b4943155a4bca6f 100644 --- a/accept_test.go +++ b/accept_test.go @@ -226,6 +226,12 @@ func Test_selectSubprotocol(t *testing.T) { serverProtocols: []string{"echo2", "echo3"}, negotiated: "echo3", }, + { + name: "clientCasePresered", + clientProtocols: []string{"Echo1"}, + serverProtocols: []string{"echo1"}, + negotiated: "Echo1", + }, } for _, tc := range testCases { diff --git a/dial.go b/dial.go index d5d2266e3b887da3214e80d60e80efb08c58b7dc..7c959bff04b105c1c7ad0be0b2a3671d3ad01533 100644 --- a/dial.go +++ b/dial.go @@ -8,7 +8,6 @@ import ( "context" "crypto/rand" "encoding/base64" - "errors" "fmt" "io" "io/ioutil" @@ -47,18 +46,27 @@ type DialOptions struct { CompressionThreshold int } -func (opts *DialOptions) cloneWithDefaults() *DialOptions { +func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { + var cancel context.CancelFunc + var o DialOptions if opts != nil { o = *opts } if o.HTTPClient == nil { o.HTTPClient = http.DefaultClient + } else if opts.HTTPClient.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + + newClient := *opts.HTTPClient + newClient.Timeout = 0 + opts.HTTPClient = &newClient } if o.HTTPHeader == nil { o.HTTPHeader = http.Header{} } - return &o + + return ctx, cancel, &o } // Dial performs a WebSocket handshake on url. @@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") - opts = opts.cloneWithDefaults() + var cancel context.CancelFunc + ctx, cancel, opts = opts.cloneWithDefaults(ctx) + if cancel != nil { + defer cancel() + } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { @@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { - if opts.HTTPClient.Timeout > 0 { - return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) @@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } - if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } - if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } diff --git a/dial_test.go b/dial_test.go index 7f13a93413bbdd36e9305da6a5b067d8dab6643c..28c255c652d71eb9002243f2aa9617357d1c69e8 100644 --- a/dial_test.go +++ b/dial_test.go @@ -36,15 +36,6 @@ func TestBadDials(t *testing.T) { name: "badURLScheme", url: "ftp://nhooyr.io", }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io", diff --git a/examples/echo/server.go b/examples/echo/server.go index 308c4a5e6482941266a6ef166fc7e487ef6cb6f2..e9f70f03a15b8cb5f99733c20de3bf5ed3f9ae65 100644 --- a/examples/echo/server.go +++ b/examples/echo/server.go @@ -16,7 +16,6 @@ import ( // It ensures the client speaks the echo subprotocol and // only allows one message every 100ms with a 10 message burst. type echoServer struct { - // logf controls where logs are sent. logf func(f string, v ...interface{}) }