diff --git a/go.mod b/go.mod index 1a7afd50..c67f3ad9 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/gorilla/websocket go 1.12 + +require golang.org/x/tools v0.0.0-20200619210111-0f592d2728bb diff --git a/go.sum b/go.sum index e69de29b..ec95b5c2 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619210111-0f592d2728bb h1:/7SQoPdMxZ0c/Zu9tBJgMbRE/BmK6i9QXflNJXKAmw0= +golang.org/x/tools v0.0.0-20200619210111-0f592d2728bb/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/server.go b/server.go index 887d5589..41404fb1 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ package websocket import ( "bufio" "errors" + "fmt" "io" "net/http" "net/url" @@ -44,6 +45,7 @@ type Upgrader struct { // WriteBufferSize. WriteBufferPool BufferPool + // Subprotocols have lower priority than NegotiateSuprotocol. // Subprotocols specifies the server's supported protocols in order of // preference. If this field is not nil, then the Upgrade method negotiates a // subprotocol by selecting the first match in this list with a protocol @@ -70,6 +72,13 @@ type Upgrader struct { // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool + // NegotiateSubprotocol has higher priority than Subprotocols. + // NegotiateSubprotocol returns the negotiated subprotocol for the handshake + // request. If the returned string is "", then the the Sec-Websocket-Protocol header + // is not included in the handshake response. If the function returns an error, then + // Upgrade responds to the client with http.StatusBadRequest. + // If this function is not nil, then the Upgrader.Subportocols field is ignored. + NegotiateSubprotocol func(r *http.Request) (string, error) } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -96,7 +105,7 @@ func checkSameOrigin(r *http.Request) bool { return equalASCIIFold(u.Host, r.Host) } -func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { +func (u *Upgrader) selectSubprotocol(r *http.Request) string { if u.Subprotocols != nil { clientProtocols := Subprotocols(r) for _, serverProtocol := range u.Subprotocols { @@ -106,8 +115,6 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header } } } - } else if responseHeader != nil { - return responseHeader.Get("Sec-Websocket-Protocol") } return "" } @@ -115,11 +122,14 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade -// request. Use the responseHeader to specify cookies (Set-Cookie) and the -// application negotiated subprotocol (Sec-WebSocket-Protocol). +// request. Use the responseHeader to specify cookies (Set-Cookie). // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. +// +// The responseHeader does not support negotiated subprotocol(Sec-Websocket-Protocol) +// IF necessary,please use Upgrader.NegotiateSubprotocol and Upgrader.Subprotocols +// Use the method to view the Upgrader struct. func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { const badHandshake = "websocket: the client is not using the websocket protocol: " @@ -156,7 +166,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") } - subprotocol := u.selectSubprotocol(r, responseHeader) + subprotocol := "" + if u.NegotiateSubprotocol != nil { + str, err := u.NegotiateSubprotocol(r) + if err != nil { + return u.returnError(w, r, http.StatusBadRequest, fmt.Sprintf("websocket:handshake negotiation protocol error:%s", err)) + } + subprotocol = str + } else { + subprotocol = u.selectSubprotocol(r) + } // Negotiate PMCE var compress bool diff --git a/server_test.go b/server_test.go index 456c1db5..52ae1d19 100644 --- a/server_test.go +++ b/server_test.go @@ -7,8 +7,10 @@ package websocket import ( "bufio" "bytes" + "errors" "net" "net/http" + "net/http/httptest" "reflect" "strings" "testing" @@ -117,3 +119,74 @@ func TestBufioReuse(t *testing.T) { } } } + +var negotiateSubprotocolTests = []struct { + *Upgrader + match bool + shouldErr bool +}{ + { + &Upgrader{ + NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "json", nil }, + }, true, false, + }, + { + &Upgrader{ + Subprotocols: []string{"json"}, + }, true, false, + }, + { + &Upgrader{ + Subprotocols: []string{"not-match"}, + }, false, false, + }, + { + &Upgrader{ + NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "", errors.New("not-match") }, + }, false, true, + }, +} + +func TestNegotiateSubprotocol(t *testing.T) { + for i := range negotiateSubprotocolTests { + upgrade := negotiateSubprotocolTests[i].Upgrader + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrade.Upgrade(w, r, nil) + })) + + req, err := http.NewRequest("GET", s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest retuened error %v", err) + } + + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-Websocket-Version", "13") + req.Header.Set("Sec-Websocket-Protocol", "json") + req.Header.Set("Sec-Websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + + if negotiateSubprotocolTests[i].shouldErr && resp.StatusCode != http.StatusBadRequest { + t.Errorf("The expecred status code is %d,actual status code is %d", http.StatusBadRequest, resp.StatusCode) + } else { + if negotiateSubprotocolTests[i].match { + protocol := resp.Header.Get("Sec-Websocket-Protocol") + if protocol != "json" { + t.Errorf("Negotiation protocol failed,request protocol is json,reponese protocol is %s", protocol) + } + } else { + if _, ok := resp.Header["Sec-Websocket-Protocol"]; ok { + t.Errorf("Negotiation protocol failed,Sec-Websocket-Protocol field should be empty") + } + } + } + s.Close() + resp.Body.Close() + } + +}