good morning!!!!

Skip to content
Snippets Groups Projects
accept.go 5.11 KiB
Newer Older
package websocket
Anmol Sethi's avatar
Anmol Sethi committed

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"crypto/sha1"
	"encoding/base64"
Anmol Sethi's avatar
Anmol Sethi committed
	"net/http"
Anmol Sethi's avatar
Anmol Sethi committed
	"net/url"
	"strings"

	"golang.org/x/net/http/httpguts"
	"golang.org/x/xerrors"
Anmol Sethi's avatar
Anmol Sethi committed
)

// AcceptOption is an option that can be passed to Accept.
Anmol Sethi's avatar
Anmol Sethi committed
// The implementations of this interface are printable.
Anmol Sethi's avatar
Anmol Sethi committed
type AcceptOption interface {
	acceptOption()
}

Anmol Sethi's avatar
Anmol Sethi committed
type acceptSubprotocols []string

func (o acceptSubprotocols) acceptOption() {}

Anmol Sethi's avatar
Anmol Sethi committed
// AcceptSubprotocols list the subprotocols that Accept will negotiate with a client.
// The first protocol that a client supports will be negotiated.
// The empty protocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection is c.Subprotocol() == "".
Anmol Sethi's avatar
Anmol Sethi committed
func AcceptSubprotocols(subprotocols ...string) AcceptOption {
Anmol Sethi's avatar
Anmol Sethi committed
	return acceptSubprotocols(subprotocols)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
type acceptOrigins []string

func (o acceptOrigins) acceptOption() {}

Anmol Sethi's avatar
Anmol Sethi committed
// AcceptOrigins lists the origins that Accept will accept.
// Accept will always accept r.Host as the origin so you do not need to
// specify that with this option.
Anmol Sethi's avatar
Anmol Sethi committed
// Use this option with caution to avoid exposing your WebSocket
// server to a CSRF attack.
// See https://stackoverflow.com/a/37837709/4283659
Anmol Sethi's avatar
Anmol Sethi committed
// You can use a * for wildcards.
Anmol Sethi's avatar
Anmol Sethi committed
func AcceptOrigins(origins ...string) AcceptOption {
	return acceptOrigins(origins)
Anmol Sethi's avatar
Anmol Sethi committed
}

// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to WebSocket.
// Accept will reject the handshake if the Origin is not the same as the Host unless
// InsecureAcceptOrigin is passed.
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
Anmol Sethi's avatar
Anmol Sethi committed
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	var subprotocols []string
	origins := []string{r.Host}
	for _, opt := range opts {
		switch opt := opt.(type) {
		case acceptOrigins:
			origins = []string(opt)
		case acceptSubprotocols:
			subprotocols = []string(opt)
		}
	}

	if !httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") {
		err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
		http.Error(w, err.Error(), http.StatusBadRequest)
		return nil, err
	}

	if !httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") {
		err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
		http.Error(w, err.Error(), http.StatusBadRequest)
		return nil, err
	}

	if r.Method != "GET" {
		err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
		http.Error(w, err.Error(), http.StatusBadRequest)
		return nil, err
	}

	if r.Header.Get("Sec-WebSocket-Version") != "13" {
		err := xerrors.Errorf("websocket: unsupported protocol version: %q", r.Header.Get("Sec-WebSocket-Version"))
		http.Error(w, err.Error(), http.StatusBadRequest)
		return nil, err
	}

	if r.Header.Get("Sec-WebSocket-Key") == "" {
		err := xerrors.New("websocket: protocol violation: missing Sec-WebSocket-Key")
		http.Error(w, err.Error(), http.StatusBadRequest)
		return nil, err
	}

	origins = append(origins, r.Host)

	err := authenticateOrigin(r, origins)
	if err != nil {
		http.Error(w, err.Error(), http.StatusForbidden)
		return nil, err
	}

	hj, ok := w.(http.Hijacker)
	if !ok {
		err = xerrors.New("websocket: response writer does not implement http.Hijacker")
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return nil, err
	}

	w.Header().Set("Upgrade", "websocket")
	w.Header().Set("Connection", "Upgrade")

	handleKey(w, r)

	selectSubprotocol(w, r, subprotocols)

	w.WriteHeader(http.StatusSwitchingProtocols)

Anmol Sethi's avatar
Anmol Sethi committed
	netConn, brw, err := hj.Hijack()
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		err = xerrors.Errorf("websocket: failed to hijack connection: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return nil, err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	c := &Conn{
		subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
		br:          brw.Reader,
		bw:          brw.Writer,
		closer:      netConn,
	}
	c.init()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	return c, nil
Anmol Sethi's avatar
Anmol Sethi committed
}

func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) {
	clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",")
Anmol Sethi's avatar
Anmol Sethi committed
	for _, sp := range subprotocols {
		for _, cp := range clientSubprotocols {
			if sp == strings.TrimSpace(cp) {
				w.Header().Set("Sec-WebSocket-Protocol", sp)
				return
			}
		}
	}
}

var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func handleKey(w http.ResponseWriter, r *http.Request) {
	key := r.Header.Get("Sec-WebSocket-Key")
	h := sha1.New()
	h.Write([]byte(key))
	h.Write(keyGUID)

	responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
	w.Header().Set("Sec-WebSocket-Accept", responseKey)
}

func authenticateOrigin(r *http.Request, origins []string) error {
	origin := r.Header.Get("Origin")
	if origin == "" {
		return nil
	}
	u, err := url.Parse(origin)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	for _, o := range origins {
		if strings.EqualFold(u.Host, o) {
Anmol Sethi's avatar
Anmol Sethi committed
			return nil
		}
	}
	return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin"))
Anmol Sethi's avatar
Anmol Sethi committed
}