diff --git a/accept.go b/accept.go index 5b997be6..68c00ed3 100644 --- a/accept.go +++ b/accept.go @@ -105,7 +105,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } } - hj, ok := w.(http.Hijacker) + hj, ok := hijacker(w) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) diff --git a/accept_test.go b/accept_test.go index 4f799126..3b45ac5c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -143,6 +143,33 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + + t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + _, err := Accept(w, r, nil) + assert.Contains(t, err, "failed to hijack connection") + }) + t.Run("closeRace", func(t *testing.T) { t.Parallel() @@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } + +type mockUnwrapper struct { + http.ResponseWriter + unwrap func() http.ResponseWriter +} + +var _ rwUnwrapper = mockUnwrapper{} + +func (mu mockUnwrapper) Unwrap() http.ResponseWriter { + return mu.unwrap() +} diff --git a/hijack.go b/hijack.go new file mode 100644 index 00000000..9cce45ca --- /dev/null +++ b/hijack.go @@ -0,0 +1,33 @@ +//go:build !js + +package websocket + +import ( + "net/http" +) + +type rwUnwrapper interface { + Unwrap() http.ResponseWriter +} + +// hijacker returns the Hijacker interface of the http.ResponseWriter. +// It follows the Unwrap method of the http.ResponseWriter if available, +// matching the behavior of http.ResponseController. If the Hijacker +// interface is not found, it returns false. +// +// Since the http.ResponseController is not available in Go 1.19, and +// does not support checking the presence of the Hijacker interface, +// this function is used to provide a consistent way to check for the +// Hijacker interface across Go versions. +func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t, true + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, false + } + } +} diff --git a/hijack_go120_test.go b/hijack_go120_test.go new file mode 100644 index 00000000..0f0673a9 --- /dev/null +++ b/hijack_go120_test.go @@ -0,0 +1,38 @@ +//go:build !js && go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coder/websocket/internal/test/assert" +) + +func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + _, _, err := http.NewResponseController(w).Hijack() + assert.Contains(t, err, "haha") + hj, ok := hijacker(w) + assert.Equal(t, "hijacker found", ok, true) + _, _, err = hj.Hijack() + assert.Contains(t, err, "haha") +}