Newer
Older
// DialOptions represents the options available to pass to Dial.
type DialOptions struct {
// HTTPClient is the http client used for the handshake.
// Its Transport must use HTTP/1.1 and must return writable bodies
// for WebSocket handshakes. This was introduced in Go 1.12.
// http.Transport does this correctly.
HTTPClient *http.Client
// Header specifies the HTTP headers included in the handshake request.
// TODO rename to HTTPHeader
Header http.Header
// Subprotocols lists the subprotocols to negotiate with the server.
Subprotocols []string
}
// We use this key for all client requests as the Sec-WebSocket-Key header is useless.
// See https://stackoverflow.com/a/37074398/4283659.
// We also use the same mask key for every message as it too does not make a difference.
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))
// Dial performs a WebSocket handshake on the given url with the given options.
func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) {
c, r, err := dial(ctx, u, opts)
if err != nil {
return nil, r, xerrors.Errorf("failed to websocket dial: %w", err)
}
return c, r, nil
}
func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) {
if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient
}
if opts.Header == nil {
opts.Header = http.Header{}
}
parsedURL, err := url.Parse(u)
if err != nil {
return nil, nil, xerrors.Errorf("failed to parse url: %w", err)
return nil, nil, xerrors.Errorf("unexpected url scheme scheme: %q", parsedURL.Scheme)
req, _ := http.NewRequest("GET", parsedURL.String(), nil)
req.Header = opts.Header
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
resp, err := opts.HTTPClient.Do(req)
return nil, nil, xerrors.Errorf("failed to send handshake request: %w", err)
}
defer func() {
respBody := resp.Body
if err != nil {
// We read a bit of the body for better debugging.
r := io.LimitReader(resp.Body, 1024)
b, _ := ioutil.ReadAll(r)
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
err = verifyServerResponse(resp)
if err != nil {
return nil, resp, err
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, resp, xerrors.Errorf("response body is not a read write closer: %T", rwc)
// TODO pool bufio
c := &Conn{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
br: bufio.NewReader(rwc),
bw: bufio.NewWriter(rwc),
closer: rwc,
client: true,
}
c.init()
return c, resp, nil
func verifyServerResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") {
return xerrors.Errorf("websocket protocol violation: Connection header does not contain Upgrade: %q", resp.Header.Get("Connection"))
}
if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") {
return xerrors.Errorf("websocket protocol violation: Upgrade header does not contain websocket: %q", resp.Header.Get("Upgrade"))
}
// We do not care about Sec-WebSocket-Accept because it does not matter.
// See the secWebSocketKey global variable.
return nil
}