diff --git a/conn.go b/conn.go index 26906c7907b8be425ffde089cfad9dd75749e98d..14d93cf6c9e84bef8821230d5206b69fdda38638 100644 --- a/conn.go +++ b/conn.go @@ -42,6 +42,7 @@ type Conn struct { writeBuf []byte closer io.Closer client bool + copts *CompressionOptions closeOnce sync.Once closeErrOnce sync.Once diff --git a/doc.go b/doc.go index b29d2cdd0cf7442735334323d78696ef687e5b39..804665fbda56b07ffb3225e8a48fe9e9a3852756 100644 --- a/doc.go +++ b/doc.go @@ -31,6 +31,7 @@ // - Accept and AcceptOptions // - Conn.Ping // - HTTPClient and HTTPHeader fields in DialOptions +// - CompressionOptions // // The *http.Response returned by Dial will always either be nil or &http.Response{} as // we do not have access to the handshake response in the browser. diff --git a/handshake.go b/handshake.go index 2cde6ae28553646f821a04ed70fa0a02e2c13e1c..787fee2cb704a0e9f897d50425281a68d7f2a529 100644 --- a/handshake.go +++ b/handshake.go @@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { return err } - if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { + if !headerContainsToken(r.Header, "Connection", "Upgrade") { err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) http.Error(w, err.Error(), http.StatusBadRequest) return err } - if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { + if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) http.Error(w, err.Error(), http.StatusBadRequest) return err @@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, w.Header().Set("Sec-WebSocket-Protocol", subproto) } + var copts *CompressionOptions + if opts.Compression != nil { + copts, err = negotiateCompression(r.Header, opts.Compression) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + if copts != nil { + copts.setHeader(w.Header()) + } + } + w.WriteHeader(http.StatusSwitchingProtocols) netConn, brw, err := hj.Hijack() @@ -162,17 +174,23 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, br: brw.Reader, bw: brw.Writer, closer: netConn, + copts: copts, } c.init() return c, nil } -func headerValuesContainsToken(h http.Header, key, token string) bool { +func headerContainsToken(h http.Header, key, token string) bool { key = textproto.CanonicalMIMEHeaderKey(key) - for _, val2 := range h[key] { - if headerValueContainsToken(val2, token) { + token = strings.ToLower(token) + match := func(t string) bool { + return t == token + } + + for _, v := range h[key] { + if searchHeaderTokens(v, match) != "" { return true } } @@ -180,22 +198,41 @@ func headerValuesContainsToken(h http.Header, key, token string) bool { return false } -func headerValueContainsToken(val2, token string) bool { - val2 = strings.TrimSpace(val2) +func headerTokenHasPrefix(h http.Header, key, prefix string) string { + key = textproto.CanonicalMIMEHeaderKey(key) - for _, val2 := range strings.Split(val2, ",") { - val2 = strings.TrimSpace(val2) - if strings.EqualFold(val2, token) { - return true + prefix = strings.ToLower(prefix) + match := func(t string) bool { + return strings.HasPrefix(t, prefix) + } + + for _, v := range h[key] { + found := searchHeaderTokens(v, match) + if found != "" { + return found } } - return false + return "" +} + +func searchHeaderTokens(v string, match func(val string) bool) string { + v = strings.TrimSpace(v) + + for _, v2 := range strings.Split(v, ",") { + v2 = strings.TrimSpace(v2) + v2 = strings.ToLower(v2) + if match(v2) { + return v2 + } + } + + return "" } func selectSubprotocol(r *http.Request, subprotocols []string) string { for _, sp := range subprotocols { - if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { + if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { return sp } } @@ -268,36 +305,32 @@ type DialOptions struct { // // See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. // -// Enabling compression will increase memory and CPU usage. -// Thus it is not ideal for every use case and disabled by default. +// Enabling compression will increase memory and CPU usage and should +// be profiled before enabling in production. // See https://github.com/gorilla/websocket/issues/203 -// Profile before enabling in production. // // This API is experimental and subject to change. type CompressionOptions struct { - // ServerNoContextTakeover controls whether the server should use context takeover. - // See docs on CompressionOptions for discussion regarding context takeover. - // - // If set by the client, will guarantee that the server does not use context takeover. - ServerNoContextTakeover bool - // ClientNoContextTakeover controls whether the client should use context takeover. // See docs on CompressionOptions for discussion regarding context takeover. // // If set by the server, will guarantee that the client does not use context takeover. ClientNoContextTakeover bool + // ServerNoContextTakeover controls whether the server should use context takeover. + // See docs on CompressionOptions for discussion regarding context takeover. + // + // If set by the client, will guarantee that the server does not use context takeover. + ServerNoContextTakeover bool + // Level controls the compression level used. // Defaults to flate.BestSpeed. Level int // Threshold controls the minimum message size in bytes before compression is used. - // In the case of ContextTakeover == false, a flate.Writer will not be grabbed - // from the pool until the message exceeds this threshold. - // // Must not be greater than 4096 as that is the write buffer's size. // - // Defaults to 512. + // Defaults to 256. Threshold int } @@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon return c, r, nil } -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { +func (opts *DialOptions) ensure() (*DialOptions, error) { if opts == nil { opts = &DialOptions{} + } else { + opts = &*opts } - // Shallow copy to ensure defaults do not affect user passed options. - opts2 := *opts - opts = &opts2 - if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } if opts.HTTPClient.Timeout > 0 { - return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } + return opts, nil +} + +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + opts, err = opts.ensure() + if err != nil { + return nil, nil, err + } + parsedURL, err := url.Parse(u) if err != nil { return nil, nil, fmt.Errorf("failed to parse url: %w", err) @@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if opts.Compression != nil { - req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + opts.Compression.setHeader(req.Header) } resp, err := opts.HTTPClient.Do(req) @@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re } }() - err = verifyServerResponse(req, resp) + copts, err := verifyServerResponse(req, resp, opts) if err != nil { return nil, resp, err } @@ -400,6 +440,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re bw: getBufioWriter(rwc), closer: rwc, client: true, + copts: copts, } c.extractBufioWriterBuf(rwc) c.init() @@ -407,31 +448,40 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re return c, resp, nil } -func verifyServerResponse(r *http.Request, resp *http.Response) error { +func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*CompressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { - return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } - if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { - return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } - if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { - return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), r.Header.Get("Sec-WebSocket-Key"), ) } - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { + return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } - return nil + var copts *CompressionOptions + if opts.Compression != nil { + var err error + copts, err = negotiateCompression(resp.Header, opts.Compression) + if err != nil { + return nil, err + } + } + + return copts, nil } // The below pools can only be used by the client because http.Hijacker will always @@ -477,3 +527,55 @@ func makeSecWebSocketKey() (string, error) { } return base64.StdEncoding.EncodeToString(b), nil } + +func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) { + deflate := headerTokenHasPrefix(h, "Sec-WebSocket-Extensions", "permessage-deflate") + if deflate == "" { + return nil, nil + } + + // Ensures our changes do not modify the real compression options. + copts = &*copts + + params := strings.Split(deflate, ";") + for i := range params { + params[i] = strings.TrimSpace(params[i]) + } + + if params[0] != "permessage-deflate" { + return nil, fmt.Errorf("unexpected header format for permessage-deflate extension: %q", deflate) + } + + for _, p := range params[1:] { + switch p { + case "client_no_context_takeover": + copts.ClientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.ServerNoContextTakeover = true + continue + case "client_max_window_bits", "server-max-window-bits": + server := h.Get("Sec-WebSocket-Key") != "" + if server { + // If we are the server, we are allowed to ignore these parameters. + // However, if we are the client, we must obey them but because of + // https://github.com/golang/go/issues/3155 we cannot. + continue + } + } + return nil, fmt.Errorf("unsupported permessage-deflate parameter %q in header: %q", p, deflate) + } + + return copts, nil +} + +func (copts *CompressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.ClientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.ServerNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} diff --git a/handshake_test.go b/handshake_test.go index cb09353f65d3928630aad627684b046f0b5ddd92..82f958e052cd1dbbd7e782678896511eeefeadcc 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -377,7 +377,7 @@ func Test_verifyServerHandshake(t *testing.T) { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - err = verifyServerResponse(r, resp) + _, err = verifyServerResponse(r, resp, &DialOptions{}) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) }