good morning!!!!

Skip to content
Snippets Groups Projects
Select Git revision
  • f7bed7c75ec38a4cee43efdbba57da76daf5305e
  • master default protected
  • fix-license-1
  • mafredri/build-add-makefile
  • mafredri/build-update-to-go1.22
  • mafredri/fix-ci-script-tool-version
  • dependabot/go_modules/internal/examples/internal-deps-4cadc2be3d
  • dependabot/go_modules/internal/examples/internal-deps-46eeb9c117
  • dependabot/go_modules/internal/examples/internal-deps-022ca1aea3
  • mafredri/chore-update-dependabot
  • dependabot/go_modules/internal/thirdparty/github.com/lesismal/nbio-1.5.12
  • dependabot/go_modules/internal/examples/golang.org/x/time-0.7.0
  • dependabot/go_modules/internal/thirdparty/github.com/lesismal/nbio-1.5.11
  • ethan/close-order
  • mafredri/chore-remove-funding
  • mafredri/fix-ci
  • dependabot/go_modules/internal/thirdparty/github.com/gobwas/ws-1.4.0
  • dependabot/go_modules/internal/thirdparty/github.com/gorilla/websocket-1.5.3
  • dependabot/go_modules/internal/thirdparty/github.com/lesismal/nbio-1.5.10
  • dependabot/go_modules/internal/thirdparty/github.com/gin-gonic/gin-1.10.0
  • dependabot/go_modules/internal/examples/golang.org/x/time-0.6.0
  • v1.8.13
  • v1.8.12
  • v1.8.11
  • v1.8.10
  • v1.8.9
  • v1.8.8
  • v1.8.7
  • v1.8.6
  • v1.8.5
  • v1.8.4
  • v1.8.3
  • v1.8.2
  • v1.8.1
  • v1.8.0
  • v1.7.4
  • v1.7.3
  • v1.7.2
  • v1.7.1
  • v1.7.0
  • v1.6.5
41 results

dial.go

Blame
  • user avatar
    Martin Benda authored and Anmol Sethi committed
    f7bed7c7
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    dial.go 8.19 KiB
    //go:build !js
    // +build !js
    
    package websocket
    
    import (
    	"bufio"
    	"bytes"
    	"context"
    	"crypto/rand"
    	"encoding/base64"
    	"fmt"
    	"io"
    	"net/http"
    	"net/url"
    	"strings"
    	"sync"
    	"time"
    
    	"nhooyr.io/websocket/internal/errd"
    )
    
    // DialOptions represents Dial's options.
    type DialOptions struct {
    	// HTTPClient is used for the connection.
    	// Its Transport must return writable bodies for WebSocket handshakes.
    	// http.Transport does beginning with Go 1.12.
    	HTTPClient *http.Client
    
    	// HTTPHeader specifies the HTTP headers included in the handshake request.
    	HTTPHeader http.Header
    
    	// Host optionally overrides the Host HTTP header to send. If empty, the value
    	// of URL.Host will be used.
    	Host string
    
    	// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
    	Subprotocols []string
    
    	// CompressionMode controls the compression mode.
    	// Defaults to CompressionDisabled.
    	//
    	// See docs on CompressionMode for details.
    	CompressionMode CompressionMode
    
    	// CompressionThreshold controls the minimum size of a message before compression is applied.
    	//
    	// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
    	// for CompressionContextTakeover.
    	CompressionThreshold int
    }
    
    func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
    	var cancel context.CancelFunc
    
    	var o DialOptions
    	if opts != nil {
    		o = *opts
    	}
    	if o.HTTPClient == nil {
    		o.HTTPClient = http.DefaultClient
    	} else if opts.HTTPClient.Timeout > 0 {
    		ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
    
    		newClient := *opts.HTTPClient
    		newClient.Timeout = 0
    		opts.HTTPClient = &newClient
    	}
    	if o.HTTPHeader == nil {
    		o.HTTPHeader = http.Header{}
    	}
    
    	return ctx, cancel, &o
    }
    
    // Dial performs a WebSocket handshake on url.
    //
    // The response is the WebSocket handshake response from the server.
    // You never need to close resp.Body yourself.
    //
    // If an error occurs, the returned response may be non nil.
    // However, you can only read the first 1024 bytes of the body.
    //
    // This function requires at least Go 1.12 as it uses a new feature
    // in net/http to perform WebSocket handshakes.
    // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
    //
    // URLs with http/https schemes will work and are interpreted as ws/wss.
    func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
    	return dial(ctx, u, opts, nil)
    }
    
    func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
    	defer errd.Wrap(&err, "failed to WebSocket dial")
    
    	var cancel context.CancelFunc
    	ctx, cancel, opts = opts.cloneWithDefaults(ctx)
    	if cancel != nil {
    		defer cancel()
    	}
    
    	secWebSocketKey, err := secWebSocketKey(rand)
    	if err != nil {
    		return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
    	}
    
    	var copts *compressionOptions
    	if opts.CompressionMode != CompressionDisabled {
    		copts = opts.CompressionMode.opts()
    	}
    
    	resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
    	if err != nil {
    		return nil, resp, err
    	}
    	respBody := resp.Body
    	resp.Body = nil
    	defer func() {
    		if err != nil {
    			// We read a bit of the body for easier debugging.
    			r := io.LimitReader(respBody, 1024)
    
    			timer := time.AfterFunc(time.Second*3, func() {
    				respBody.Close()
    			})
    			defer timer.Stop()
    
    			b, _ := io.ReadAll(r)
    			respBody.Close()
    			resp.Body = io.NopCloser(bytes.NewReader(b))
    		}
    	}()
    
    	copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
    	if err != nil {
    		return nil, resp, err
    	}
    
    	rwc, ok := respBody.(io.ReadWriteCloser)
    	if !ok {
    		return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
    	}
    
    	return newConn(connConfig{
    		subprotocol:    resp.Header.Get("Sec-WebSocket-Protocol"),
    		rwc:            rwc,
    		client:         true,
    		copts:          copts,
    		flateThreshold: opts.CompressionThreshold,
    		br:             getBufioReader(rwc),
    		bw:             getBufioWriter(rwc),
    	}), resp, nil
    }
    
    func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
    	u, err := url.Parse(urls)
    	if err != nil {
    		return nil, fmt.Errorf("failed to parse url: %w", err)
    	}
    
    	switch u.Scheme {
    	case "ws":
    		u.Scheme = "http"
    	case "wss":
    		u.Scheme = "https"
    	case "http", "https":
    	default:
    		return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
    	}
    
    	req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
    	if err != nil {
    		return nil, fmt.Errorf("failed to create new http request: %w", err)
    	}
    	if len(opts.Host) > 0 {
    		req.Host = opts.Host
    	}
    	req.Header = opts.HTTPHeader.Clone()
    	req.Header.Set("Connection", "Upgrade")
    	req.Header.Set("Upgrade", "websocket")
    	req.Header.Set("Sec-WebSocket-Version", "13")
    	req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
    	if len(opts.Subprotocols) > 0 {
    		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
    	}
    	if copts != nil {
    		copts.setHeader(req.Header)
    	}
    
    	resp, err := opts.HTTPClient.Do(req)
    	if err != nil {
    		return nil, fmt.Errorf("failed to send handshake request: %w", err)
    	}
    	return resp, nil
    }
    
    func secWebSocketKey(rr io.Reader) (string, error) {
    	if rr == nil {
    		rr = rand.Reader
    	}
    	b := make([]byte, 16)
    	_, err := io.ReadFull(rr, b)
    	if err != nil {
    		return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
    	}
    	return base64.StdEncoding.EncodeToString(b), nil
    }
    
    func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
    	if resp.StatusCode != http.StatusSwitchingProtocols {
    		return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
    	}
    
    	if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
    		return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
    	}
    
    	if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
    		return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
    	}
    
    	if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
    		return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
    			resp.Header.Get("Sec-WebSocket-Accept"),
    			secWebSocketKey,
    		)
    	}
    
    	err := verifySubprotocol(opts.Subprotocols, resp)
    	if err != nil {
    		return nil, err
    	}
    
    	return verifyServerExtensions(copts, resp.Header)
    }
    
    func verifySubprotocol(subprotos []string, resp *http.Response) error {
    	proto := resp.Header.Get("Sec-WebSocket-Protocol")
    	if proto == "" {
    		return nil
    	}
    
    	for _, sp2 := range subprotos {
    		if strings.EqualFold(sp2, proto) {
    			return nil
    		}
    	}
    
    	return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
    }
    
    func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
    	exts := websocketExtensions(h)
    	if len(exts) == 0 {
    		return nil, nil
    	}
    
    	ext := exts[0]
    	if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
    		return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
    	}
    
    	_copts := *copts
    	copts = &_copts
    
    	for _, p := range ext.params {
    		switch p {
    		case "client_no_context_takeover":
    			copts.clientNoContextTakeover = true
    			continue
    		case "server_no_context_takeover":
    			copts.serverNoContextTakeover = true
    			continue
    		}
    
    		return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
    	}
    
    	return copts, nil
    }
    
    var bufioReaderPool sync.Pool
    
    func getBufioReader(r io.Reader) *bufio.Reader {
    	br, ok := bufioReaderPool.Get().(*bufio.Reader)
    	if !ok {
    		return bufio.NewReader(r)
    	}
    	br.Reset(r)
    	return br
    }
    
    func putBufioReader(br *bufio.Reader) {
    	bufioReaderPool.Put(br)
    }
    
    var bufioWriterPool sync.Pool
    
    func getBufioWriter(w io.Writer) *bufio.Writer {
    	bw, ok := bufioWriterPool.Get().(*bufio.Writer)
    	if !ok {
    		return bufio.NewWriter(w)
    	}
    	bw.Reset(w)
    	return bw
    }
    
    func putBufioWriter(bw *bufio.Writer) {
    	bufioWriterPool.Put(bw)
    }