From aaf4b458c6a66df98da8375425cb54ec47e9540b Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sat, 25 Jan 2020 20:58:09 -0600 Subject: [PATCH] Up test coverage of accept.go to 100% --- accept.go | 6 ++- accept_test.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++ compress.go | 14 +++-- conn_test.go | 10 ++-- dial.go | 4 +- write.go | 2 +- 6 files changed, 164 insertions(+), 12 deletions(-) diff --git a/accept.go b/accept.go index f030e4a..d9b4bf9 100644 --- a/accept.go +++ b/accept.go @@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) + copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode) if err != nil { return nil, err } @@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - case "client_max_window_bits", "server-max-window-bits": + } + + if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") { continue } diff --git a/accept_test.go b/accept_test.go index 2a784d1..8a9e919 100644 --- a/accept_test.go +++ b/accept_test.go @@ -3,6 +3,10 @@ package websocket import ( + "bufio" + "errors" + "net" + "net/http" "net/http/httptest" "strings" "testing" @@ -23,6 +27,38 @@ func TestAccept(t *testing.T) { assert.ErrorContains(t, "Accept", err, "protocol violation") }) + t.Run("badOrigin", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Origin", "harhar.com") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host") + }) + + t.Run("badCompression", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + } + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter") + }) + t.Run("requireHttpHijacker", func(t *testing.T) { t.Parallel() @@ -36,6 +72,26 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") }) + + t.Run("badHijack", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "failed to hijack connection") + }) } func Test_verifyClientHandshake(t *testing.T) { @@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) { } func Test_acceptCompression(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + mode CompressionMode + reqSecWebSocketExtensions string + respSecWebSocketExtensions string + expCopts *compressionOptions + error bool + }{ + { + name: "disabled", + mode: CompressionDisabled, + expCopts: nil, + }, + { + name: "noClientSupport", + mode: CompressionNoContextTakeover, + expCopts: nil, + }, + { + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", + respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + }, + { + name: "permessage-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; meow", + error: true, + }, + { + name: "x-webkit-deflate-frame", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + }, + { + name: "x-webkit-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", + error: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) + + w := httptest.NewRecorder() + copts, err := acceptCompression(r, w, tc.mode) + if tc.error { + assert.Error(t, "acceptCompression", err) + return + } + + assert.Success(t, "acceptCompression", err) + assert.Equal(t, "compresssionOpts", tc.expCopts, copts) + assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) + }) + } +} + +type mockHijacker struct { + http.ResponseWriter + hijack func() (net.Conn, *bufio.ReadWriter, error) +} + +var _ http.Hijacker = mockHijacker{} +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return mj.hijack() } diff --git a/compress.go b/compress.go index 62cc9cd..fd2535c 100644 --- a/compress.go +++ b/compress.go @@ -9,15 +9,22 @@ import ( "sync" ) +// CompressionOptions represents the available deflate extension options. +// See https://tools.ietf.org/html/rfc7692 type CompressionOptions struct { // Mode controls the compression mode. + // + // See docs on CompressionMode. Mode CompressionMode // Threshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes + // for CompressionContextTakeover. Threshold int } -// CompressionMode controls the modes available RFC 7692's deflate extension. +// CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // // A compatibility layer is implemented for the older deflate-frame extension used @@ -31,7 +38,7 @@ const ( // for every message. This applies to both server and client side. // // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be much lower if the connections + // will not be used but the memory overhead will be lower if the connections // are long lived and seldom used. // // The message will only be compressed if greater than 512 bytes. @@ -40,8 +47,7 @@ const ( // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. - // - // The message will only be compressed if greater than 128 bytes. + // It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. diff --git a/conn_test.go b/conn_test.go index 9b311a8..c8663b4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -26,7 +26,9 @@ func TestConn(t *testing.T) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, - CompressionMode: websocket.CompressionNoContextTakeover, + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionNoContextTakeover, + }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") @@ -42,8 +44,10 @@ func TestConn(t *testing.T) { defer cancel() opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - CompressionMode: websocket.CompressionNoContextTakeover, + Subprotocols: []string{"echo"}, + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionNoContextTakeover, + }, } opts.HTTPClient = s.Client() diff --git a/dial.go b/dial.go index 43408f2..af94501 100644 --- a/dial.go +++ b/dial.go @@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - if opts.CompressionMode != CompressionDisabled { - copts := opts.CompressionMode.opts() + if opts.CompressionOptions.Mode != CompressionDisabled { + copts := opts.CompressionOptions.Mode.opts() copts.setHeader(req.Header) } diff --git a/write.go b/write.go index de20e04..33d20c1 100644 --- a/write.go +++ b/write.go @@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter { func (mw *msgWriter) ensureFlateWriter() { if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) + mw.flateWriter = getFlateWriter(mw.trimWriter, nil) } } -- GitLab