From 3dd723ae69f00f042c4fd6961ed6cee3e9e11659 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 22 Aug 2024 11:10:19 +0300 Subject: [PATCH] accept: Add unwrapping for hijack like http.ResponseController (#472) Since we rely on the connection not being hijacked too early (i.e. detecting the presence of http.Hijacker) to set headers, we must manually implement the unwrapping of the http.ResponseController. By doing so, we also retain Go 1.19 compatibility without build tags. Closes #455 --- accept.go | 2 +- accept_test.go | 38 ++++++++++++++++++++++++++++++++++++++ hijack.go | 33 +++++++++++++++++++++++++++++++++ hijack_go120_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 hijack.go create mode 100644 hijack_go120_test.go diff --git a/accept.go b/accept.go index 5b997be6..68c00ed3 100644 --- a/accept.go +++ b/accept.go @@ -105,7 +105,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } } - hj, ok := w.(http.Hijacker) + hj, ok := hijacker(w) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) diff --git a/accept_test.go b/accept_test.go index 4f799126..3b45ac5c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -143,6 +143,33 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + + t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + _, err := Accept(w, r, nil) + assert.Contains(t, err, "failed to hijack connection") + }) + t.Run("closeRace", func(t *testing.T) { t.Parallel() @@ -534,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } + +type mockUnwrapper struct { + http.ResponseWriter + unwrap func() http.ResponseWriter +} + +var _ rwUnwrapper = mockUnwrapper{} + +func (mu mockUnwrapper) Unwrap() http.ResponseWriter { + return mu.unwrap() +} diff --git a/hijack.go b/hijack.go new file mode 100644 index 00000000..9cce45ca --- /dev/null +++ b/hijack.go @@ -0,0 +1,33 @@ +//go:build !js + +package websocket + +import ( + "net/http" +) + +type rwUnwrapper interface { + Unwrap() http.ResponseWriter +} + +// hijacker returns the Hijacker interface of the http.ResponseWriter. +// It follows the Unwrap method of the http.ResponseWriter if available, +// matching the behavior of http.ResponseController. If the Hijacker +// interface is not found, it returns false. +// +// Since the http.ResponseController is not available in Go 1.19, and +// does not support checking the presence of the Hijacker interface, +// this function is used to provide a consistent way to check for the +// Hijacker interface across Go versions. +func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t, true + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, false + } + } +} diff --git a/hijack_go120_test.go b/hijack_go120_test.go new file mode 100644 index 00000000..0f0673a9 --- /dev/null +++ b/hijack_go120_test.go @@ -0,0 +1,38 @@ +//go:build !js && go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coder/websocket/internal/test/assert" +) + +func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + _, _, err := http.NewResponseController(w).Hijack() + assert.Contains(t, err, "haha") + hj, ok := hijacker(w) + assert.Equal(t, "hijacker found", ok, true) + _, _, err = hj.Hijack() + assert.Contains(t, err, "haha") +}