diff --git a/broker.go b/broker.go index 61bfdd336f..bb360093df 100644 --- a/broker.go +++ b/broker.go @@ -101,6 +101,12 @@ func newBroker(site *Site, editable, noStore, noLog, keepAppLive, debug bool) *B } } +func (b *Broker) getClient(id string) *Client { + b.unicastsMux.RLock() + defer b.unicastsMux.RUnlock() + return b.clientsByID[id] +} + func (b *Broker) addApp(mode, route, addr, keyID, keySecret string) { s := newApp(b, mode, route, addr, keyID, keySecret) diff --git a/client.go b/client.go index df88541228..113d6681be 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "net/http" + "sync" "time" "github.com/google/uuid" @@ -25,11 +26,15 @@ import ( ) const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Maximum message size allowed from peer. - maxMessageSize = 1 * 1024 * 1024 // bytes + writeWait = 10 * time.Second // Time allowed to write a message to the peer. + maxMessageSize = 1 * 1024 * 1024 // bytes Maximum message size allowed from peer. + // TODO: Refactor into iota. + STATE_CREATED = "CREATED" + STATE_TIMEOUT = "TIMEOUT" + STATE_LISTEN = "LISTEN" + STATE_RECONNECT = "RECONNECT" + STATE_DISCONNECT = "DISCONNECT" + STATE_CLOSED = "CLOSED" ) var ( @@ -63,16 +68,17 @@ type Client struct { header *http.Header // forwarded headers from the WS connection appPath string // path of the app this client is connected to, doesn't change throughout WS lifetime pingInterval time.Duration - isReconnect bool - cancel context.CancelFunc reconnectTimeout time.Duration + lock *sync.Mutex + state string } // TODO: Refactor some of the params into a Config struct. func newClient(addr string, auth *Auth, session *Session, broker *Broker, conn *websocket.Conn, editable bool, - baseURL string, header *http.Header, pingInterval time.Duration, isReconnect bool, reconnectTimeout time.Duration) *Client { + baseURL string, header *http.Header, pingInterval time.Duration, reconnectTimeout time.Duration) *Client { id := uuid.New().String() - return &Client{id, auth, addr, session, broker, conn, nil, make(chan []byte, 256), editable, baseURL, header, "", pingInterval, isReconnect, nil, reconnectTimeout} + return &Client{id, auth, addr, session, broker, conn, nil, make(chan []byte, 256), + editable, baseURL, header, "", pingInterval, reconnectTimeout, &sync.Mutex{}, STATE_CREATED} } func (c *Client) refreshToken() error { @@ -90,29 +96,44 @@ func (c *Client) refreshToken() error { return nil } +func (c *Client) setState(newState string) { + c.lock.Lock() + c.state = newState + c.lock.Unlock() +} + func (c *Client) listen() { defer func() { - ctx, cancel := context.WithCancel(context.Background()) - c.cancel = cancel - go func(ctx context.Context) { - select { - // Send disconnect message only if client doesn't reconnect within the specified timeframe. - case <-time.After(c.reconnectTimeout): - app := c.broker.getApp(c.appPath) - if app != nil { - app.forward(c.id, c.session, disconnectMsg) - if err := app.disconnect(c.id); err != nil { - echo(Log{"t": "disconnect", "client": c.addr, "route": c.appPath, "err": err.Error()}) - } - } + c.lock.Lock() + defer c.lock.Unlock() + if c.state != STATE_DISCONNECT { + return + } + // This defer runs to completion. If the client drops, reconnects and drops out again, ignore first drop timeout. + timeoutID := STATE_TIMEOUT + c.addr + c.state = timeoutID + c.lock.Unlock() - c.broker.unsubscribe <- c - case <-ctx.Done(): + select { + // Send disconnect message only if client doesn't reconnect within the specified timeframe. + case <-time.After(c.reconnectTimeout): + c.lock.Lock() + if c.state != timeoutID { return } - }(ctx) + app := c.broker.getApp(c.appPath) + if app != nil { + app.forward(c.id, c.session, disconnectMsg) + if err := app.disconnect(c.id); err != nil { + echo(Log{"t": "disconnect", "client": c.addr, "route": c.appPath, "err": err.Error()}) + } + } - c.conn.Close() + echo(Log{"t": "client_unsubscribe", "client": c.id}) + c.broker.unsubscribe <- c + c.state = STATE_CLOSED + return + } }() // Time allowed to read the next pong message from the peer. Must be greater than ping interval. pongWait := 10 * c.pingInterval / 9 @@ -127,10 +148,8 @@ func (c *Client) listen() { if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { echo(Log{"t": "socket_read", "client": c.addr, "err": err.Error()}) - } else { - // Firefox follows spec closely and requires a close message to be sent before closing the connection. - c.conn.WriteMessage(websocket.CloseMessage, []byte{}) } + c.setState(STATE_DISCONNECT) break } @@ -173,7 +192,10 @@ func (c *Client) listen() { c.broker.sendAll(c.broker.clients[app.route], clearStateMsg) } case watchMsgT: - if c.isReconnect { + c.lock.Lock() + state := c.state + c.lock.Unlock() + if state == STATE_RECONNECT { continue } c.subscribe(m.addr) // subscribe even if page is currently NA @@ -238,10 +260,13 @@ func (c *Client) flush() { defer func() { ticker.Stop() c.conn.Close() + c.lock.Unlock() }() for { select { case data, ok := <-c.data: + // An alternative to the mutex here would be a new channel for closing the connection so it does not race with reconnect. + c.lock.Lock() c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // broker closed the channel. @@ -265,11 +290,14 @@ func (c *Client) flush() { if err := w.Close(); err != nil { return } + c.lock.Unlock() case <-ticker.C: + c.lock.Lock() c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } + c.lock.Unlock() } } } diff --git a/conf.go b/conf.go index 0b5e2485dc..e2543df177 100644 --- a/conf.go +++ b/conf.go @@ -114,6 +114,6 @@ type Conf struct { SkipLogin bool `cfg:"oidc-skip-login" env:"H2O_WAVE_OIDC_SKIP_LOGIN" cfgDefault:"false" cfgHelper:"do not display the login form during OIDC authorization"` KeepAppLive bool `cfg:"keep-app-live" env:"H2O_WAVE_KEEP_APP_LIVE" cfgDefault:"false" cfgHelper:"do not unregister unresponsive apps"` Conf string `cfg:"conf" env:"H2O_WAVE_CONF" cfgDefault:".env" cfgHelper:"path to configuration file"` - ReconnectTimeout string `cfg:"reconnect-timeout" env:"H2O_WAVE_RECONNECT_TIMEOUT" cfgDefault:"2s" cfgHelper:"Time to wait for reconnect before dropping the client"` + ReconnectTimeout string `cfg:"reconnect-timeout" env:"H2O_WAVE_RECONNECT_TIMEOUT" cfgDefault:"5s" cfgHelper:"Time to wait for reconnect before dropping the client"` AllowedOrigins string `cfg:"allowed-origins" env:"H2O_WAVE_ALLOWED_ORIGINS" cfgDefault:"" cfgHelper:"comma-separated list of allowed origins (e.g. http://foo.com) for websocket upgrades"` } diff --git a/socket.go b/socket.go index fe99027da9..35b8fc265e 100644 --- a/socket.go +++ b/socket.go @@ -84,28 +84,29 @@ func (s *SocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } clientID := r.URL.Query().Get("client-id") - client, ok := s.broker.clientsByID[clientID] - if ok { + client := s.broker.getClient(clientID) + if client != nil { + client.lock.Lock() + // Close prev connection gracefully. + client.conn.WriteMessage(websocket.CloseMessage, []byte{}) + client.conn.Close() client.conn = conn - client.isReconnect = true - if client.cancel != nil { - client.cancel() - } - if s.broker.debug { - echo(Log{"t": "socket_reconnect", "client_id": clientID, "addr": getRemoteAddr(r)}) - } + client.state = STATE_RECONNECT + client.addr = getRemoteAddr(r) + client.lock.Unlock() + echo(Log{"t": "client_reconnect", "client_id": client.id, "addr": getRemoteAddr(r)}) } else { - client = newClient(getRemoteAddr(r), s.auth, session, s.broker, conn, s.editable, s.baseURL, &header, s.pingInterval, false, s.reconnectTimeout) - } + client = newClient(getRemoteAddr(r), s.auth, session, s.broker, conn, s.editable, s.baseURL, &header, s.pingInterval, s.reconnectTimeout) - if msg, err := json.Marshal(OpsD{I: client.id}); err == nil { - sw, err := conn.NextWriter(websocket.TextMessage) + helloMsg, err := json.Marshal(OpsD{I: client.id}) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - sw.Write(msg) - sw.Close() + if !client.send(helloMsg) { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } go client.flush() diff --git a/ui/src/core.ts b/ui/src/core.ts index c07883f46a..2dccc6ef50 100644 --- a/ui/src/core.ts +++ b/ui/src/core.ts @@ -946,12 +946,11 @@ export const const slug = window.location.pathname, reconnect = (address: S) => { - if (_clientID && !address.includes('?client-id')) { - address = `${address}?${new URLSearchParams({ 'client-id': _clientID })}` - } + let wsAddr = address + if (_clientID) wsAddr = `${address}?${new URLSearchParams({ 'client-id': _clientID })}` const retry = () => reconnect(address) - const socket = new WebSocket(address) + const socket = new WebSocket(wsAddr) socket.onopen = () => { _reconnectFailures = 0 _socket = socket diff --git a/website/docs/routing.md b/website/docs/routing.md index ce1bd8bd9b..1fdd0d17a9 100644 --- a/website/docs/routing.md +++ b/website/docs/routing.md @@ -345,7 +345,7 @@ Note that when a user logs out of the Wave daemon, all the apps linked to the da ### Handling client (browser tab) disconnect -To get notified when a user closes the tab, use the system-wide `@system.client_disconnect` event. +To get notified when a user closes the tab, use the system-wide `@system.client_disconnect` event. The time if takes for this function to be called depends on the value of `H2O_WAVE_RECONNECT_TIMEOUT` (which defaults to `5s`). ```py @on('@system.client_disconnect')