Newer
Older
"net/url"
"strings"
"golang.org/x/net/http/httpguts"
"golang.org/x/xerrors"
)
// AcceptOption is an option that can be passed to Accept.
// The implementations of this interface are printable.
type acceptSubprotocols []string
func (o acceptSubprotocols) acceptOption() {}
// AcceptSubprotocols list the subprotocols that Accept will negotiate with a client.
// The first protocol that a client supports will be negotiated.
// The empty protocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection is c.Subprotocol() == "".
type acceptOrigins []string
func (o acceptOrigins) acceptOption() {}
// AcceptOrigins lists the origins that Accept will accept.
// Accept will always accept r.Host as the origin so you do not need to
// specify that with this option.
// Use this option with caution to avoid exposing your WebSocket
// server to a CSRF attack.
// See https://stackoverflow.com/a/37837709/4283659
}
// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to WebSocket.
// Accept will reject the handshake if the Origin is not the same as the Host unless
// InsecureAcceptOrigin is passed.
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
var subprotocols []string
origins := []string{r.Host}
for _, opt := range opts {
switch opt := opt.(type) {
case acceptOrigins:
origins = []string(opt)
case acceptSubprotocols:
subprotocols = []string(opt)
}
}
if !httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") {
err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if !httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") {
err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if r.Method != "GET" {
err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
err := xerrors.Errorf("websocket: unsupported protocol version: %q", r.Header.Get("Sec-WebSocket-Version"))
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
if r.Header.Get("Sec-WebSocket-Key") == "" {
err := xerrors.New("websocket: protocol violation: missing Sec-WebSocket-Key")
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
origins = append(origins, r.Host)
err := authenticateOrigin(r, origins)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return nil, err
}
hj, ok := w.(http.Hijacker)
if !ok {
err = xerrors.New("websocket: response writer does not implement http.Hijacker")
http.Error(w, err.Error(), http.StatusInternalServerError)
return nil, err
}
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
handleKey(w, r)
selectSubprotocol(w, r, subprotocols)
w.WriteHeader(http.StatusSwitchingProtocols)
err = xerrors.Errorf("websocket: failed to hijack connection: %w", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return nil, err
}
c := &Conn{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
br: brw.Reader,
bw: brw.Writer,
closer: netConn,
}
c.init()
}
func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) {
clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",")
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
for _, sp := range subprotocols {
for _, cp := range clientSubprotocols {
if sp == strings.TrimSpace(cp) {
w.Header().Set("Sec-WebSocket-Protocol", sp)
return
}
}
}
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func handleKey(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("Sec-WebSocket-Key")
h := sha1.New()
h.Write([]byte(key))
h.Write(keyGUID)
responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
w.Header().Set("Sec-WebSocket-Accept", responseKey)
}
func authenticateOrigin(r *http.Request, origins []string) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
if strings.EqualFold(u.Host, o) {
return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin"))