good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit aaf4b458 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Up test coverage of accept.go to 100%

parent 8c87970e
No related merge requests found
...@@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto) w.Header().Set("Sec-WebSocket-Protocol", subproto)
} }
copts, err := acceptCompression(r, w, opts.CompressionMode) copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi ...@@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
case "server_no_context_takeover": case "server_no_context_takeover":
copts.serverNoContextTakeover = true copts.serverNoContextTakeover = true
continue continue
case "client_max_window_bits", "server-max-window-bits": }
if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") {
continue continue
} }
......
...@@ -3,6 +3,10 @@ ...@@ -3,6 +3,10 @@
package websocket package websocket
import ( import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
...@@ -23,6 +27,38 @@ func TestAccept(t *testing.T) { ...@@ -23,6 +27,38 @@ func TestAccept(t *testing.T) {
assert.ErrorContains(t, "Accept", err, "protocol violation") assert.ErrorContains(t, "Accept", err, "protocol violation")
}) })
t.Run("badOrigin", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Origin", "harhar.com")
_, err := Accept(w, r, nil)
assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host")
})
t.Run("badCompression", func(t *testing.T) {
t.Parallel()
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
_, err := Accept(w, r, nil)
assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter")
})
t.Run("requireHttpHijacker", func(t *testing.T) { t.Run("requireHttpHijacker", func(t *testing.T) {
t.Parallel() t.Parallel()
...@@ -36,6 +72,26 @@ func TestAccept(t *testing.T) { ...@@ -36,6 +72,26 @@ func TestAccept(t *testing.T) {
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker")
}) })
t.Run("badHijack", func(t *testing.T) {
t.Parallel()
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
_, err := Accept(w, r, nil)
assert.ErrorContains(t, "Accept", err, "failed to hijack connection")
})
} }
func Test_verifyClientHandshake(t *testing.T) { func Test_verifyClientHandshake(t *testing.T) {
...@@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) { ...@@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) {
} }
func Test_acceptCompression(t *testing.T) { func Test_acceptCompression(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
mode CompressionMode
reqSecWebSocketExtensions string
respSecWebSocketExtensions string
expCopts *compressionOptions
error bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
},
{
name: "permessage-deflate/error",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; meow",
error: true,
},
{
name: "x-webkit-deflate-frame",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
},
{
name: "x-webkit-deflate/error",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
error: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
w := httptest.NewRecorder()
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
assert.Error(t, "acceptCompression", err)
return
}
assert.Success(t, "acceptCompression", err)
assert.Equal(t, "compresssionOpts", tc.expCopts, copts)
assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
})
}
}
type mockHijacker struct {
http.ResponseWriter
hijack func() (net.Conn, *bufio.ReadWriter, error)
}
var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack()
} }
...@@ -9,15 +9,22 @@ import ( ...@@ -9,15 +9,22 @@ import (
"sync" "sync"
) )
// CompressionOptions represents the available deflate extension options.
// See https://tools.ietf.org/html/rfc7692
type CompressionOptions struct { type CompressionOptions struct {
// Mode controls the compression mode. // Mode controls the compression mode.
//
// See docs on CompressionMode.
Mode CompressionMode Mode CompressionMode
// Threshold controls the minimum size of a message before compression is applied. // Threshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes
// for CompressionContextTakeover.
Threshold int Threshold int
} }
// CompressionMode controls the modes available RFC 7692's deflate extension. // CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692 // See https://tools.ietf.org/html/rfc7692
// //
// A compatibility layer is implemented for the older deflate-frame extension used // A compatibility layer is implemented for the older deflate-frame extension used
...@@ -31,7 +38,7 @@ const ( ...@@ -31,7 +38,7 @@ const (
// for every message. This applies to both server and client side. // for every message. This applies to both server and client side.
// //
// This means less efficient compression as the sliding window from previous messages // This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be much lower if the connections // will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used. // are long lived and seldom used.
// //
// The message will only be compressed if greater than 512 bytes. // The message will only be compressed if greater than 512 bytes.
...@@ -40,8 +47,7 @@ const ( ...@@ -40,8 +47,7 @@ const (
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
// This enables reusing the sliding window from previous messages. // This enables reusing the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient. // As most WebSocket protocols are repetitive, this can be very efficient.
// // It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover.
// The message will only be compressed if greater than 128 bytes.
// //
// If the peer negotiates NoContextTakeover on the client or server side, it will be // If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC. // used instead as this is required by the RFC.
......
...@@ -26,7 +26,9 @@ func TestConn(t *testing.T) { ...@@ -26,7 +26,9 @@ func TestConn(t *testing.T) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"echo"}, Subprotocols: []string{"echo"},
InsecureSkipVerify: true, InsecureSkipVerify: true,
CompressionMode: websocket.CompressionNoContextTakeover, CompressionOptions: websocket.CompressionOptions{
Mode: websocket.CompressionNoContextTakeover,
},
}) })
assert.Success(t, "accept", err) assert.Success(t, "accept", err)
defer c.Close(websocket.StatusInternalError, "") defer c.Close(websocket.StatusInternalError, "")
...@@ -42,8 +44,10 @@ func TestConn(t *testing.T) { ...@@ -42,8 +44,10 @@ func TestConn(t *testing.T) {
defer cancel() defer cancel()
opts := &websocket.DialOptions{ opts := &websocket.DialOptions{
Subprotocols: []string{"echo"}, Subprotocols: []string{"echo"},
CompressionMode: websocket.CompressionNoContextTakeover, CompressionOptions: websocket.CompressionOptions{
Mode: websocket.CompressionNoContextTakeover,
},
} }
opts.HTTPClient = s.Client() opts.HTTPClient = s.Client()
......
...@@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe ...@@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
if len(opts.Subprotocols) > 0 { if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
} }
if opts.CompressionMode != CompressionDisabled { if opts.CompressionOptions.Mode != CompressionDisabled {
copts := opts.CompressionMode.opts() copts := opts.CompressionOptions.Mode.opts()
copts.setHeader(req.Header) copts.setHeader(req.Header)
} }
......
...@@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter { ...@@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter {
func (mw *msgWriter) ensureFlateWriter() { func (mw *msgWriter) ensureFlateWriter() {
if mw.flateWriter == nil { if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter) mw.flateWriter = getFlateWriter(mw.trimWriter, nil)
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment