diff --git a/crates/fs/src/fs_watcher.rs b/crates/fs/src/fs_watcher.rs index fec8b03bfe1..ef676e9ed02 100644 --- a/crates/fs/src/fs_watcher.rs +++ b/crates/fs/src/fs_watcher.rs @@ -227,20 +227,61 @@ struct WatcherRegistrationState { mode: WatcherMode, } +struct PathRegistrationState { + count: u32, + has_os_watcher: bool, +} + struct WatcherState { watchers: HashMap, - native_path_registrations: HashMap, u32>, - poll_path_registrations: HashMap, u32>, + native_path_registrations: HashMap, PathRegistrationState>, + poll_path_registrations: HashMap, PathRegistrationState>, last_registration: WatcherRegistrationId, } impl WatcherState { - fn path_registrations(&mut self, mode: WatcherMode) -> &mut HashMap, u32> { + fn path_registrations( + &mut self, + mode: WatcherMode, + ) -> &mut HashMap, PathRegistrationState> { match mode { WatcherMode::Native => &mut self.native_path_registrations, WatcherMode::Poll => &mut self.poll_path_registrations, } } + + fn remove_registration( + &mut self, + id: WatcherRegistrationId, + ) -> Option<(Arc, WatcherMode)> { + let registration_state = self.watchers.remove(&id)?; + let path_registrations = self.path_registrations(registration_state.mode); + let count = path_registrations.get_mut(®istration_state.path)?; + count.count -= 1; + if count.count != 0 { + return None; + } + + let was_actually_watched = count.has_os_watcher; + path_registrations.remove(®istration_state.path); + + was_actually_watched.then_some((registration_state.path, registration_state.mode)) + } +} + +trait WatchBackend: Send { + fn watch(&mut self, path: &Path, mode: notify::RecursiveMode) -> notify::Result<()>; + fn unwatch(&mut self, path: &Path) -> notify::Result<()>; +} + +impl WatchBackend for T { + fn watch(&mut self, path: &Path, mode: notify::RecursiveMode) -> notify::Result<()> { + notify::Watcher::watch(self, path, mode) + } + + fn unwatch(&mut self, path: &Path) -> notify::Result<()> { + notify::Watcher::unwatch(self, path) + } } pub struct GlobalWatcher { @@ -248,8 +289,8 @@ pub struct GlobalWatcher { // DANGER: never keep state lock while holding watcher lock // two mutexes because calling watcher.add triggers watcher.event, which needs watchers. - native_watcher: Mutex>, - poll_watcher: Mutex>, + native_watcher: Mutex>>, + poll_watcher: Mutex>>, } impl GlobalWatcher { @@ -280,41 +321,28 @@ impl GlobalWatcher { mode, }; state.watchers.insert(id, registration_state); - *state.path_registrations(mode).entry(path).or_insert(0) += 1; + state + .path_registrations(mode) + .entry(path) + .and_modify(|registration| registration.count += 1) + .or_insert(PathRegistrationState { + count: 1, + has_os_watcher: !path_already_covered, + }); Ok(id) } pub fn remove(&self, id: WatcherRegistrationId) { let mut state = self.state.lock(); - let Some(registration_state) = state.watchers.remove(&id) else { + let Some((path, mode)) = state.remove_registration(id) else { return; }; - - let path_registrations = state.path_registrations(registration_state.mode); - let Some(count) = path_registrations.get_mut(®istration_state.path) else { - return; - }; - *count -= 1; - if *count == 0 { - path_registrations.remove(®istration_state.path); - let path_still_covered = path_already_covered( - registration_state.path.as_ref(), - path_registrations, - registration_state.mode, - ); - - if !path_still_covered { - drop(state); - self.unwatch(®istration_state.path, registration_state.mode) - .log_err(); - } - } + drop(state); + self.unwatch(&path, mode).log_err(); } fn watch(&self, path: &Path, mode: WatcherMode) -> anyhow::Result<()> { - use notify::Watcher; - match mode { WatcherMode::Native => { self.ensure_native_watcher()?; @@ -345,8 +373,6 @@ impl GlobalWatcher { } fn unwatch(&self, path: &Path, mode: WatcherMode) -> anyhow::Result<()> { - use notify::Watcher; - match mode { WatcherMode::Native => { if let Some(watcher) = self.native_watcher.lock().as_mut() { @@ -369,7 +395,7 @@ impl GlobalWatcher { } let watcher = notify::recommended_watcher(handle_native_event)?; - *self.native_watcher.lock() = Some(watcher); + *self.native_watcher.lock() = Some(Box::new(watcher)); Ok(()) } @@ -380,14 +406,14 @@ impl GlobalWatcher { let config = notify::Config::default().with_poll_interval(*POLL_INTERVAL); let watcher = notify::PollWatcher::new(handle_poll_event, config)?; - *self.poll_watcher.lock() = Some(watcher); + *self.poll_watcher.lock() = Some(Box::new(watcher)); Ok(()) } } fn path_already_covered( path: &Path, - path_registrations: &HashMap, u32>, + path_registrations: &HashMap, PathRegistrationState>, mode: WatcherMode, ) -> bool { (mode == WatcherMode::Poll || cfg!(any(target_os = "windows", target_os = "macos"))) @@ -470,7 +496,7 @@ fn handle_event(mode: WatcherMode, event: Result) #[cfg(test)] mod tests { use super::*; - use std::path::PathBuf; + use std::{collections::HashSet, path::PathBuf}; fn rescan(path: &str) -> PathEvent { PathEvent { @@ -486,6 +512,49 @@ mod tests { } } + #[derive(Default)] + struct FakeWatchBackend { + watched_paths: HashSet, + watch_calls: Vec, + unwatch_calls: Vec, + } + + struct SharedFakeWatchBackend(Arc>); + + impl WatchBackend for SharedFakeWatchBackend { + fn watch(&mut self, path: &Path, _mode: notify::RecursiveMode) -> notify::Result<()> { + let path = path.to_path_buf(); + let mut backend = self.0.lock(); + backend.watch_calls.push(path.clone()); + backend.watched_paths.insert(path); + Ok(()) + } + + fn unwatch(&mut self, path: &Path) -> notify::Result<()> { + let path = path.to_path_buf(); + let mut backend = self.0.lock(); + backend.unwatch_calls.push(path.clone()); + if backend.watched_paths.remove(&path) { + Ok(()) + } else { + Err(notify::Error::generic("path was not watched")) + } + } + } + + fn test_watcher(poll_watcher: Arc>) -> GlobalWatcher { + GlobalWatcher { + state: Mutex::new(WatcherState { + watchers: Default::default(), + native_path_registrations: Default::default(), + poll_path_registrations: Default::default(), + last_registration: Default::default(), + }), + native_watcher: Mutex::new(None), + poll_watcher: Mutex::new(Some(Box::new(SharedFakeWatchBackend(poll_watcher)))), + } + } + struct TestCase { name: &'static str, pending_paths: Vec, @@ -494,6 +563,28 @@ mod tests { expected_path_events: Vec, } + #[test] + fn covered_child_registration_is_not_unwatched_after_parent_is_removed() { + let backend = Arc::new(Mutex::new(FakeWatchBackend::default())); + let watcher = test_watcher(backend.clone()); + let parent = Arc::::from(Path::new("/repo")); + let child = Arc::::from(Path::new("/repo/foo.csproj")); + + let parent_registration = watcher + .add(parent.as_ref().into(), WatcherMode::Poll, |_| {}) + .expect("add parent watch"); + let child_registration = watcher + .add(child.as_ref().into(), WatcherMode::Poll, |_| {}) + .expect("add covered child watch"); + + watcher.remove(parent_registration); + watcher.remove(child_registration); + + let backend = backend.lock(); + assert_eq!(backend.watch_calls, &[parent.to_path_buf()]); + assert_eq!(backend.unwatch_calls, &[parent.to_path_buf()]); + } + #[test] fn test_coalesce_pending_rescans() { let test_cases = [