good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 53c1aea0 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Implement compression extension negotiation

parent e142e08c
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,7 @@ type Conn struct { ...@@ -42,6 +42,7 @@ type Conn struct {
writeBuf []byte writeBuf []byte
closer io.Closer closer io.Closer
client bool client bool
copts *CompressionOptions
closeOnce sync.Once closeOnce sync.Once
closeErrOnce sync.Once closeErrOnce sync.Once
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
// - Accept and AcceptOptions // - Accept and AcceptOptions
// - Conn.Ping // - Conn.Ping
// - HTTPClient and HTTPHeader fields in DialOptions // - HTTPClient and HTTPHeader fields in DialOptions
// - CompressionOptions
// //
// The *http.Response returned by Dial will always either be nil or &http.Response{} as // The *http.Response returned by Dial will always either be nil or &http.Response{} as
// we do not have access to the handshake response in the browser. // we do not have access to the handshake response in the browser.
......
...@@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { ...@@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { if !headerContainsToken(r.Header, "Connection", "Upgrade") {
err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return err return err
} }
if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { if !headerContainsToken(r.Header, "Upgrade", "WebSocket") {
err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return err return err
...@@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, ...@@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
w.Header().Set("Sec-WebSocket-Protocol", subproto) w.Header().Set("Sec-WebSocket-Protocol", subproto)
} }
var copts *CompressionOptions
if opts.Compression != nil {
copts, err = negotiateCompression(r.Header, opts.Compression)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if copts != nil {
copts.setHeader(w.Header())
}
}
w.WriteHeader(http.StatusSwitchingProtocols) w.WriteHeader(http.StatusSwitchingProtocols)
netConn, brw, err := hj.Hijack() netConn, brw, err := hj.Hijack()
...@@ -162,17 +174,23 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, ...@@ -162,17 +174,23 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
br: brw.Reader, br: brw.Reader,
bw: brw.Writer, bw: brw.Writer,
closer: netConn, closer: netConn,
copts: copts,
} }
c.init() c.init()
return c, nil return c, nil
} }
func headerValuesContainsToken(h http.Header, key, token string) bool { func headerContainsToken(h http.Header, key, token string) bool {
key = textproto.CanonicalMIMEHeaderKey(key) key = textproto.CanonicalMIMEHeaderKey(key)
for _, val2 := range h[key] { token = strings.ToLower(token)
if headerValueContainsToken(val2, token) { match := func(t string) bool {
return t == token
}
for _, v := range h[key] {
if searchHeaderTokens(v, match) != "" {
return true return true
} }
} }
...@@ -180,22 +198,41 @@ func headerValuesContainsToken(h http.Header, key, token string) bool { ...@@ -180,22 +198,41 @@ func headerValuesContainsToken(h http.Header, key, token string) bool {
return false return false
} }
func headerValueContainsToken(val2, token string) bool { func headerTokenHasPrefix(h http.Header, key, prefix string) string {
val2 = strings.TrimSpace(val2) key = textproto.CanonicalMIMEHeaderKey(key)
for _, val2 := range strings.Split(val2, ",") { prefix = strings.ToLower(prefix)
val2 = strings.TrimSpace(val2) match := func(t string) bool {
if strings.EqualFold(val2, token) { return strings.HasPrefix(t, prefix)
return true }
for _, v := range h[key] {
found := searchHeaderTokens(v, match)
if found != "" {
return found
} }
} }
return false return ""
}
func searchHeaderTokens(v string, match func(val string) bool) string {
v = strings.TrimSpace(v)
for _, v2 := range strings.Split(v, ",") {
v2 = strings.TrimSpace(v2)
v2 = strings.ToLower(v2)
if match(v2) {
return v2
}
}
return ""
} }
func selectSubprotocol(r *http.Request, subprotocols []string) string { func selectSubprotocol(r *http.Request, subprotocols []string) string {
for _, sp := range subprotocols { for _, sp := range subprotocols {
if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) {
return sp return sp
} }
} }
...@@ -268,36 +305,32 @@ type DialOptions struct { ...@@ -268,36 +305,32 @@ type DialOptions struct {
// //
// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. // See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression.
// //
// Enabling compression will increase memory and CPU usage. // Enabling compression will increase memory and CPU usage and should
// Thus it is not ideal for every use case and disabled by default. // be profiled before enabling in production.
// See https://github.com/gorilla/websocket/issues/203 // See https://github.com/gorilla/websocket/issues/203
// Profile before enabling in production.
// //
// This API is experimental and subject to change. // This API is experimental and subject to change.
type CompressionOptions struct { type CompressionOptions struct {
// ServerNoContextTakeover controls whether the server should use context takeover.
// See docs on CompressionOptions for discussion regarding context takeover.
//
// If set by the client, will guarantee that the server does not use context takeover.
ServerNoContextTakeover bool
// ClientNoContextTakeover controls whether the client should use context takeover. // ClientNoContextTakeover controls whether the client should use context takeover.
// See docs on CompressionOptions for discussion regarding context takeover. // See docs on CompressionOptions for discussion regarding context takeover.
// //
// If set by the server, will guarantee that the client does not use context takeover. // If set by the server, will guarantee that the client does not use context takeover.
ClientNoContextTakeover bool ClientNoContextTakeover bool
// ServerNoContextTakeover controls whether the server should use context takeover.
// See docs on CompressionOptions for discussion regarding context takeover.
//
// If set by the client, will guarantee that the server does not use context takeover.
ServerNoContextTakeover bool
// Level controls the compression level used. // Level controls the compression level used.
// Defaults to flate.BestSpeed. // Defaults to flate.BestSpeed.
Level int Level int
// Threshold controls the minimum message size in bytes before compression is used. // Threshold controls the minimum message size in bytes before compression is used.
// In the case of ContextTakeover == false, a flate.Writer will not be grabbed
// from the pool until the message exceeds this threshold.
//
// Must not be greater than 4096 as that is the write buffer's size. // Must not be greater than 4096 as that is the write buffer's size.
// //
// Defaults to 512. // Defaults to 256.
Threshold int Threshold int
} }
...@@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon ...@@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
return c, r, nil return c, r, nil
} }
func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { func (opts *DialOptions) ensure() (*DialOptions, error) {
if opts == nil { if opts == nil {
opts = &DialOptions{} opts = &DialOptions{}
} else {
opts = &*opts
} }
// Shallow copy to ensure defaults do not affect user passed options.
opts2 := *opts
opts = &opts2
if opts.HTTPClient == nil { if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient opts.HTTPClient = http.DefaultClient
} }
if opts.HTTPClient.Timeout > 0 { if opts.HTTPClient.Timeout > 0 {
return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
} }
if opts.HTTPHeader == nil { if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{} opts.HTTPHeader = http.Header{}
} }
return opts, nil
}
func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) {
opts, err = opts.ensure()
if err != nil {
return nil, nil, err
}
parsedURL, err := url.Parse(u) parsedURL, err := url.Parse(u)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to parse url: %w", err) return nil, nil, fmt.Errorf("failed to parse url: %w", err)
...@@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re ...@@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
} }
if opts.Compression != nil { if opts.Compression != nil {
req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") opts.Compression.setHeader(req.Header)
} }
resp, err := opts.HTTPClient.Do(req) resp, err := opts.HTTPClient.Do(req)
...@@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re ...@@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
} }
}() }()
err = verifyServerResponse(req, resp) copts, err := verifyServerResponse(req, resp, opts)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
...@@ -400,6 +440,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re ...@@ -400,6 +440,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
bw: getBufioWriter(rwc), bw: getBufioWriter(rwc),
closer: rwc, closer: rwc,
client: true, client: true,
copts: copts,
} }
c.extractBufioWriterBuf(rwc) c.extractBufioWriterBuf(rwc)
c.init() c.init()
...@@ -407,31 +448,40 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re ...@@ -407,31 +448,40 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
return c, resp, nil return c, resp, nil
} }
func verifyServerResponse(r *http.Request, resp *http.Response) error { func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*CompressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols { if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
} }
if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
} }
if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
} }
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"), resp.Header.Get("Sec-WebSocket-Accept"),
r.Header.Get("Sec-WebSocket-Key"), r.Header.Get("Sec-WebSocket-Key"),
) )
} }
if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) {
return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
} }
return nil var copts *CompressionOptions
if opts.Compression != nil {
var err error
copts, err = negotiateCompression(resp.Header, opts.Compression)
if err != nil {
return nil, err
}
}
return copts, nil
} }
// The below pools can only be used by the client because http.Hijacker will always // The below pools can only be used by the client because http.Hijacker will always
...@@ -477,3 +527,55 @@ func makeSecWebSocketKey() (string, error) { ...@@ -477,3 +527,55 @@ func makeSecWebSocketKey() (string, error) {
} }
return base64.StdEncoding.EncodeToString(b), nil return base64.StdEncoding.EncodeToString(b), nil
} }
func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) {
deflate := headerTokenHasPrefix(h, "Sec-WebSocket-Extensions", "permessage-deflate")
if deflate == "" {
return nil, nil
}
// Ensures our changes do not modify the real compression options.
copts = &*copts
params := strings.Split(deflate, ";")
for i := range params {
params[i] = strings.TrimSpace(params[i])
}
if params[0] != "permessage-deflate" {
return nil, fmt.Errorf("unexpected header format for permessage-deflate extension: %q", deflate)
}
for _, p := range params[1:] {
switch p {
case "client_no_context_takeover":
copts.ClientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.ServerNoContextTakeover = true
continue
case "client_max_window_bits", "server-max-window-bits":
server := h.Get("Sec-WebSocket-Key") != ""
if server {
// If we are the server, we are allowed to ignore these parameters.
// However, if we are the client, we must obey them but because of
// https://github.com/golang/go/issues/3155 we cannot.
continue
}
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter %q in header: %q", p, deflate)
}
return copts, nil
}
func (copts *CompressionOptions) setHeader(h http.Header) {
s := "permessage-deflate"
if copts.ClientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.ServerNoContextTakeover {
s += "; server_no_context_takeover"
}
h.Set("Sec-WebSocket-Extensions", s)
}
...@@ -377,7 +377,7 @@ func Test_verifyServerHandshake(t *testing.T) { ...@@ -377,7 +377,7 @@ func Test_verifyServerHandshake(t *testing.T) {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
} }
err = verifyServerResponse(r, resp) _, err = verifyServerResponse(r, resp, &DialOptions{})
if (err == nil) != tc.success { if (err == nil) != tc.success {
t.Fatalf("unexpected error: %+v", err) t.Fatalf("unexpected error: %+v", err)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment