diff --git a/accept.go b/accept.go index f505c03be7fc02053fc3e6a58b245989f2a0cc4b..3120690a54b88cdd2519e1d599f519a52fb5ee3e 100644 --- a/accept.go +++ b/accept.go @@ -29,21 +29,26 @@ func AcceptSubprotocols(protocols ...string) AcceptOption { return acceptSubprotocols(protocols) } -type acceptOrigins []string +type acceptInsecureOrigin struct{} -func (o acceptOrigins) acceptOption() {} +func (o acceptInsecureOrigin) acceptOption() {} -// AcceptOrigins lists the origins that Accept will accept. -// Accept will always accept r.Host as the origin. Use this -// option when you want to accept an origin with a different domain -// than the one the WebSocket server is running on. +// AcceptInsecureOrigin disables Accept's origin verification +// behaviour. By default Accept only allows the handshake to +// succeed if the javascript that is initiating the handshake +// is on the same domain as the server. This is to prevent CSRF +// when secure data is stored in cookies. // -// Use this option with caution to avoid exposing your WebSocket -// server to a CSRF attack. // See https://stackoverflow.com/a/37837709/4283659 -// TODO remove in favour of AcceptInsecureOrigin -func AcceptOrigins(origins ...string) AcceptOption { - return acceptOrigins(origins) +// +// Use this if you want a WebSocket server any javascript can +// connect to or you want to perform Origin verification yourself +// and allow some whitelist of domains. +// +// Ensure you understand exactly what the above means before you use +// this option in conjugation with cookies containing secure data. +func AcceptInsecureOrigin() AcceptOption { + return acceptInsecureOrigin{} } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { @@ -87,11 +92,11 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept uses w to write the handshake response so the timeouts on the http.Server apply. func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) { var subprotocols []string - origins := []string{r.Host} + verifyOrigin := true for _, opt := range opts { switch opt := opt.(type) { - case acceptOrigins: - origins = []string(opt) + case acceptInsecureOrigin: + verifyOrigin = false case acceptSubprotocols: subprotocols = []string(opt) } @@ -102,12 +107,12 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn return nil, err } - origins = append(origins, r.Host) - - err = authenticateOrigin(r, origins) - if err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return nil, err + if verifyOrigin { + err = authenticateOrigin(r) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } } hj, ok := w.(http.Hijacker) @@ -173,7 +178,7 @@ func handleKey(w http.ResponseWriter, r *http.Request) { w.Header().Set("Sec-WebSocket-Accept", responseKey) } -func authenticateOrigin(r *http.Request, origins []string) error { +func authenticateOrigin(r *http.Request) error { origin := r.Header.Get("Origin") if origin == "" { return nil @@ -182,10 +187,8 @@ func authenticateOrigin(r *http.Request, origins []string) error { if err != nil { return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err) } - for _, o := range origins { - if strings.EqualFold(u.Host, o) { - return nil - } + if strings.EqualFold(u.Host, r.Host) { + return nil } - return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin")) + return xerrors.Errorf("request origin %q is not authorized", origin) } diff --git a/accept_test.go b/accept_test.go index 4b5214dde7a45045e130a2c472545ba8390536f4..6f5c3fb9e9f3896330ab692d2bee4808a1f5a4b2 100644 --- a/accept_test.go +++ b/accept_test.go @@ -140,37 +140,39 @@ func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { - name string - origin string - authorizedOrigins []string - success bool + name string + origin string + host string + success bool }{ { name: "none", success: true, + host: "example.com", }, { name: "invalid", origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}", + host: "example.com", success: false, }, { - name: "unauthorized", - origin: "https://example.com", - authorizedOrigins: []string{"example1.com"}, - success: false, + name: "unauthorized", + origin: "https://example.com", + host: "example1.com", + success: false, }, { - name: "authorized", - origin: "https://example.com", - authorizedOrigins: []string{"example.com"}, - success: true, + name: "authorized", + origin: "https://example.com", + host: "example.com", + success: true, }, { - name: "authorizedCaseInsensitive", - origin: "https://examplE.com", - authorizedOrigins: []string{"example.com"}, - success: true, + name: "authorizedCaseInsensitive", + origin: "https://examplE.com", + host: "example.com", + success: true, }, } @@ -179,10 +181,10 @@ func Test_authenticateOrigin(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) - err := authenticateOrigin(r, tc.authorizedOrigins) + err := authenticateOrigin(r) if (err == nil) != tc.success { t.Fatalf("unexpected error value: %+v", err) } diff --git a/websocket_test.go b/websocket_test.go index 2133482fb44403f733dca3834acb5d2071a78c9d..868b69a37ccb7e0700c89627eeb1403a02d6da4f 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -143,9 +143,30 @@ func TestHandshake(t *testing.T) { }, }, { - name: "authorizedOrigin", + name: "acceptSecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOrigins("har.bar.com", "example.com")) + c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin()) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + h := http.Header{} + h.Set("Origin", "https://127.0.0.1") + c, _, err := websocket.Dial(ctx, u, websocket.DialHeader(h)) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + { + name: "acceptInsecureOrigin", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin()) if err != nil { return err }