good morning!!!!

Skip to content
Snippets Groups Projects
dial.go 7.14 KiB
Newer Older
Anmol Sethi's avatar
Anmol Sethi committed
// +build !js

package websocket

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bufio"
	"bytes"
	"context"
	"crypto/rand"
	"encoding/base64"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"strings"
Anmol Sethi's avatar
Anmol Sethi committed
	"sync"
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	"golang.org/x/xerrors"

Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/errd"
Anmol Sethi's avatar
Anmol Sethi committed
// DialOptions represents Dial's options.
type DialOptions struct {
Anmol Sethi's avatar
Anmol Sethi committed
	// HTTPClient is used for the connection.
	// Its Transport must return writable bodies for WebSocket handshakes.
	// http.Transport does beginning with Go 1.12.
	HTTPClient *http.Client

	// HTTPHeader specifies the HTTP headers included in the handshake request.
	HTTPHeader http.Header

Anmol Sethi's avatar
Anmol Sethi committed
	// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
	Subprotocols []string

Anmol Sethi's avatar
Anmol Sethi committed
	// CompressionOptions controls the compression options.
	// See docs on the CompressionOptions type.
Anmol Sethi's avatar
Anmol Sethi committed
	CompressionOptions *CompressionOptions
Anmol Sethi's avatar
Anmol Sethi committed
// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
Anmol Sethi's avatar
Anmol Sethi committed
// You never need to close resp.Body yourself.
Anmol Sethi's avatar
Anmol Sethi committed
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
Anmol Sethi's avatar
Anmol Sethi committed
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	return dial(ctx, u, opts, nil)
Anmol Sethi's avatar
Anmol Sethi committed
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	defer errd.Wrap(&err, "failed to WebSocket dial")

	if opts == nil {
		opts = &DialOptions{}
	}
Anmol Sethi's avatar
Anmol Sethi committed
	opts = &*opts
	if opts.HTTPClient == nil {
		opts.HTTPClient = http.DefaultClient
	}
	if opts.HTTPHeader == nil {
		opts.HTTPHeader = http.Header{}
	}
Anmol Sethi's avatar
Anmol Sethi committed
	if opts.CompressionOptions == nil {
		opts.CompressionOptions = &CompressionOptions{}
	}
Anmol Sethi's avatar
Anmol Sethi committed
	secWebSocketKey, err := secWebSocketKey(rand)
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, nil, xerrors.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, resp, err
Anmol Sethi's avatar
Anmol Sethi committed
	respBody := resp.Body
	resp.Body = nil
	defer func() {
		if err != nil {
			// We read a bit of the body for easier debugging.
Anmol Sethi's avatar
Anmol Sethi committed
			r := io.LimitReader(respBody, 1024)
			b, _ := ioutil.ReadAll(r)
Anmol Sethi's avatar
Anmol Sethi committed
			respBody.Close()
			resp.Body = ioutil.NopCloser(bytes.NewReader(b))
		}
	}()

Anmol Sethi's avatar
Anmol Sethi committed
	copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
	if err != nil {
		return nil, resp, err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	rwc, ok := respBody.(io.ReadWriteCloser)
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
Anmol Sethi's avatar
Anmol Sethi committed
	return newConn(connConfig{
Anmol Sethi's avatar
Anmol Sethi committed
		subprotocol:    resp.Header.Get("Sec-WebSocket-Protocol"),
		rwc:            rwc,
		client:         true,
		copts:          copts,
		flateThreshold: opts.CompressionOptions.Threshold,
		br:             getBufioReader(rwc),
		bw:             getBufioWriter(rwc),
Anmol Sethi's avatar
Anmol Sethi committed
	}), resp, nil
Anmol Sethi's avatar
Anmol Sethi committed
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
	if opts.HTTPClient.Timeout > 0 {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
Anmol Sethi's avatar
Anmol Sethi committed
	}

	u, err := url.Parse(urls)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("failed to parse url: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}

	switch u.Scheme {
	case "ws":
		u.Scheme = "http"
	case "wss":
		u.Scheme = "https"
	default:
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("unexpected url scheme: %q", u.Scheme)
Anmol Sethi's avatar
Anmol Sethi committed
	}

	req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
	req.Header = opts.HTTPHeader.Clone()
	req.Header.Set("Connection", "Upgrade")
	req.Header.Set("Upgrade", "websocket")
	req.Header.Set("Sec-WebSocket-Version", "13")
	req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
	if len(opts.Subprotocols) > 0 {
		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
	}
	if opts.CompressionOptions.Mode != CompressionDisabled {
		copts := opts.CompressionOptions.Mode.opts()
Anmol Sethi's avatar
Anmol Sethi committed
		copts.setHeader(req.Header)
	}

	resp, err := opts.HTTPClient.Do(req)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("failed to send handshake request: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return resp, nil
}

Anmol Sethi's avatar
Anmol Sethi committed
func secWebSocketKey(rr io.Reader) (string, error) {
	if rr == nil {
		rr = rand.Reader
	}
	b := make([]byte, 16)
Anmol Sethi's avatar
Anmol Sethi committed
	_, err := io.ReadFull(rr, b)
Anmol Sethi's avatar
Anmol Sethi committed
		return "", xerrors.Errorf("failed to read random data from rand.Reader: %w", err)
	}
	return base64.StdEncoding.EncodeToString(b), nil
}

Anmol Sethi's avatar
Anmol Sethi committed
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
	if resp.StatusCode != http.StatusSwitchingProtocols {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
	}

	if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
	}

	if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
Anmol Sethi's avatar
Anmol Sethi committed
	if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
			resp.Header.Get("Sec-WebSocket-Accept"),
Anmol Sethi's avatar
Anmol Sethi committed
			secWebSocketKey,
Anmol Sethi's avatar
Anmol Sethi committed
	err := verifySubprotocol(opts.Subprotocols, resp)
	if err != nil {
		return nil, err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	return verifyServerExtensions(resp.Header)
}

func verifySubprotocol(subprotos []string, resp *http.Response) error {
	proto := resp.Header.Get("Sec-WebSocket-Protocol")
	if proto == "" {
		return nil
	}

	for _, sp2 := range subprotos {
		if strings.EqualFold(sp2, proto) {
			return nil
		}
	}

Anmol Sethi's avatar
Anmol Sethi committed
	return xerrors.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
Anmol Sethi's avatar
Anmol Sethi committed
func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
	exts := websocketExtensions(h)
	if len(exts) == 0 {
		return nil, nil
	}

	ext := exts[0]
Anmol Sethi's avatar
Anmol Sethi committed
	if ext.name != "permessage-deflate" || len(exts) > 1 {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, xerrors.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
Anmol Sethi's avatar
Anmol Sethi committed
	copts := &compressionOptions{}
	for _, p := range ext.params {
		switch p {
		case "client_no_context_takeover":
			copts.clientNoContextTakeover = true
		case "server_no_context_takeover":
			copts.serverNoContextTakeover = true
Anmol Sethi's avatar
Anmol Sethi committed
		default:
Anmol Sethi's avatar
Anmol Sethi committed
			return nil, xerrors.Errorf("unsupported permessage-deflate parameter: %q", p)
Anmol Sethi's avatar
Anmol Sethi committed

var readerPool sync.Pool

func getBufioReader(r io.Reader) *bufio.Reader {
	br, ok := readerPool.Get().(*bufio.Reader)
	if !ok {
		return bufio.NewReader(r)
	}
	br.Reset(r)
	return br
}

func putBufioReader(br *bufio.Reader) {
	readerPool.Put(br)
}

var writerPool sync.Pool

func getBufioWriter(w io.Writer) *bufio.Writer {
	bw, ok := writerPool.Get().(*bufio.Writer)
	if !ok {
		return bufio.NewWriter(w)
	}
	bw.Reset(w)
	return bw
}

func putBufioWriter(bw *bufio.Writer) {
	writerPool.Put(bw)
}