diff --git a/accept.go b/accept.go index 31f104b23a5d2301f3f26bcd3c232fb275c8cd6d..cc9babb0e6ec960e19260f27b63e319d7ecabf07 100644 --- a/accept.go +++ b/accept.go @@ -65,9 +65,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con opts.CompressionOptions = &CompressionOptions{} } - err = verifyClientRequest(r) + errCode, err := verifyClientRequest(w, r) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + http.Error(w, err.Error(), errCode) return nil, err } @@ -127,32 +127,37 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con }), nil } -func verifyClientRequest(r *http.Request) error { +func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { - return xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsToken(r.Header, "Connection", "Upgrade") { - return xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsToken(r.Header, "Upgrade", "websocket") { - return xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { - return xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + return http.StatusMethodNotAllowed, xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { - return xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + w.Header().Set("Sec-WebSocket-Version", "13") + return http.StatusBadRequest, xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { - return xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + return http.StatusBadRequest, xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } - return nil + return 0, nil } func authenticateOrigin(r *http.Request) error { diff --git a/accept_test.go b/accept_test.go index 18302da5b7653a14b1f868dd280af356d4be22c3..354e95ec46fc46c14b3401a93596ee6042e1b3bf 100644 --- a/accept_test.go +++ b/accept_test.go @@ -192,7 +192,7 @@ func Test_verifyClientHandshake(t *testing.T) { r.Header.Set(k, v) } - err := verifyClientRequest(r) + _, err := verifyClientRequest(httptest.NewRecorder(), r) if tc.success != (err == nil) { t.Fatalf("unexpected error value: %v", err) }