From aaf4b458c6a66df98da8375425cb54ec47e9540b Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 25 Jan 2020 20:58:09 -0600
Subject: [PATCH] Up test coverage of accept.go to 100%

---
 accept.go      |   6 ++-
 accept_test.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++
 compress.go    |  14 +++--
 conn_test.go   |  10 ++--
 dial.go        |   4 +-
 write.go       |   2 +-
 6 files changed, 164 insertions(+), 12 deletions(-)

diff --git a/accept.go b/accept.go
index f030e4a..d9b4bf9 100644
--- a/accept.go
+++ b/accept.go
@@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
 		w.Header().Set("Sec-WebSocket-Protocol", subproto)
 	}
 
-	copts, err := acceptCompression(r, w, opts.CompressionMode)
+	copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode)
 	if err != nil {
 		return nil, err
 	}
@@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
 		case "server_no_context_takeover":
 			copts.serverNoContextTakeover = true
 			continue
-		case "client_max_window_bits", "server-max-window-bits":
+		}
+
+		if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") {
 			continue
 		}
 
diff --git a/accept_test.go b/accept_test.go
index 2a784d1..8a9e919 100644
--- a/accept_test.go
+++ b/accept_test.go
@@ -3,6 +3,10 @@
 package websocket
 
 import (
+	"bufio"
+	"errors"
+	"net"
+	"net/http"
 	"net/http/httptest"
 	"strings"
 	"testing"
@@ -23,6 +27,38 @@ func TestAccept(t *testing.T) {
 		assert.ErrorContains(t, "Accept", err, "protocol violation")
 	})
 
+	t.Run("badOrigin", func(t *testing.T) {
+		t.Parallel()
+
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest("GET", "/", nil)
+		r.Header.Set("Connection", "Upgrade")
+		r.Header.Set("Upgrade", "websocket")
+		r.Header.Set("Sec-WebSocket-Version", "13")
+		r.Header.Set("Sec-WebSocket-Key", "meow123")
+		r.Header.Set("Origin", "harhar.com")
+
+		_, err := Accept(w, r, nil)
+		assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host")
+	})
+
+	t.Run("badCompression", func(t *testing.T) {
+		t.Parallel()
+
+		w := mockHijacker{
+			ResponseWriter: httptest.NewRecorder(),
+		}
+		r := httptest.NewRequest("GET", "/", nil)
+		r.Header.Set("Connection", "Upgrade")
+		r.Header.Set("Upgrade", "websocket")
+		r.Header.Set("Sec-WebSocket-Version", "13")
+		r.Header.Set("Sec-WebSocket-Key", "meow123")
+		r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
+
+		_, err := Accept(w, r, nil)
+		assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter")
+	})
+
 	t.Run("requireHttpHijacker", func(t *testing.T) {
 		t.Parallel()
 
@@ -36,6 +72,26 @@ func TestAccept(t *testing.T) {
 		_, err := Accept(w, r, nil)
 		assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker")
 	})
+
+	t.Run("badHijack", func(t *testing.T) {
+		t.Parallel()
+
+		w := mockHijacker{
+			ResponseWriter: httptest.NewRecorder(),
+			hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
+				return nil, nil, errors.New("haha")
+			},
+		}
+
+		r := httptest.NewRequest("GET", "/", nil)
+		r.Header.Set("Connection", "Upgrade")
+		r.Header.Set("Upgrade", "websocket")
+		r.Header.Set("Sec-WebSocket-Version", "13")
+		r.Header.Set("Sec-WebSocket-Key", "meow123")
+
+		_, err := Accept(w, r, nil)
+		assert.ErrorContains(t, "Accept", err, "failed to hijack connection")
+	})
 }
 
 func Test_verifyClientHandshake(t *testing.T) {
@@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) {
 }
 
 func Test_acceptCompression(t *testing.T) {
+	t.Parallel()
+
+	testCases := []struct {
+		name                       string
+		mode                       CompressionMode
+		reqSecWebSocketExtensions  string
+		respSecWebSocketExtensions string
+		expCopts                   *compressionOptions
+		error                      bool
+	}{
+		{
+			name:     "disabled",
+			mode:     CompressionDisabled,
+			expCopts: nil,
+		},
+		{
+			name:     "noClientSupport",
+			mode:     CompressionNoContextTakeover,
+			expCopts: nil,
+		},
+		{
+			name:                       "permessage-deflate",
+			mode:                       CompressionNoContextTakeover,
+			reqSecWebSocketExtensions:  "permessage-deflate; client_max_window_bits",
+			respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
+			expCopts: &compressionOptions{
+				clientNoContextTakeover: true,
+				serverNoContextTakeover: true,
+			},
+		},
+		{
+			name:                      "permessage-deflate/error",
+			mode:                      CompressionNoContextTakeover,
+			reqSecWebSocketExtensions: "permessage-deflate; meow",
+			error:                     true,
+		},
+		{
+			name:                       "x-webkit-deflate-frame",
+			mode:                       CompressionNoContextTakeover,
+			reqSecWebSocketExtensions:  "x-webkit-deflate-frame; no_context_takeover",
+			respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
+			expCopts: &compressionOptions{
+				clientNoContextTakeover: true,
+				serverNoContextTakeover: true,
+			},
+		},
+		{
+			name:                      "x-webkit-deflate/error",
+			mode:                      CompressionNoContextTakeover,
+			reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
+			error:                     true,
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			r := httptest.NewRequest(http.MethodGet, "/", nil)
+			r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
+
+			w := httptest.NewRecorder()
+			copts, err := acceptCompression(r, w, tc.mode)
+			if tc.error {
+				assert.Error(t, "acceptCompression", err)
+				return
+			}
+
+			assert.Success(t, "acceptCompression", err)
+			assert.Equal(t, "compresssionOpts", tc.expCopts, copts)
+			assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
+		})
+	}
+}
+
+type mockHijacker struct {
+	http.ResponseWriter
+	hijack func() (net.Conn, *bufio.ReadWriter, error)
+}
+
+var _ http.Hijacker = mockHijacker{}
 
+func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	return mj.hijack()
 }
diff --git a/compress.go b/compress.go
index 62cc9cd..fd2535c 100644
--- a/compress.go
+++ b/compress.go
@@ -9,15 +9,22 @@ import (
 	"sync"
 )
 
+// CompressionOptions represents the available deflate extension options.
+// See https://tools.ietf.org/html/rfc7692
 type CompressionOptions struct {
 	// Mode controls the compression mode.
+	//
+	// See docs on CompressionMode.
 	Mode CompressionMode
 
 	// Threshold controls the minimum size of a message before compression is applied.
+	//
+	// Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes
+	// for CompressionContextTakeover.
 	Threshold int
 }
 
-// CompressionMode controls the modes available RFC 7692's deflate extension.
+// CompressionMode represents the modes available to the deflate extension.
 // See https://tools.ietf.org/html/rfc7692
 //
 // A compatibility layer is implemented for the older deflate-frame extension used
@@ -31,7 +38,7 @@ const (
 	// for every message. This applies to both server and client side.
 	//
 	// This means less efficient compression as the sliding window from previous messages
-	// will not be used but the memory overhead will be much lower if the connections
+	// will not be used but the memory overhead will be lower if the connections
 	// are long lived and seldom used.
 	//
 	// The message will only be compressed if greater than 512 bytes.
@@ -40,8 +47,7 @@ const (
 	// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
 	// This enables reusing the sliding window from previous messages.
 	// As most WebSocket protocols are repetitive, this can be very efficient.
-	//
-	// The message will only be compressed if greater than 128 bytes.
+	// It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover.
 	//
 	// If the peer negotiates NoContextTakeover on the client or server side, it will be
 	// used instead as this is required by the RFC.
diff --git a/conn_test.go b/conn_test.go
index 9b311a8..c8663b4 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -26,7 +26,9 @@ func TestConn(t *testing.T) {
 			c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 				Subprotocols:       []string{"echo"},
 				InsecureSkipVerify: true,
-				CompressionMode:    websocket.CompressionNoContextTakeover,
+				CompressionOptions: websocket.CompressionOptions{
+					Mode: websocket.CompressionNoContextTakeover,
+				},
 			})
 			assert.Success(t, "accept", err)
 			defer c.Close(websocket.StatusInternalError, "")
@@ -42,8 +44,10 @@ func TestConn(t *testing.T) {
 		defer cancel()
 
 		opts := &websocket.DialOptions{
-			Subprotocols:    []string{"echo"},
-			CompressionMode: websocket.CompressionNoContextTakeover,
+			Subprotocols: []string{"echo"},
+			CompressionOptions: websocket.CompressionOptions{
+				Mode: websocket.CompressionNoContextTakeover,
+			},
 		}
 		opts.HTTPClient = s.Client()
 
diff --git a/dial.go b/dial.go
index 43408f2..af94501 100644
--- a/dial.go
+++ b/dial.go
@@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
 	if len(opts.Subprotocols) > 0 {
 		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
 	}
-	if opts.CompressionMode != CompressionDisabled {
-		copts := opts.CompressionMode.opts()
+	if opts.CompressionOptions.Mode != CompressionDisabled {
+		copts := opts.CompressionOptions.Mode.opts()
 		copts.setHeader(req.Header)
 	}
 
diff --git a/write.go b/write.go
index de20e04..33d20c1 100644
--- a/write.go
+++ b/write.go
@@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter {
 
 func (mw *msgWriter) ensureFlateWriter() {
 	if mw.flateWriter == nil {
-		mw.flateWriter = getFlateWriter(mw.trimWriter)
+		mw.flateWriter = getFlateWriter(mw.trimWriter, nil)
 	}
 }
 
-- 
GitLab