Newer
Older
// DialOptions represents the options available to pass to Dial.
type DialOptions struct {
// HTTPClient is the http client used for the handshake.
// Its Transport must use HTTP/1.1 and return writable bodies
// for WebSocket handshakes. This was introduced in Go 1.12.
// http.Transport does this all correctly.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Subprotocols lists the subprotocols to negotiate with the server.
Subprotocols []string
}
// We use this key for all client requests as the Sec-WebSocket-Key header is useless.
// See https://stackoverflow.com/a/37074398/4283659.
// We also use the same mask key for every message as it too does not make a difference.
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))
// Dial performs a WebSocket handshake on the given url with the given options.
// The response is the WebSocket handshake response from the server.
// If an error occurs, the returned response may be non nil. However, you can only
// read the first 1024 bytes of its body.
//
// This function requires at least Go 1.12 to succeed as it uses a new feature
// in net/http to perform WebSocket handshakes and get a writable body
// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861
func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) {
c, r, err := dial(ctx, u, opts)
if err != nil {
return nil, r, xerrors.Errorf("failed to websocket dial: %w", err)
}
return c, r, nil
}
func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) {
if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient
}
if opts.HTTPClient.Timeout > 0 {
return nil, nil, xerrors.Errorf("please use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
parsedURL, err := url.Parse(u)
if err != nil {
return nil, nil, xerrors.Errorf("failed to parse url: %w", err)
return nil, nil, xerrors.Errorf("unexpected url scheme: %q", parsedURL.Scheme)
req, _ := http.NewRequest("GET", parsedURL.String(), nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
resp, err := opts.HTTPClient.Do(req)
return nil, nil, xerrors.Errorf("failed to send handshake request: %w", err)
// We read a bit of the body for easier debugging.
r := io.LimitReader(resp.Body, 1024)
b, _ := ioutil.ReadAll(r)
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
}()
err = verifyServerResponse(resp)
if err != nil {
return nil, resp, err
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, resp, xerrors.Errorf("response body is not a read write closer: %T", rwc)
c := &Conn{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
closer: rwc,
client: true,
}
c.init()
return c, resp, nil
func verifyServerResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") {
return xerrors.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") {
return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
// We do not care about Sec-WebSocket-Accept because it does not matter.
// See the secWebSocketKey global variable.
return nil
}
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// The below pools can only be used by the client because http.Hijacker will always
// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top.
var bufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReader(nil)
},
}
func getBufioReader(r io.Reader) *bufio.Reader {
br := bufioReaderPool.Get().(*bufio.Reader)
br.Reset(r)
return br
}
func returnBufioReader(br *bufio.Reader) {
bufioReaderPool.Put(br)
}
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriter(nil)
},
}
func getBufioWriter(w io.Writer) *bufio.Writer {
bw := bufioWriterPool.Get().(*bufio.Writer)
bw.Reset(w)
return bw
}
func returnBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}