good morning!!!!

Skip to content
Snippets Groups Projects
dial_test.go 5.34 KiB
Newer Older
// +build !js

package websocket

import (
	"context"
Anmol Sethi's avatar
Anmol Sethi committed
	"crypto/rand"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
Anmol Sethi's avatar
Anmol Sethi committed
	"strings"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/test/assert"
)

func TestBadDials(t *testing.T) {
	t.Parallel()

Anmol Sethi's avatar
Anmol Sethi committed
	t.Run("badReq", func(t *testing.T) {
		t.Parallel()

		testCases := []struct {
			name string
			url  string
			opts *DialOptions
			rand readerFunc
		}{
			{
				name: "badURL",
				url:  "://noscheme",
			},
			{
				name: "badURLScheme",
				url:  "ftp://nhooyr.io",
			},
			{
				name: "badTLS",
				url:  "wss://totallyfake.nhooyr.io",
			},
			{
				name: "badReader",
				rand: func(p []byte) (int, error) {
					return 0, io.EOF
				},
			},
		}
Anmol Sethi's avatar
Anmol Sethi committed
		for _, tc := range testCases {
			tc := tc
			t.Run(tc.name, func(t *testing.T) {
				t.Parallel()
Anmol Sethi's avatar
Anmol Sethi committed
				ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
				defer cancel()
Anmol Sethi's avatar
Anmol Sethi committed
				if tc.rand == nil {
					tc.rand = rand.Reader.Read
				}

				_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
Anmol Sethi's avatar
Anmol Sethi committed
				assert.Error(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
			})
		}
	})

	t.Run("badResponse", func(t *testing.T) {
		t.Parallel()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
		defer cancel()

		_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
			HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
				return &http.Response{
					Body: ioutil.NopCloser(strings.NewReader("hi")),
				}, nil
			}),
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
Anmol Sethi's avatar
Anmol Sethi committed
	})

	t.Run("badBody", func(t *testing.T) {
		t.Parallel()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
		defer cancel()

		rt := func(r *http.Request) (*http.Response, error) {
			h := http.Header{}
			h.Set("Connection", "Upgrade")
			h.Set("Upgrade", "websocket")
			h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))

			return &http.Response{
				StatusCode: http.StatusSwitchingProtocols,
				Header:     h,
				Body:       ioutil.NopCloser(strings.NewReader("hi")),
			}, nil
		}

		_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
			HTTPClient: mockHTTPClient(rt),
		})
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
Anmol Sethi's avatar
Anmol Sethi committed
	})
}

func Test_verifyServerHandshake(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name     string
		response func(w http.ResponseWriter)
		success  bool
	}{
		{
			name: "badStatus",
			response: func(w http.ResponseWriter) {
				w.WriteHeader(http.StatusOK)
			},
			success: false,
		},
		{
			name: "badConnection",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "???")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: false,
		},
		{
			name: "badUpgrade",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "Upgrade")
				w.Header().Set("Upgrade", "???")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: false,
		},
		{
			name: "badSecWebSocketAccept",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "Upgrade")
				w.Header().Set("Upgrade", "websocket")
				w.Header().Set("Sec-WebSocket-Accept", "xd")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: false,
		},
		{
			name: "badSecWebSocketProtocol",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "Upgrade")
				w.Header().Set("Upgrade", "websocket")
				w.Header().Set("Sec-WebSocket-Protocol", "xd")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: false,
		},
Anmol Sethi's avatar
Anmol Sethi committed
		{
			name: "unsupportedExtension",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "Upgrade")
				w.Header().Set("Upgrade", "websocket")
				w.Header().Set("Sec-WebSocket-Extensions", "meow")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: false,
		},
		{
			name: "unsupportedDeflateParam",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "Upgrade")
				w.Header().Set("Upgrade", "websocket")
				w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: false,
		},
		{
			name: "success",
			response: func(w http.ResponseWriter) {
				w.Header().Set("Connection", "Upgrade")
				w.Header().Set("Upgrade", "websocket")
				w.WriteHeader(http.StatusSwitchingProtocols)
			},
			success: true,
		},
	}

	for _, tc := range testCases {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			w := httptest.NewRecorder()
			tc.response(w)
			resp := w.Result()

			r := httptest.NewRequest("GET", "/", nil)
Anmol Sethi's avatar
Anmol Sethi committed
			key, err := secWebSocketKey(rand.Reader)
Anmol Sethi's avatar
Anmol Sethi committed
			assert.Success(t, err)
			r.Header.Set("Sec-WebSocket-Key", key)

			if resp.Header.Get("Sec-WebSocket-Accept") == "" {
				resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
			}

Anmol Sethi's avatar
Anmol Sethi committed
			opts := &DialOptions{
				Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
			}
			_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
Anmol Sethi's avatar
Anmol Sethi committed
			if tc.success {
				assert.Success(t, err)
			} else {
				assert.Error(t, err)
Anmol Sethi's avatar
Anmol Sethi committed

func mockHTTPClient(fn roundTripperFunc) *http.Client {
	return &http.Client{
		Transport: fn,
	}
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
	return f(r)
}