good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 95bfb8f5 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Fix negotation of flate parameters

parent 94f9b715
Branches
Tags
No related merge requests found
...@@ -209,7 +209,6 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM ...@@ -209,7 +209,6 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts() copts := mode.opts()
copts.serverMaxWindowBits = 8
for _, p := range ext.params { for _, p := range ext.params {
switch p { switch p {
...@@ -222,26 +221,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi ...@@ -222,26 +221,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
} }
if strings.HasPrefix(p, "client_max_window_bits") { if strings.HasPrefix(p, "client_max_window_bits") {
continue // We cannot adjust the read sliding window so cannot make use of this.
// bits, ok := parseExtensionParameter(p, 15)
// if !ok || bits < 8 || bits > 16 {
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
// http.Error(w, err.Error(), http.StatusBadRequest)
// return nil, err
// }
// copts.clientMaxWindowBits = bits
// continue
}
if false && strings.HasPrefix(p, "server_max_window_bits") {
// We always send back 8 but make sure to validate.
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
continue continue
} }
...@@ -256,14 +236,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi ...@@ -256,14 +236,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
} }
// parseExtensionParameter parses the value in the extension parameter p. // parseExtensionParameter parses the value in the extension parameter p.
// It falls back to defaultVal if there is no value. func parseExtensionParameter(p string) (int, bool) {
// If defaultVal == 0, then ok == false if there is no value.
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
ps := strings.Split(p, "=") ps := strings.Split(p, "=")
if len(ps) == 1 { if len(ps) == 1 {
if defaultVal > 0 {
return defaultVal, true
}
return 0, false return 0, false
} }
i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
......
...@@ -327,7 +327,6 @@ func Test_acceptCompression(t *testing.T) { ...@@ -327,7 +327,6 @@ func Test_acceptCompression(t *testing.T) {
expCopts: &compressionOptions{ expCopts: &compressionOptions{
clientNoContextTakeover: true, clientNoContextTakeover: true,
serverNoContextTakeover: true, serverNoContextTakeover: true,
serverMaxWindowBits: 8,
}, },
}, },
{ {
......
...@@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{ ...@@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{
// We skip the tests related to requestMaxWindowBits as that is unimplemented due // We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155 // to limitations in compress/flate. See https://github.com/golang/go/issues/3155
// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
"13.3.*", "13.4.*", "13.5.*", "13.6.*", "13.3.*", "13.4.*", "13.5.*", "13.6.*",
} }
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
package websocket package websocket
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"sync" "sync"
...@@ -20,10 +19,7 @@ func (m CompressionMode) opts() *compressionOptions { ...@@ -20,10 +19,7 @@ func (m CompressionMode) opts() *compressionOptions {
type compressionOptions struct { type compressionOptions struct {
clientNoContextTakeover bool clientNoContextTakeover bool
clientMaxWindowBits int
serverNoContextTakeover bool serverNoContextTakeover bool
serverMaxWindowBits int
} }
func (copts *compressionOptions) setHeader(h http.Header) { func (copts *compressionOptions) setHeader(h http.Header) {
...@@ -34,12 +30,6 @@ func (copts *compressionOptions) setHeader(h http.Header) { ...@@ -34,12 +30,6 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover { if copts.serverNoContextTakeover {
s += "; server_no_context_takeover" s += "; server_no_context_takeover"
} }
if false && copts.serverMaxWindowBits > 0 {
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
}
if false && copts.clientMaxWindowBits > 0 {
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
}
h.Set("Sec-WebSocket-Extensions", s) h.Set("Sec-WebSocket-Extensions", s)
} }
...@@ -147,6 +137,10 @@ func (sw *slidingWindow) init(n int) { ...@@ -147,6 +137,10 @@ func (sw *slidingWindow) init(n int) {
return return
} }
if n == 0 {
n = 32768
}
p := slidingWindowPool(n) p := slidingWindowPool(n)
buf, ok := p.Get().([]byte) buf, ok := p.Get().([]byte)
if ok { if ok {
......
...@@ -82,7 +82,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( ...@@ -82,7 +82,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
} }
resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey) var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
...@@ -104,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( ...@@ -104,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
} }
}() }()
copts, err := verifyServerResponse(opts, secWebSocketKey, resp) copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
...@@ -125,7 +130,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( ...@@ -125,7 +130,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}), resp, nil }), resp, nil
} }
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) { func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 { if opts.HTTPClient.Timeout > 0 {
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
} }
...@@ -153,9 +158,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe ...@@ -153,9 +158,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
if len(opts.Subprotocols) > 0 { if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
} }
if opts.CompressionMode != CompressionDisabled { if copts != nil {
copts := opts.CompressionMode.opts()
copts.clientMaxWindowBits = 8
copts.setHeader(req.Header) copts.setHeader(req.Header)
} }
...@@ -178,7 +181,7 @@ func secWebSocketKey(rr io.Reader) (string, error) { ...@@ -178,7 +181,7 @@ func secWebSocketKey(rr io.Reader) (string, error) {
return base64.StdEncoding.EncodeToString(b), nil return base64.StdEncoding.EncodeToString(b), nil
} }
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols { if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
} }
...@@ -203,7 +206,7 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http. ...@@ -203,7 +206,7 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.
return nil, err return nil, err
} }
return verifyServerExtensions(resp.Header) return verifyServerExtensions(copts, resp.Header)
} }
func verifySubprotocol(subprotos []string, resp *http.Response) error { func verifySubprotocol(subprotos []string, resp *http.Response) error {
...@@ -221,19 +224,19 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error { ...@@ -221,19 +224,19 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
} }
func verifyServerExtensions(h http.Header) (*compressionOptions, error) { func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h) exts := websocketExtensions(h)
if len(exts) == 0 { if len(exts) == 0 {
return nil, nil return nil, nil
} }
ext := exts[0] ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 { if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
} }
copts := &compressionOptions{} copts = &*copts
copts.clientMaxWindowBits = 8
for _, p := range ext.params { for _, p := range ext.params {
switch p { switch p {
case "client_no_context_takeover": case "client_no_context_takeover":
...@@ -244,24 +247,6 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { ...@@ -244,24 +247,6 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
continue continue
} }
if false && strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
}
copts.serverMaxWindowBits = bits
continue
}
if false && strings.HasPrefix(p, "client_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
}
copts.clientMaxWindowBits = 8
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
} }
......
...@@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) { ...@@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) {
opts := &DialOptions{ opts := &DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
} }
_, err = verifyServerResponse(opts, key, resp) _, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
if tc.success { if tc.success {
assert.Success(t, err) assert.Success(t, err)
} else { } else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment