Newer
Older
package websocket
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
// HTTPClient is used for the connection.
// Its Transport must return writable bodies for WebSocket handshakes.
// http.Transport does beginning with Go 1.12.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
// CompressionOptions controls the compression options.
// See docs on the CompressionOptions type.
// The response is the WebSocket handshake response from the server.
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
if opts == nil {
opts = &DialOptions{}
}
if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
if opts.CompressionOptions == nil {
opts.CompressionOptions = &CompressionOptions{}
}
return nil, nil, xerrors.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
}()
copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
rwc: rwc,
client: true,
copts: copts,
flateThreshold: opts.CompressionOptions.Threshold,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 {
return nil, xerrors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
return nil, xerrors.Errorf("failed to parse url: %w", err)
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, xerrors.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if opts.CompressionOptions.Mode != CompressionDisabled {
copts := opts.CompressionOptions.Mode.opts()
copts.setHeader(req.Header)
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, xerrors.Errorf("failed to send handshake request: %w", err)
func secWebSocketKey(rr io.Reader) (string, error) {
if rr == nil {
rr = rand.Reader
}
return "", xerrors.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
return nil, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
return nil, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, xerrors.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
if err != nil {
return nil, err
}
return verifyServerExtensions(resp.Header)
}
func verifySubprotocol(subprotos []string, resp *http.Response) error {
proto := resp.Header.Get("Sec-WebSocket-Protocol")
if proto == "" {
return nil
}
for _, sp2 := range subprotos {
if strings.EqualFold(sp2, proto) {
return nil
}
}
return xerrors.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}
ext := exts[0]
return nil, xerrors.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
return nil, xerrors.Errorf("unsupported permessage-deflate parameter: %q", p)
}
}
return copts, nil
}
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
var readerPool sync.Pool
func getBufioReader(r io.Reader) *bufio.Reader {
br, ok := readerPool.Get().(*bufio.Reader)
if !ok {
return bufio.NewReader(r)
}
br.Reset(r)
return br
}
func putBufioReader(br *bufio.Reader) {
readerPool.Put(br)
}
var writerPool sync.Pool
func getBufioWriter(w io.Writer) *bufio.Writer {
bw, ok := writerPool.Get().(*bufio.Writer)
if !ok {
return bufio.NewWriter(w)
}
bw.Reset(w)
return bw
}
func putBufioWriter(bw *bufio.Writer) {
writerPool.Put(bw)
}