good morning!!!!

Skip to content
Snippets Groups Projects
accept_test.go 10.2 KiB
Newer Older
Anmol Sethi's avatar
Anmol Sethi committed
// +build !js

Anmol Sethi's avatar
Anmol Sethi committed
package websocket

import (
	"net"
	"net/http"
Anmol Sethi's avatar
Anmol Sethi committed
	"net/http/httptest"
	"strings"
	"testing"
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/test/assert"
func TestAccept(t *testing.T) {
	t.Parallel()

	t.Run("badClientHandshake", func(t *testing.T) {
		t.Parallel()

		w := httptest.NewRecorder()
		r := httptest.NewRequest("GET", "/", nil)

		_, err := Accept(w, r, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, "protocol violation")
	t.Run("badOrigin", func(t *testing.T) {
		t.Parallel()

		w := httptest.NewRecorder()
		r := httptest.NewRequest("GET", "/", nil)
		r.Header.Set("Connection", "Upgrade")
		r.Header.Set("Upgrade", "websocket")
		r.Header.Set("Sec-WebSocket-Version", "13")
		r.Header.Set("Sec-WebSocket-Key", "meow123")
		r.Header.Set("Origin", "harhar.com")

		_, err := Accept(w, r, nil)
		assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`)
	})

	// #247
	t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) {
		t.Parallel()

		w := httptest.NewRecorder()
		r := httptest.NewRequest("GET", "/", nil)
		r.Header.Set("Connection", "Upgrade")
		r.Header.Set("Upgrade", "websocket")
		r.Header.Set("Sec-WebSocket-Version", "13")
		r.Header.Set("Sec-WebSocket-Key", "meow123")
		r.Header.Set("Origin", "https://harhar.com")

		_, err := Accept(w, r, nil)
		assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`)
	})

	t.Run("badCompression", func(t *testing.T) {
		t.Parallel()

		w := mockHijacker{
			ResponseWriter: httptest.NewRecorder(),
		}
		r := httptest.NewRequest("GET", "/", nil)
		r.Header.Set("Connection", "Upgrade")
		r.Header.Set("Upgrade", "websocket")
		r.Header.Set("Sec-WebSocket-Version", "13")
		r.Header.Set("Sec-WebSocket-Key", "meow123")
		r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")

		_, err := Accept(w, r, &AcceptOptions{
			CompressionMode: CompressionContextTakeover,
		})
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, `unsupported permessage-deflate parameter`)
	t.Run("requireHttpHijacker", func(t *testing.T) {
		t.Parallel()

		w := httptest.NewRecorder()
		r := httptest.NewRequest("GET", "/", nil)
		r.Header.Set("Connection", "Upgrade")
		r.Header.Set("Upgrade", "websocket")
		r.Header.Set("Sec-WebSocket-Version", "13")
		r.Header.Set("Sec-WebSocket-Key", "meow123")

		_, err := Accept(w, r, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)

	t.Run("badHijack", func(t *testing.T) {
		t.Parallel()

		w := mockHijacker{
			ResponseWriter: httptest.NewRecorder(),
			hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
				return nil, nil, errors.New("haha")
			},
		}

		r := httptest.NewRequest("GET", "/", nil)
		r.Header.Set("Connection", "Upgrade")
		r.Header.Set("Upgrade", "websocket")
		r.Header.Set("Sec-WebSocket-Version", "13")
		r.Header.Set("Sec-WebSocket-Key", "meow123")

		_, err := Accept(w, r, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, `failed to hijack connection`)
Anmol Sethi's avatar
Anmol Sethi committed
func Test_verifyClientHandshake(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name    string
		method  string
		http1   bool
Anmol Sethi's avatar
Anmol Sethi committed
		h       map[string]string
		success bool
	}{
		{
			name: "badConnection",
			h: map[string]string{
				"Connection": "notUpgrade",
			},
		},
		{
			name: "badUpgrade",
			h: map[string]string{
				"Connection": "Upgrade",
				"Upgrade":    "notWebSocket",
			},
		},
		{
			name:   "badMethod",
			method: "POST",
			h: map[string]string{
				"Connection": "Upgrade",
				"Upgrade":    "websocket",
			},
		},
		{
			name: "badWebSocketVersion",
			h: map[string]string{
				"Connection":            "Upgrade",
				"Upgrade":               "websocket",
				"Sec-WebSocket-Version": "14",
			},
		},
		{
			name: "badWebSocketKey",
			h: map[string]string{
				"Connection":            "Upgrade",
				"Upgrade":               "websocket",
				"Sec-WebSocket-Version": "13",
				"Sec-WebSocket-Key":     "",
			},
		},
		{
			name: "badHTTPVersion",
			h: map[string]string{
				"Connection":            "Upgrade",
				"Upgrade":               "websocket",
				"Sec-WebSocket-Version": "13",
				"Sec-WebSocket-Key":     "meow123",
			},
			http1: true,
		},
Anmol Sethi's avatar
Anmol Sethi committed
		{
			name: "success",
			h: map[string]string{
Anmol Sethi's avatar
Anmol Sethi committed
				"Connection":            "keep-alive, Upgrade",
Anmol Sethi's avatar
Anmol Sethi committed
				"Upgrade":               "websocket",
				"Sec-WebSocket-Version": "13",
				"Sec-WebSocket-Key":     "meow123",
			},
			success: true,
		},
	}

	for _, tc := range testCases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			r := httptest.NewRequest(tc.method, "/", nil)

			r.ProtoMajor = 1
			r.ProtoMinor = 1
			if tc.http1 {
				r.ProtoMinor = 0
			}

Anmol Sethi's avatar
Anmol Sethi committed
			for k, v := range tc.h {
				r.Header.Set(k, v)
			}

			_, err := verifyClientRequest(httptest.NewRecorder(), r)
Anmol Sethi's avatar
Anmol Sethi committed
			if tc.success {
				assert.Success(t, err)
			} else {
				assert.Error(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
			}
		})
	}
}

func Test_selectSubprotocol(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name            string
		clientProtocols []string
		serverProtocols []string
		negotiated      string
	}{
		{
			name:            "empty",
			clientProtocols: nil,
			serverProtocols: nil,
			negotiated:      "",
		},
		{
			name:            "basic",
			clientProtocols: []string{"echo", "echo2"},
			serverProtocols: []string{"echo2", "echo"},
			negotiated:      "echo2",
		},
		{
			name:            "none",
			clientProtocols: []string{"echo", "echo3"},
			serverProtocols: []string{"echo2", "echo4"},
			negotiated:      "",
		},
		{
			name:            "fallback",
			clientProtocols: []string{"echo", "echo3"},
			serverProtocols: []string{"echo2", "echo3"},
			negotiated:      "echo3",
		},
		{
			name:            "clientCasePresered",
			clientProtocols: []string{"Echo1"},
			serverProtocols: []string{"echo1"},
			negotiated:      "Echo1",
		},
Anmol Sethi's avatar
Anmol Sethi committed
	}

	for _, tc := range testCases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			r := httptest.NewRequest("GET", "/", nil)
			r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))

			negotiated := selectSubprotocol(r, tc.serverProtocols)
Anmol Sethi's avatar
Anmol Sethi committed
			assert.Equal(t, "negotiated", tc.negotiated, negotiated)
Anmol Sethi's avatar
Anmol Sethi committed
		})
	}
}

func Test_authenticateOrigin(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name           string
		origin         string
		host           string
		originPatterns []string
		success        bool
Anmol Sethi's avatar
Anmol Sethi committed
	}{
		{
			name:    "none",
			success: true,
			host:    "example.com",
Anmol Sethi's avatar
Anmol Sethi committed
		},
		{
			name:    "invalid",
			origin:  "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
			host:    "example.com",
Anmol Sethi's avatar
Anmol Sethi committed
			success: false,
		},
		{
			name:    "unauthorized",
			origin:  "https://example.com",
			host:    "example1.com",
			success: false,
			name:    "authorized",
			origin:  "https://example.com",
			host:    "example.com",
			success: true,
			name:    "authorizedCaseInsensitive",
			origin:  "https://examplE.com",
			host:    "example.com",
			success: true,
		{
			name:   "originPatterns",
			origin: "https://two.examplE.com",
			host:   "example.com",
			originPatterns: []string{
				"*.example.com",
				"bar.com",
			},
			success: true,
		},
		{
			name:   "originPatternsUnauthorized",
			origin: "https://two.examplE.com",
			host:   "example.com",
			originPatterns: []string{
				"exam3.com",
				"bar.com",
			},
			success: false,
		},
Anmol Sethi's avatar
Anmol Sethi committed
	}

	for _, tc := range testCases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
Anmol Sethi's avatar
Anmol Sethi committed
			r.Header.Set("Origin", tc.origin)

			err := authenticateOrigin(r, tc.originPatterns)
Anmol Sethi's avatar
Anmol Sethi committed
			if tc.success {
				assert.Success(t, err)
			} else {
				assert.Error(t, err)
Anmol Sethi's avatar
Anmol Sethi committed

func Test_acceptCompression(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name                       string
		mode                       CompressionMode
		reqSecWebSocketExtensions  string
		respSecWebSocketExtensions string
		expCopts                   *compressionOptions
		error                      bool
	}{
		{
			name:     "disabled",
			mode:     CompressionDisabled,
			expCopts: nil,
		},
		{
			name:     "noClientSupport",
			mode:     CompressionNoContextTakeover,
			expCopts: nil,
		},
		{
			name:                       "permessage-deflate",
			mode:                       CompressionNoContextTakeover,
			reqSecWebSocketExtensions:  "permessage-deflate; client_max_window_bits",
			respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
			expCopts: &compressionOptions{
				clientNoContextTakeover: true,
				serverNoContextTakeover: true,
			},
		},
		{
			name:                      "permessage-deflate/error",
			mode:                      CompressionNoContextTakeover,
			reqSecWebSocketExtensions: "permessage-deflate; meow",
			error:                     true,
		},
Anmol Sethi's avatar
Anmol Sethi committed
		// {
		// 	name:                       "x-webkit-deflate-frame",
		// 	mode:                       CompressionNoContextTakeover,
		// 	reqSecWebSocketExtensions:  "x-webkit-deflate-frame; no_context_takeover",
		// 	respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
		// 	expCopts: &compressionOptions{
		// 		clientNoContextTakeover: true,
		// 		serverNoContextTakeover: true,
		// 	},
		// },
		// {
		// 	name:                      "x-webkit-deflate/error",
		// 	mode:                      CompressionNoContextTakeover,
		// 	reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
		// 	error:                     true,
		// },
	}

	for _, tc := range testCases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			r := httptest.NewRequest(http.MethodGet, "/", nil)
			r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)

			w := httptest.NewRecorder()
			copts, err := acceptCompression(r, w, tc.mode)
			if tc.error {
Anmol Sethi's avatar
Anmol Sethi committed
				assert.Error(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
			assert.Success(t, err)
			assert.Equal(t, "compression options", tc.expCopts, copts)
			assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
		})
	}
}

type mockHijacker struct {
	http.ResponseWriter
	hijack func() (net.Conn, *bufio.ReadWriter, error)
}

var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
	return mj.hijack()
Anmol Sethi's avatar
Anmol Sethi committed
}