good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 3dd723ae authored by Mathias Fredriksson's avatar Mathias Fredriksson Committed by GitHub
Browse files

accept: Add unwrapping for hijack like http.ResponseController (#472)

Since we rely on the connection not being hijacked too early (i.e.
detecting the presence of http.Hijacker) to set headers, we must
manually implement the unwrapping of the http.ResponseController. By
doing so, we also retain Go 1.19 compatibility without build tags.

Closes #455
parent 641f4f5c
No related branches found
No related tags found
No related merge requests found
...@@ -105,7 +105,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -105,7 +105,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
} }
} }
hj, ok := w.(http.Hijacker) hj, ok := hijacker(w)
if !ok { if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker") err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
......
...@@ -143,6 +143,33 @@ func TestAccept(t *testing.T) { ...@@ -143,6 +143,33 @@ func TestAccept(t *testing.T) {
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`) assert.Contains(t, err, `failed to hijack connection`)
}) })
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, "failed to hijack connection")
})
t.Run("closeRace", func(t *testing.T) { t.Run("closeRace", func(t *testing.T) {
t.Parallel() t.Parallel()
...@@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} ...@@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack() return mj.hijack()
} }
type mockUnwrapper struct {
http.ResponseWriter
unwrap func() http.ResponseWriter
}
var _ rwUnwrapper = mockUnwrapper{}
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
return mu.unwrap()
}
//go:build !js
package websocket
import (
"net/http"
)
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
// hijacker returns the Hijacker interface of the http.ResponseWriter.
// It follows the Unwrap method of the http.ResponseWriter if available,
// matching the behavior of http.ResponseController. If the Hijacker
// interface is not found, it returns false.
//
// Since the http.ResponseController is not available in Go 1.19, and
// does not support checking the presence of the Hijacker interface,
// this function is used to provide a consistent way to check for the
// Hijacker interface across Go versions.
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t, true
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil, false
}
}
}
//go:build !js && go1.20
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
_, _, err := http.NewResponseController(w).Hijack()
assert.Contains(t, err, "haha")
hj, ok := hijacker(w)
assert.Equal(t, "hijacker found", ok, true)
_, _, err = hj.Hijack()
assert.Contains(t, err, "haha")
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment