diff --git a/internal/api/room/control.go b/internal/api/room/control.go index 7395db33..ad6d27b3 100644 --- a/internal/api/room/control.go +++ b/internal/api/room/control.go @@ -19,23 +19,22 @@ type ControlTargetPayload struct { } func (h *RoomHandler) controlStatus(w http.ResponseWriter, r *http.Request) error { - host := h.sessions.GetHost() + host, hasHost := h.sessions.GetHost() - if host != nil { - return utils.HttpSuccess(w, ControlStatusPayload{ - HasHost: true, - HostId: host.ID(), - }) + var hostId string + if hasHost { + hostId = host.ID() } return utils.HttpSuccess(w, ControlStatusPayload{ - HasHost: false, + HasHost: hasHost, + HostId: hostId, }) } func (h *RoomHandler) controlRequest(w http.ResponseWriter, r *http.Request) error { - host := h.sessions.GetHost() - if host != nil { + _, hasHost := h.sessions.GetHost() + if hasHost { return utils.HttpUnprocessableEntity("there is already a host") } @@ -82,9 +81,9 @@ func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) error } func (h *RoomHandler) controlReset(w http.ResponseWriter, r *http.Request) error { - host := h.sessions.GetHost() + _, hasHost := h.sessions.GetHost() - if host != nil { + if hasHost { h.desktop.ResetKeys() h.sessions.ClearHost() } diff --git a/internal/http/logger.go b/internal/http/logger.go index 3cdf5922..258eec54 100644 --- a/internal/http/logger.go +++ b/internal/http/logger.go @@ -49,7 +49,7 @@ type logEntry struct { logger zerolog.Logger err error panic *logPanic - session *types.Session + session types.Session } type logPanic struct { @@ -69,7 +69,7 @@ func (e *logEntry) Error(err error) { } func (e *logEntry) SetSession(session types.Session) { - e.session = &session + e.session = session } func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra any) { @@ -83,7 +83,7 @@ func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Dur // add session ID to logs (if exists) if e.session != nil { - logger = logger.With().Str("session_id", (*e.session).ID()).Logger() + logger = logger.With().Str("session_id", e.session.ID()).Logger() } // handle panic error message diff --git a/internal/session/manager.go b/internal/session/manager.go index ee82e7a3..6498d18d 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -3,6 +3,7 @@ package session import ( "errors" "sync" + "sync/atomic" "github.com/kataras/go-events" "github.com/rs/zerolog" @@ -62,8 +63,7 @@ type SessionManagerCtx struct { sessions map[string]*SessionCtx sessionsMu sync.Mutex - host types.Session - hostMu sync.Mutex + hostId atomic.Value cursors map[types.Session][]types.Cursor cursorsMu sync.Mutex @@ -188,24 +188,33 @@ func (manager *SessionManagerCtx) List() []types.Session { // --- func (manager *SessionManagerCtx) SetHost(host types.Session) { - manager.hostMu.Lock() - manager.host = host - manager.hostMu.Unlock() + var hostId string + if host != nil { + hostId = host.ID() + } + manager.hostId.Store(hostId) manager.emmiter.Emit("host_changed", host) } -func (manager *SessionManagerCtx) GetHost() types.Session { - manager.hostMu.Lock() - defer manager.hostMu.Unlock() +func (manager *SessionManagerCtx) GetHost() (types.Session, bool) { + hostId, ok := manager.hostId.Load().(string) + if !ok || hostId == "" { + return nil, false + } - return manager.host + return manager.Get(hostId) } func (manager *SessionManagerCtx) ClearHost() { manager.SetHost(nil) } +func (manager *SessionManagerCtx) isHost(host types.Session) bool { + hostId, ok := manager.hostId.Load().(string) + return ok && hostId == host.ID() +} + // --- // cursors // --- diff --git a/internal/session/session.go b/internal/session/session.go index 43af7173..ec1dff77 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -2,6 +2,7 @@ package session import ( "sync" + "time" "github.com/rs/zerolog" @@ -9,6 +10,10 @@ import ( "github.com/demodesk/neko/pkg/types/event" ) +// client is expected to reconnect within 5 second +// if some unexpected websocket disconnect happens +const WS_DELAYED_DURATION = 5 * time.Second + type SessionCtx struct { id string token string @@ -20,6 +25,10 @@ type SessionCtx struct { websocketPeer types.WebSocketPeer websocketMu sync.Mutex + // websocket delayed set connected events + wsDelayedMu sync.Mutex + wsDelayedTimer *time.Timer + webrtcPeer types.WebRTCPeer webrtcMu sync.Mutex } @@ -56,7 +65,7 @@ func (session *SessionCtx) State() types.SessionState { } func (session *SessionCtx) IsHost() bool { - return session.manager.GetHost() == session + return session.manager.isHost(session) } func (session *SessionCtx) PrivateModeEnabled() bool { @@ -83,7 +92,7 @@ func (session *SessionCtx) SetWebSocketPeer(websocketPeer types.WebSocketPeer) { } } -func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPeer, connected bool) { +func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPeer, connected bool, delayed bool) { session.websocketMu.Lock() isCurrentPeer := websocketPeer == session.websocketPeer session.websocketMu.Unlock() @@ -94,8 +103,36 @@ func (session *SessionCtx) SetWebSocketConnected(websocketPeer types.WebSocketPe session.logger.Info(). Bool("connected", connected). + Bool("delayed", delayed). Msg("set websocket connected") + // + // ws delayed + // + + var wsDelayedTimer *time.Timer + + if delayed { + wsDelayedTimer = time.AfterFunc(WS_DELAYED_DURATION, func() { + session.SetWebSocketConnected(websocketPeer, connected, false) + }) + } + + session.wsDelayedMu.Lock() + if session.wsDelayedTimer != nil { + session.wsDelayedTimer.Stop() + } + session.wsDelayedTimer = wsDelayedTimer + session.wsDelayedMu.Unlock() + + if delayed { + return + } + + // + // not delayed + // + session.state.IsConnected = connected if connected { diff --git a/internal/websocket/filechooserdialog.go b/internal/websocket/filechooserdialog.go index a2fab187..e134930e 100644 --- a/internal/websocket/filechooserdialog.go +++ b/internal/websocket/filechooserdialog.go @@ -13,8 +13,8 @@ func (manager *WebSocketManagerCtx) fileChooserDialogEvents() { manager.desktop.OnFileChooserDialogOpened(func() { manager.logger.Info().Msg("file chooser dialog opened") - host := manager.sessions.GetHost() - if host == nil { + host, hasHost := manager.sessions.GetHost() + if !hasHost { manager.logger.Warn().Msg("no host for file chooser dialog found, closing") go manager.desktop.CloseFileChooserDialog() return diff --git a/internal/websocket/handler/control.go b/internal/websocket/handler/control.go index 67cfcfa4..9a21cd1d 100644 --- a/internal/websocket/handler/control.go +++ b/internal/websocket/handler/control.go @@ -42,7 +42,7 @@ func (h *MessageHandlerCtx) controlRequest(session types.Session) error { if !h.sessions.Settings().ImplicitHosting { // tell session if there is a host - if host := h.sessions.GetHost(); host != nil { + if host, hasHost := h.sessions.GetHost(); hasHost { session.Send( event.CONTROL_HOST, message.ControlHost{ diff --git a/internal/websocket/handler/system.go b/internal/websocket/handler/system.go index 6d0f2f66..446c4330 100644 --- a/internal/websocket/handler/system.go +++ b/internal/websocket/handler/system.go @@ -12,14 +12,16 @@ import ( ) func (h *MessageHandlerCtx) systemInit(session types.Session) error { - host := h.sessions.GetHost() + host, hasHost := h.sessions.GetHost() - controlHost := message.ControlHost{ - HasHost: host != nil, + var hostID string + if hasHost { + hostID = host.ID() } - if controlHost.HasHost { - controlHost.HostID = host.ID() + controlHost := message.ControlHost{ + HasHost: hasHost, + HostID: hostID, } size := h.desktop.GetScreenSize() diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index b5eb0ad2..ac9b3656 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -132,8 +132,8 @@ func (manager *WebSocketManagerCtx) Start() { }) manager.desktop.OnClipboardUpdated(func() { - session := manager.sessions.GetHost() - if session == nil || !session.Profile().CanAccessClipboard { + host, hasHost := manager.sessions.GetHost() + if !hasHost || !host.Profile().CanAccessClipboard { return } @@ -145,7 +145,7 @@ func (manager *WebSocketManagerCtx) Start() { return } - session.Send( + host.Send( event.CLIPBOARD_UPDATED, message.ClipboardData{ Text: data.Text, @@ -232,26 +232,47 @@ func (manager *WebSocketManagerCtx) connect(connection *websocket.Conn, r *http. Str("agent", r.UserAgent()). Msg("connection started") - session.SetWebSocketConnected(peer, true) + session.SetWebSocketConnected(peer, true, false) - defer func() { - logger.Info(). - Str("address", connection.RemoteAddr().String()). - Str("agent", r.UserAgent()). - Msg("connection ended") + // this is a blocking function that lives + // throughout whole websocket connection + err = manager.handle(connection, peer, session) - session.SetWebSocketConnected(peer, false) - }() + logger.Info(). + Str("address", connection.RemoteAddr().String()). + Str("agent", r.UserAgent()). + Msg("connection ended") - manager.handle(connection, peer, session) + delayedDisconnect := false + + e, ok := err.(*websocket.CloseError) + if !ok { + logger.Err(err).Msg("read message error") + // client is expected to reconnect soon + delayedDisconnect = true + } else { + switch e.Code { + case websocket.CloseNormalClosure: + logger.Info().Str("reason", e.Text).Msg("websocket close") + case websocket.CloseGoingAway: + logger.Info().Str("reason", "going away").Msg("websocket close") + default: + logger.Warn().Err(err).Msg("websocket close") + // abnormal websocket closure: + // client is expected to reconnect soon + delayedDisconnect = true + } + } + + session.SetWebSocketConnected(peer, false, delayedDisconnect) } -func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer types.WebSocketPeer, session types.Session) { +func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer types.WebSocketPeer, session types.Session) error { // add session id to logger context logger := manager.logger.With().Str("session_id", session.ID()).Logger() bytes := make(chan []byte) - cancel := make(chan struct{}) + cancel := make(chan error) ticker := time.NewTicker(pingPeriod) defer ticker.Stop() @@ -263,13 +284,7 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer type for { _, raw, err := connection.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - logger.Warn().Err(err).Msg("read message error") - } else { - logger.Debug().Err(err).Msg("read message error") - } - - close(cancel) + cancel <- err break } @@ -306,15 +321,14 @@ func (manager *WebSocketManagerCtx) handle(connection *websocket.Conn, peer type if !handled { logger.Warn().Str("event", data.Event).Msg("unhandled message") } - case <-cancel: - return + case err := <-cancel: + return err case <-manager.shutdown: peer.Destroy("connection shutdown") - return + return nil case <-ticker.C: if err := peer.Ping(); err != nil { - logger.Err(err).Msg("ping message has failed") - return + return err } } } diff --git a/pkg/types/session.go b/pkg/types/session.go index 10e93244..97b8183c 100644 --- a/pkg/types/session.go +++ b/pkg/types/session.go @@ -41,7 +41,7 @@ type Session interface { // websocket SetWebSocketPeer(websocketPeer WebSocketPeer) - SetWebSocketConnected(websocketPeer WebSocketPeer, connected bool) + SetWebSocketConnected(websocketPeer WebSocketPeer, connected bool, delayed bool) GetWebSocketPeer() WebSocketPeer Send(event string, payload any) @@ -60,7 +60,7 @@ type SessionManager interface { List() []Session SetHost(host Session) - GetHost() Session + GetHost() (Session, bool) ClearHost() SetCursor(cursor Cursor, session Session)