good morning!!!!

Skip to content
Snippets Groups Projects
dial.go 4.1 KiB
Newer Older
package websocket
Anmol Sethi's avatar
Anmol Sethi committed

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bufio"
Anmol Sethi's avatar
Anmol Sethi committed
	"bytes"
Anmol Sethi's avatar
Anmol Sethi committed
	"context"
	"encoding/base64"
Anmol Sethi's avatar
Anmol Sethi committed
	"io"
	"io/ioutil"
Anmol Sethi's avatar
Anmol Sethi committed
	"net/http"
Anmol Sethi's avatar
Anmol Sethi committed
	"net/url"
	"strings"

	"golang.org/x/net/http/httpguts"
	"golang.org/x/xerrors"
Anmol Sethi's avatar
Anmol Sethi committed
)

// DialOption represents a dial option that can be passed to Dial.
Anmol Sethi's avatar
Anmol Sethi committed
// The implementations are printable for easy debugging.
Anmol Sethi's avatar
Anmol Sethi committed
type DialOption interface {
	dialOption()
}

Anmol Sethi's avatar
Anmol Sethi committed
type dialHTTPClient http.Client

func (o dialHTTPClient) dialOption() {}

Anmol Sethi's avatar
Anmol Sethi committed
// DialHTTPClient is the http client used for the handshake.
// Its Transport must use HTTP/1.1 and must return writable bodies
// for WebSocket handshakes.
// http.Transport does this correctly.
Anmol Sethi's avatar
Anmol Sethi committed
func DialHTTPClient(hc *http.Client) DialOption {
	return (*dialHTTPClient)(hc)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
type dialHeader http.Header

func (o dialHeader) dialOption() {}

Anmol Sethi's avatar
Anmol Sethi committed
// DialHeader are the HTTP headers included in the handshake request.
func DialHeader(h http.Header) DialOption {
Anmol Sethi's avatar
Anmol Sethi committed
	return dialHeader(h)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
type dialSubprotocols []string

func (o dialSubprotocols) dialOption() {}

Anmol Sethi's avatar
Anmol Sethi committed
// DialSubprotocols accepts a slice of protcols to include in the Sec-WebSocket-Protocol header.
func DialSubprotocols(subprotocols ...string) DialOption {
Anmol Sethi's avatar
Anmol Sethi committed
	return dialSubprotocols(subprotocols)
Anmol Sethi's avatar
Anmol Sethi committed
}

// We use this key for all client requests as the Sec-WebSocket-Key header is useless.
// See https://stackoverflow.com/a/37074398/4283659.
Anmol Sethi's avatar
Anmol Sethi committed
// We also use the same mask key for every message as it too does not make a difference.
Anmol Sethi's avatar
Anmol Sethi committed
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))

Anmol Sethi's avatar
Anmol Sethi committed
// Dial performs a WebSocket handshake on the given url with the given options.
Anmol Sethi's avatar
Anmol Sethi committed
func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.Response, err error) {
	httpClient := http.DefaultClient
	var subprotocols []string
	header := http.Header{}
	for _, o := range opts {
		switch o := o.(type) {
		case dialSubprotocols:
			subprotocols = o
		case dialHeader:
			header = http.Header(o)
		case *dialHTTPClient:
			httpClient = (*http.Client)(o)
		}
	}

	parsedURL, err := url.Parse(u)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, nil, xerrors.Errorf("failed to parse websocket url: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}

	switch parsedURL.Scheme {
	case "ws", "http":
		parsedURL.Scheme = "http"
	case "wss", "https":
		parsedURL.Scheme = "https"
	default:
		return nil, nil, xerrors.Errorf("unknown scheme in url: %q", parsedURL.Scheme)
	}

	req, _ := http.NewRequest("GET", u, nil)
	req = req.WithContext(ctx)
	req.Header = header
	req.Header.Set("Connection", "Upgrade")
	req.Header.Set("Upgrade", "websocket")
	req.Header.Set("Sec-WebSocket-Version", "13")
	req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
	if len(subprotocols) > 0 {
		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(subprotocols, ","))
	}

	resp, err := httpClient.Do(req)
	if err != nil {
		return nil, nil, err
	}
	defer func() {
		respBody := resp.Body
		if err != nil {
			// We read a bit of the body for better debugging.
			r := io.LimitReader(resp.Body, 1024)
			b, _ := ioutil.ReadAll(r)
			resp.Body = ioutil.NopCloser(bytes.NewReader(b))
Anmol Sethi's avatar
Anmol Sethi committed
			respBody.Close()
Anmol Sethi's avatar
Anmol Sethi committed
		}
	}()

	if resp.StatusCode != http.StatusSwitchingProtocols {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, resp, xerrors.Errorf("websocket: expected status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
Anmol Sethi's avatar
Anmol Sethi committed
	}

	if !httpguts.HeaderValuesContainsToken(resp.Header["Connection"], "Upgrade") {
		return nil, resp, xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", resp.Header.Get("Connection"))
	}

	if !httpguts.HeaderValuesContainsToken(resp.Header["Upgrade"], "websocket") {
		return nil, resp, xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", resp.Header.Get("Upgrade"))

	}

	// We do not care about Sec-WebSocket-Accept because it does not matter.
	// See the secWebSocketKey global variable.

	rwc, ok := resp.Body.(io.ReadWriteCloser)
	if !ok {
		return nil, resp, xerrors.Errorf("websocket: body is not a read write closer but should be: %T", rwc)
	}

Anmol Sethi's avatar
Anmol Sethi committed
	c := &Conn{
		subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
		br:          bufio.NewReader(rwc),
		bw:          bufio.NewWriter(rwc),
		closer:      rwc,
		client:      true,
	}
	c.init()

	return c, resp, nil
Anmol Sethi's avatar
Anmol Sethi committed
}