diff --git a/accept.go b/accept.go index 428abba4f1ec96a3977fa79bb323f34fb74fdb5b..6e1f494e32c9027f926d63fa8079ad01b4d713fb 100644 --- a/accept.go +++ b/accept.go @@ -163,13 +163,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } - if !headerContainsToken(r.Header, "Connection", "Upgrade") { + if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } - if !headerContainsToken(r.Header, "Upgrade", "websocket") { + if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) @@ -313,11 +313,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com return copts, nil } -func headerContainsToken(h http.Header, key, token string) bool { - token = strings.ToLower(token) - +func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { for _, t := range headerTokens(h, key) { - if t == token { + if strings.EqualFold(t, token) { return true } } @@ -358,7 +356,6 @@ func headerTokens(h http.Header, key string) []string { for _, v := range h[key] { v = strings.TrimSpace(v) for _, t := range strings.Split(v, ",") { - t = strings.ToLower(t) t = strings.TrimSpace(t) tokens = append(tokens, t) } diff --git a/accept_test.go b/accept_test.go index f7bc669356a7f301974d896be5c64959ab51ba26..d19f54e15c230208420b3e083b4943155a4bca6f 100644 --- a/accept_test.go +++ b/accept_test.go @@ -226,6 +226,12 @@ func Test_selectSubprotocol(t *testing.T) { serverProtocols: []string{"echo2", "echo3"}, negotiated: "echo3", }, + { + name: "clientCasePresered", + clientProtocols: []string{"Echo1"}, + serverProtocols: []string{"echo1"}, + negotiated: "Echo1", + }, } for _, tc := range testCases { diff --git a/ci/test.sh b/ci/test.sh index 95ef710172ea8fab3ecf2a0b64a088c6d0ef3789..bd68b80eb08e235358b60166695da6b9f562efff 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -5,9 +5,9 @@ main() { cd "$(dirname "$0")/.." go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./... - sed -i '/stringer\.go/d' ci/out/coverage.prof - sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof - sed -i '/examples/d' ci/out/coverage.prof + sed -i.bak '/stringer\.go/d' ci/out/coverage.prof + sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof + sed -i.bak '/examples/d' ci/out/coverage.prof # Last line is the total coverage. go tool cover -func ci/out/coverage.prof | tail -n1 diff --git a/dial.go b/dial.go index d5d2266e3b887da3214e80d60e80efb08c58b7dc..a79b55e6f8c963e07c7131a0c103f18d7fc18e88 100644 --- a/dial.go +++ b/dial.go @@ -8,7 +8,6 @@ import ( "context" "crypto/rand" "encoding/base64" - "errors" "fmt" "io" "io/ioutil" @@ -47,18 +46,27 @@ type DialOptions struct { CompressionThreshold int } -func (opts *DialOptions) cloneWithDefaults() *DialOptions { +func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { + var cancel context.CancelFunc + var o DialOptions if opts != nil { o = *opts } if o.HTTPClient == nil { o.HTTPClient = http.DefaultClient + } else if opts.HTTPClient.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + + newClient := *opts.HTTPClient + newClient.Timeout = 0 + opts.HTTPClient = &newClient } if o.HTTPHeader == nil { o.HTTPHeader = http.Header{} } - return &o + + return ctx, cancel, &o } // Dial performs a WebSocket handshake on url. @@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") - opts = opts.cloneWithDefaults() + var cancel context.CancelFunc + ctx, cancel, opts = opts.cloneWithDefaults(ctx) + if cancel != nil { + defer cancel() + } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { @@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { - if opts.HTTPClient.Timeout > 0 { - return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - u, err := url.Parse(urls) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) @@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } - if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } - if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } @@ -242,7 +250,8 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } - copts = &*copts + _copts := *copts + copts = &_copts for _, p := range ext.params { switch p { diff --git a/dial_test.go b/dial_test.go index 7f13a93413bbdd36e9305da6a5b067d8dab6643c..28c255c652d71eb9002243f2aa9617357d1c69e8 100644 --- a/dial_test.go +++ b/dial_test.go @@ -36,15 +36,6 @@ func TestBadDials(t *testing.T) { name: "badURLScheme", url: "ftp://nhooyr.io", }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, { name: "badTLS", url: "wss://totallyfake.nhooyr.io", diff --git a/examples/echo/server.go b/examples/echo/server.go index 308c4a5e6482941266a6ef166fc7e487ef6cb6f2..e9f70f03a15b8cb5f99733c20de3bf5ed3f9ae65 100644 --- a/examples/echo/server.go +++ b/examples/echo/server.go @@ -16,7 +16,6 @@ import ( // It ensures the client speaks the echo subprotocol and // only allows one message every 100ms with a 10 message burst. type echoServer struct { - // logf controls where logs are sent. logf func(f string, v ...interface{}) } diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 1534f3168153134fa1cb3d178cdb964174e95441..f3d4c517f4a673e3b9350b54af687361f7444e84 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -24,7 +24,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) if dialOpts == nil { dialOpts = &websocket.DialOptions{} } - dialOpts = &*dialOpts + _dialOpts := *dialOpts + dialOpts = &_dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, }