diff --git a/accept.go b/accept.go index 15e14285ba12f5b65debe471637ea953359914a9..19e388ec42c249accb7206206492357136978930 100644 --- a/accept.go +++ b/accept.go @@ -123,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) - if err != nil { - return nil, err + copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) + if ok { + w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) @@ -238,25 +238,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { return "" } -func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { +func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { - return nil, nil + return nil, false } - - for _, ext := range websocketExtensions(r.Header) { + for _, ext := range extensions { switch ext.name { // We used to implement x-webkit-deflate-fram too but Safari has bugs. // See https://github.com/nhooyr/websocket/issues/218 case "permessage-deflate": - return acceptDeflate(w, ext, mode) + copts, ok := acceptDeflate(ext, mode) + if ok { + return copts, true + } } } - return nil, nil + return nil, false } -func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { +func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() - for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -265,24 +266,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - case "server_max_window_bits=15": + case "client_max_window_bits", + "server_max_window_bits=15": continue } - if strings.HasPrefix(p, "client_max_window_bits") { - // We cannot adjust the read sliding window so cannot make use of this. - // By not responding to it, we tell the client we're ignoring it. + if strings.HasPrefix(p, "client_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } - - err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return nil, false } - - copts.setHeader(w.Header()) - - return copts, nil + return copts, true } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { diff --git a/accept_test.go b/accept_test.go index ae17c0b41c999d207a40e4647b3005bf8d3e45f3..513313ecca07d60264fe0090b292ab53cd98d608 100644 --- a/accept_test.go +++ b/accept_test.go @@ -62,20 +62,50 @@ func TestAccept(t *testing.T) { t.Run("badCompression", func(t *testing.T) { t.Parallel() - w := mockHijacker{ - ResponseWriter: httptest.NewRecorder(), + newRequest := func(extensions string) *http.Request { + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", extensions) + return r + } + errHijack := errors.New("hijack error") + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errHijack + }, + } } - r := httptest.NewRequest("GET", "/", nil) - r.Header.Set("Connection", "Upgrade") - r.Header.Set("Upgrade", "websocket") - r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") - r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") - _, err := Accept(w, r, &AcceptOptions{ - CompressionMode: CompressionContextTakeover, + t.Run("withoutFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "") + }) + t.Run("withFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar, permessage-deflate") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", + w.Header().Get("Sec-WebSocket-Extensions"), + CompressionNoContextTakeover.opts().String(), + ) }) - assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -344,42 +374,53 @@ func Test_authenticateOrigin(t *testing.T) { } } -func Test_acceptCompression(t *testing.T) { +func Test_selectDeflate(t *testing.T) { t.Parallel() testCases := []struct { - name string - mode CompressionMode - reqSecWebSocketExtensions string - respSecWebSocketExtensions string - expCopts *compressionOptions - error bool + name string + mode CompressionMode + header string + expCopts *compressionOptions + expOK bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, + expOK: false, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, + expOK: false, }, { - name: "permessage-deflate", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", - respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, + expOK: true, + }, + { + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow", + expOK: false, }, { - name: "permessage-deflate/error", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; meow", - error: true, + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + expOK: true, }, // { // name: "x-webkit-deflate-frame", @@ -404,19 +445,11 @@ func Test_acceptCompression(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) - - w := httptest.NewRecorder() - copts, err := acceptCompression(r, w, tc.mode) - if tc.error { - assert.Error(t, err) - return - } - - assert.Success(t, err) + h := http.Header{} + h.Set("Sec-WebSocket-Extensions", tc.header) + copts, ok := selectDeflate(websocketExtensions(h), tc.mode) + assert.Equal(t, "selected options", tc.expOK, ok) assert.Equal(t, "compression options", tc.expCopts, copts) - assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } diff --git a/compress.go b/compress.go index 61e6e2681e59ab14876858be3ec7502fd3edffd4..ee21e1d121552b10bc1fcfe10e9c5de79913b8b1 100644 --- a/compress.go +++ b/compress.go @@ -6,7 +6,6 @@ package websocket import ( "compress/flate" "io" - "net/http" "sync" ) @@ -65,7 +64,7 @@ type compressionOptions struct { serverNoContextTakeover bool } -func (copts *compressionOptions) setHeader(h http.Header) { +func (copts *compressionOptions) String() string { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" @@ -73,7 +72,7 @@ func (copts *compressionOptions) setHeader(h http.Header) { if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } - h.Set("Sec-WebSocket-Extensions", s) + return s } // These bytes are required to get flate.Reader to return. diff --git a/dial.go b/dial.go index 9acca1336c58b32632cb7647a9ec370762191dbb..e72432e778f1e5997480b81460e8d15ffa5a2c77 100644 --- a/dial.go +++ b/dial.go @@ -185,7 +185,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { - copts.setHeader(req.Header) + req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req) diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 64c938c5b5ef56339e4f27585cb4374b2192ab23..1b90cc9fd0156a309c4d56a906bc40088e05f969 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -1,6 +1,7 @@ package assert import ( + "errors" "fmt" "reflect" "strings" @@ -43,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) { t.Fatalf("expected %q to contain %q", s, sub) } } + +// ErrorIs asserts errors.Is(got, exp) +func ErrorIs(t testing.TB, exp, got error) { + t.Helper() + + if !errors.Is(got, exp) { + t.Fatalf("expected %v but got %v", exp, got) + } +}