diff --git a/accept.go b/accept.go index 6e1f494e32c9027f926d63fa8079ad01b4d713fb..542b61e84e406b9c6b444f049334fc1e303d480b 100644 --- a/accept.go +++ b/accept.go @@ -215,7 +215,10 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { return nil } } - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if u.Host == "" { + return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) + } + return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) } func match(pattern, s string) (bool, error) { diff --git a/accept_test.go b/accept_test.go index d19f54e15c230208420b3e083b4943155a4bca6f..67ece2535f096cac635e708260d92b742ab631b6 100644 --- a/accept_test.go +++ b/accept_test.go @@ -39,7 +39,23 @@ func TestAccept(t *testing.T) { r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) + assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`) + }) + + // #247 + t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Origin", "https://harhar.com") + + _, err := Accept(w, r, nil) + assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`) }) t.Run("badCompression", func(t *testing.T) {