diff --git a/accept.go b/accept.go index f030e4aa12db7129463a1770d35305b4609ec3f1..d9b4bf909107e7e410ac46330cf0847b6ec0bfad 100644 --- a/accept.go +++ b/accept.go @@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) + copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode) if err != nil { return nil, err } @@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - case "client_max_window_bits", "server-max-window-bits": + } + + if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") { continue } diff --git a/accept_test.go b/accept_test.go index 2a784d19203058195c2dde014f08b0671138bd7d..8a9e919841d575e33b3b0a11b6b430e3f5973894 100644 --- a/accept_test.go +++ b/accept_test.go @@ -3,6 +3,10 @@ package websocket import ( + "bufio" + "errors" + "net" + "net/http" "net/http/httptest" "strings" "testing" @@ -23,6 +27,38 @@ func TestAccept(t *testing.T) { assert.ErrorContains(t, "Accept", err, "protocol violation") }) + t.Run("badOrigin", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Origin", "harhar.com") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host") + }) + + t.Run("badCompression", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + } + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter") + }) + t.Run("requireHttpHijacker", func(t *testing.T) { t.Parallel() @@ -36,6 +72,26 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") }) + + t.Run("badHijack", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "failed to hijack connection") + }) } func Test_verifyClientHandshake(t *testing.T) { @@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) { } func Test_acceptCompression(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + mode CompressionMode + reqSecWebSocketExtensions string + respSecWebSocketExtensions string + expCopts *compressionOptions + error bool + }{ + { + name: "disabled", + mode: CompressionDisabled, + expCopts: nil, + }, + { + name: "noClientSupport", + mode: CompressionNoContextTakeover, + expCopts: nil, + }, + { + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", + respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + }, + { + name: "permessage-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; meow", + error: true, + }, + { + name: "x-webkit-deflate-frame", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + }, + { + name: "x-webkit-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", + error: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) + + w := httptest.NewRecorder() + copts, err := acceptCompression(r, w, tc.mode) + if tc.error { + assert.Error(t, "acceptCompression", err) + return + } + + assert.Success(t, "acceptCompression", err) + assert.Equal(t, "compresssionOpts", tc.expCopts, copts) + assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) + }) + } +} + +type mockHijacker struct { + http.ResponseWriter + hijack func() (net.Conn, *bufio.ReadWriter, error) +} + +var _ http.Hijacker = mockHijacker{} +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return mj.hijack() } diff --git a/compress.go b/compress.go index 62cc9cd3ebdeb66cae2fba3d3b40aede449447d4..fd2535cc9943a2c7bb58fc382fc0e9b67e4d1b5b 100644 --- a/compress.go +++ b/compress.go @@ -9,15 +9,22 @@ import ( "sync" ) +// CompressionOptions represents the available deflate extension options. +// See https://tools.ietf.org/html/rfc7692 type CompressionOptions struct { // Mode controls the compression mode. + // + // See docs on CompressionMode. Mode CompressionMode // Threshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes + // for CompressionContextTakeover. Threshold int } -// CompressionMode controls the modes available RFC 7692's deflate extension. +// CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // // A compatibility layer is implemented for the older deflate-frame extension used @@ -31,7 +38,7 @@ const ( // for every message. This applies to both server and client side. // // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be much lower if the connections + // will not be used but the memory overhead will be lower if the connections // are long lived and seldom used. // // The message will only be compressed if greater than 512 bytes. @@ -40,8 +47,7 @@ const ( // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. - // - // The message will only be compressed if greater than 128 bytes. + // It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. diff --git a/conn_test.go b/conn_test.go index 9b311a87f3c23e1cd8cddeed8c3858e900aa7215..c8663b47d0cc2298a2a2f4c072376f639126ff90 100644 --- a/conn_test.go +++ b/conn_test.go @@ -26,7 +26,9 @@ func TestConn(t *testing.T) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, - CompressionMode: websocket.CompressionNoContextTakeover, + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionNoContextTakeover, + }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") @@ -42,8 +44,10 @@ func TestConn(t *testing.T) { defer cancel() opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - CompressionMode: websocket.CompressionNoContextTakeover, + Subprotocols: []string{"echo"}, + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionNoContextTakeover, + }, } opts.HTTPClient = s.Client() diff --git a/dial.go b/dial.go index 43408f20803fde097c11772a9e5ba8a05b2c0f10..af945011b4fbfbd4201d8a6f4a05d0635dcad840 100644 --- a/dial.go +++ b/dial.go @@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - if opts.CompressionMode != CompressionDisabled { - copts := opts.CompressionMode.opts() + if opts.CompressionOptions.Mode != CompressionDisabled { + copts := opts.CompressionOptions.Mode.opts() copts.setHeader(req.Header) } diff --git a/write.go b/write.go index de20e041ce3d92763cf72df93de94c957468e183..33d20c1d7c98491c24093a5d26786cdc5e461c39 100644 --- a/write.go +++ b/write.go @@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter { func (mw *msgWriter) ensureFlateWriter() { if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) + mw.flateWriter = getFlateWriter(mw.trimWriter, nil) } }