mod db; mod legacy_thread; mod native_agent_server; pub mod outline; mod pattern_extraction; mod templates; #[cfg(test)] mod tests; mod thread; mod thread_store; mod tool_permissions; mod tools; use context_server::ContextServerId; pub use db::*; use itertools::Itertools; pub use native_agent_server::NativeAgentServer; pub use pattern_extraction::*; pub use shell_command_parser::extract_commands; pub use templates::*; pub use thread::*; pub use thread_store::*; pub use tool_permissions::*; pub use tools::*; use acp_thread::{ AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest, AgentSessionListResponse, TokenUsageRatio, UserMessageId, }; use agent_client_protocol::schema as acp; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; use futures::channel::{mpsc, oneshot}; use futures::future::Shared; use futures::{FutureExt as _, StreamExt as _, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task, WeakEntity, }; use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry}; use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; use settings::{LanguageModelSelection, Settings as _, update_settings_file}; use std::any::Any; use std::path::PathBuf; use std::rc::Rc; use std::sync::{Arc, LazyLock}; use util::ResultExt; use util::path_list::PathList; use util::rel_path::RelPath; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProjectSnapshot { pub worktree_snapshots: Vec, pub timestamp: DateTime, } pub struct RulesLoadingError { pub message: SharedString, } struct ProjectState { project: Entity, project_context: Entity, project_context_needs_refresh: watch::Sender<()>, _maintain_project_context: Task>, context_server_registry: Entity, _subscriptions: Vec, } /// Holds both the internal Thread and the AcpThread for a session struct Session { /// The internal thread that processes messages thread: Entity, /// The ACP thread that handles protocol communication acp_thread: Entity, project_id: EntityId, pending_save: Task>, _subscriptions: Vec, ref_count: usize, } struct PendingSession { task: Shared, Arc>>>, ref_count: usize, } pub struct LanguageModels { /// Access language model by ID models: HashMap>, /// Cached list for returning language model information model_list: acp_thread::AgentModelList, refresh_models_rx: watch::Receiver<()>, refresh_models_tx: watch::Sender<()>, _authenticate_all_providers_task: Task<()>, } impl LanguageModels { fn new(cx: &mut App) -> Self { let (refresh_models_tx, refresh_models_rx) = watch::channel(()); let mut this = Self { models: HashMap::default(), model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), refresh_models_rx, refresh_models_tx, _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx), }; this.refresh_list(cx); this } fn refresh_list(&mut self, cx: &App) { let providers = LanguageModelRegistry::global(cx) .read(cx) .visible_providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); let mut language_model_list = IndexMap::default(); let mut recommended_models = HashSet::default(); let mut recommended = Vec::new(); for provider in &providers { for model in provider.recommended_models(cx) { recommended_models.insert((model.provider_id(), model.id())); recommended.push(Self::map_language_model_to_info(&model, provider)); } } if !recommended.is_empty() { language_model_list.insert( acp_thread::AgentModelGroupName("Recommended".into()), recommended, ); } let mut models = HashMap::default(); for provider in providers { let mut provider_models = Vec::new(); for model in provider.provided_models(cx) { let model_info = Self::map_language_model_to_info(&model, &provider); let model_id = model_info.id.clone(); provider_models.push(model_info); models.insert(model_id, model); } if !provider_models.is_empty() { language_model_list.insert( acp_thread::AgentModelGroupName(provider.name().0.clone()), provider_models, ); } } self.models = models; self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); self.refresh_models_tx.send(()).ok(); } fn watch(&self) -> watch::Receiver<()> { self.refresh_models_rx.clone() } pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option> { self.models.get(model_id).cloned() } fn map_language_model_to_info( model: &Arc, provider: &Arc, ) -> acp_thread::AgentModelInfo { acp_thread::AgentModelInfo { id: Self::model_id(model), name: model.name().0, description: None, icon: Some(match provider.icon() { IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path), IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name), }), is_latest: model.is_latest(), cost: model.model_cost_info().map(|cost| cost.to_shared_string()), } } fn model_id(model: &Arc) -> acp::ModelId { acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0)) } fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> { let authenticate_all_providers = LanguageModelRegistry::global(cx) .read(cx) .visible_providers() .iter() .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) .collect::>(); cx.spawn(async move |cx| { for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { if let Err(err) = authenticate_task.await { match err { language_model::AuthenticateError::CredentialsNotFound => { // Since we're authenticating these providers in the // background for the purposes of populating the // language selector, we don't care about providers // where the credentials are not found. } language_model::AuthenticateError::ConnectionRefused => { // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures. // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it. // TODO: Better manage LM Studio auth logic to avoid these noisy failures. } _ => { // Some providers have noisy failure states that we // don't want to spam the logs with every time the // language model selector is initialized. // // Ideally these should have more clear failure modes // that we know are safe to ignore here, like what we do // with `CredentialsNotFound` above. match provider_id.0.as_ref() { "lmstudio" | "ollama" => { // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated". // // These fail noisily, so we don't log them. } "copilot_chat" => { // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors. } _ => { log::error!( "Failed to authenticate provider: {}: {err:#}", provider_name.0 ); } } } } } } cx.update(language_models::update_environment_fallback_model); }) } } pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, pending_sessions: HashMap, thread_store: Entity, /// Project-specific state keyed by project EntityId projects: HashMap, /// Shared templates for all threads templates: Arc, /// Cached model information models: LanguageModels, prompt_store: Option>, fs: Arc, _subscriptions: Vec, } impl NativeAgent { pub fn new( thread_store: Entity, templates: Arc, prompt_store: Option>, fs: Arc, cx: &mut App, ) -> Entity { log::debug!("Creating new NativeAgent"); cx.new(|cx| { let mut subscriptions = vec![cx.subscribe( &LanguageModelRegistry::global(cx), Self::handle_models_updated_event, )]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) } Self { sessions: HashMap::default(), pending_sessions: HashMap::default(), thread_store, projects: HashMap::default(), templates, models: LanguageModels::new(cx), prompt_store, fs, _subscriptions: subscriptions, } }) } fn new_session( &mut self, project: Entity, cx: &mut Context, ) -> Entity { let project_id = self.get_or_create_project_state(&project, cx); let project_state = &self.projects[&project_id]; let registry = LanguageModelRegistry::read_global(cx); let available_count = registry.available_models(cx).count(); log::debug!("Total available models: {}", available_count); let default_model = registry.default_model().and_then(|default_model| { self.models .model_from_id(&LanguageModels::model_id(&default_model.model)) }); let thread = cx.new(|cx| { Thread::new( project, project_state.project_context.clone(), project_state.context_server_registry.clone(), self.templates.clone(), default_model, cx, ) }); self.register_session(thread, project_id, 1, cx) } fn register_session( &mut self, thread_handle: Entity, project_id: EntityId, ref_count: usize, cx: &mut Context, ) -> Entity { let connection = Rc::new(NativeAgentConnection(cx.entity())); let thread = thread_handle.read(cx); let session_id = thread.id().clone(); let parent_session_id = thread.parent_thread_id(); let title = thread.title(); let draft_prompt = thread.draft_prompt().map(Vec::from); let scroll_position = thread.ui_scroll_position(); let token_usage = thread.latest_token_usage(); let project = thread.project.clone(); let action_log = thread.action_log.clone(); let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); let acp_thread = cx.new(|cx| { let mut acp_thread = acp_thread::AcpThread::new( parent_session_id, title, None, connection, project.clone(), action_log.clone(), session_id.clone(), prompt_capabilities_rx, cx, ); acp_thread.set_draft_prompt(draft_prompt, cx); acp_thread.set_ui_scroll_position(scroll_position); acp_thread.update_token_usage(token_usage, cx); acp_thread }); let registry = LanguageModelRegistry::read_global(cx); let summarization_model = registry.thread_summary_model(cx).map(|c| c.model); let weak = cx.weak_entity(); let weak_thread = thread_handle.downgrade(); thread_handle.update(cx, |thread, cx| { thread.set_summarization_model(summarization_model, cx); thread.add_default_tools( Rc::new(NativeThreadEnvironment { acp_thread: acp_thread.downgrade(), thread: weak_thread, agent: weak, }) as _, cx, ) }); let subscriptions = vec![ cx.subscribe(&thread_handle, Self::handle_thread_title_updated), cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), cx.observe(&thread_handle, move |this, thread, cx| { this.save_thread(thread, cx) }), ]; self.sessions.insert( session_id, Session { thread: thread_handle, acp_thread: acp_thread.clone(), project_id, _subscriptions: subscriptions, pending_save: Task::ready(Ok(())), ref_count, }, ); self.update_available_commands_for_project(project_id, cx); acp_thread } pub fn models(&self) -> &LanguageModels { &self.models } fn get_or_create_project_state( &mut self, project: &Entity, cx: &mut Context, ) -> EntityId { let project_id = project.entity_id(); if self.projects.contains_key(&project_id) { return project_id; } let project_context = cx.new(|_| ProjectContext::new(vec![], vec![])); self.register_project_with_initial_context(project.clone(), project_context, cx); if let Some(state) = self.projects.get_mut(&project_id) { state.project_context_needs_refresh.send(()).ok(); } project_id } fn register_project_with_initial_context( &mut self, project: Entity, project_context: Entity, cx: &mut Context, ) { let project_id = project.entity_id(); let context_server_store = project.read(cx).context_server_store(); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let subscriptions = vec![ cx.subscribe(&project, Self::handle_project_event), cx.subscribe( &context_server_store, Self::handle_context_server_store_updated, ), cx.subscribe( &context_server_registry, Self::handle_context_server_registry_event, ), ]; let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = watch::channel(()); self.projects.insert( project_id, ProjectState { project, project_context, project_context_needs_refresh: project_context_needs_refresh_tx, _maintain_project_context: cx.spawn(async move |this, cx| { Self::maintain_project_context( this, project_id, project_context_needs_refresh_rx, cx, ) .await }), context_server_registry, _subscriptions: subscriptions, }, ); } fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> { self.sessions .get(session_id) .and_then(|session| self.projects.get(&session.project_id)) } async fn maintain_project_context( this: WeakEntity, project_id: EntityId, mut needs_refresh: watch::Receiver<()>, cx: &mut AsyncApp, ) -> Result<()> { while needs_refresh.changed().await.is_ok() { let project_context = this .update(cx, |this, cx| { let state = this .projects .get(&project_id) .context("project state not found")?; anyhow::Ok(Self::build_project_context( &state.project, this.prompt_store.as_ref(), cx, )) })?? .await; this.update(cx, |this, cx| { if let Some(state) = this.projects.get(&project_id) { state .project_context .update(cx, |current_project_context, _cx| { *current_project_context = project_context; }); } })?; } Ok(()) } fn build_project_context( project: &Entity, prompt_store: Option<&Entity>, cx: &mut App, ) -> Task { let worktrees = project.read(cx).visible_worktrees(cx).collect::>(); let worktree_tasks = worktrees .into_iter() .map(|worktree| { Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx) }) .collect::>(); let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() { prompt_store.read_with(cx, |prompt_store, cx| { let prompts = prompt_store.default_prompt_metadata(); let load_tasks = prompts.into_iter().map(|prompt_metadata| { let contents = prompt_store.load(prompt_metadata.id, cx); async move { (contents.await, prompt_metadata) } }); cx.background_spawn(future::join_all(load_tasks)) }) } else { Task::ready(vec![]) }; cx.spawn(async move |_cx| { let (worktrees, default_user_rules) = future::join(future::join_all(worktree_tasks), default_user_rules_task).await; let worktrees = worktrees .into_iter() .map(|(worktree, _rules_error)| { // TODO: show error message // if let Some(rules_error) = rules_error { // this.update(cx, |_, cx| cx.emit(rules_error)).ok(); // } worktree }) .collect::>(); let default_user_rules = default_user_rules .into_iter() .flat_map(|(contents, prompt_metadata)| match contents { Ok(contents) => Some(UserRulesContext { uuid: prompt_metadata.id.as_user()?, title: prompt_metadata.title.map(|title| title.to_string()), contents, }), Err(_err) => { // TODO: show error message // this.update(cx, |_, cx| { // cx.emit(RulesLoadingError { // message: format!("{err:?}").into(), // }); // }) // .ok(); None } }) .collect::>(); ProjectContext::new(worktrees, default_user_rules) }) } fn load_worktree_info_for_system_prompt( worktree: Entity, project: Entity, cx: &mut App, ) -> Task<(WorktreeContext, Option)> { let tree = worktree.read(cx); let root_name = tree.root_name_str().into(); let abs_path = tree.abs_path(); let scan_complete = tree.as_local().map(|local| local.scan_complete()); let mut context = WorktreeContext { root_name, abs_path, rules_file: None, }; cx.spawn(async move |cx| { if let Some(scan_complete) = scan_complete { scan_complete.await; } let rules_task = cx.update(|cx| Self::load_worktree_rules_file(worktree, project, cx)); let (rules_file, rules_file_error) = match rules_task { Some(rules_task) => match rules_task.await { Ok(rules_file) => (Some(rules_file), None), Err(err) => ( None, Some(RulesLoadingError { message: format!("{err}").into(), }), ), }, None => (None, None), }; context.rules_file = rules_file; (context, rules_file_error) }) } fn load_worktree_rules_file( worktree: Entity, project: Entity, cx: &mut App, ) -> Option>> { let worktree = worktree.read(cx); let worktree_id = worktree.id(); let selected_rules_file = RULES_FILE_NAMES .into_iter() .filter_map(|name| { worktree .entry_for_path(RelPath::unix(name).unwrap()) .filter(|entry| entry.is_file()) .map(|entry| entry.path.clone()) }) .next(); // Note that Cline supports `.clinerules` being a directory, but that is not currently // supported. This doesn't seem to occur often in GitHub repositories. selected_rules_file.map(|path_in_worktree| { let project_path = ProjectPath { worktree_id, path: path_in_worktree.clone(), }; let buffer_task = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); let rope_task = cx.spawn(async move |cx| { let buffer = buffer_task.await?; let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| { let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; anyhow::Ok((project_entry_id, buffer.as_rope().clone())) })?; anyhow::Ok((project_entry_id, rope)) }); // Build a string from the rope on a background thread. cx.background_spawn(async move { let (project_entry_id, rope) = rope_task.await?; anyhow::Ok(RulesFileContext { path_in_worktree, text: rope.to_string().trim().to_string(), project_entry_id: project_entry_id.to_usize(), }) }) }) } fn handle_thread_title_updated( &mut self, thread: Entity, _: &TitleUpdated, cx: &mut Context, ) { let session_id = thread.read(cx).id(); let Some(session) = self.sessions.get(session_id) else { return; }; let thread = thread.downgrade(); let acp_thread = session.acp_thread.downgrade(); cx.spawn(async move |_, cx| { let title = thread.read_with(cx, |thread, _| thread.title())?; if let Some(title) = title { let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?; task.await?; } anyhow::Ok(()) }) .detach_and_log_err(cx); } fn handle_thread_token_usage_updated( &mut self, thread: Entity, usage: &TokenUsageUpdated, cx: &mut Context, ) { let Some(session) = self.sessions.get(thread.read(cx).id()) else { return; }; session.acp_thread.update(cx, |acp_thread, cx| { acp_thread.update_token_usage(usage.0.clone(), cx); }); } fn handle_project_event( &mut self, project: Entity, event: &project::Event, _cx: &mut Context, ) { let project_id = project.entity_id(); let Some(state) = self.projects.get_mut(&project_id) else { return; }; match event { project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { state.project_context_needs_refresh.send(()).ok(); } project::Event::WorktreeUpdatedEntries(_, items) => { if items.iter().any(|(path, _, _)| { RULES_FILE_NAMES .iter() .any(|name| path.as_ref() == RelPath::unix(name).unwrap()) }) { state.project_context_needs_refresh.send(()).ok(); } } _ => {} } } fn handle_prompts_updated_event( &mut self, _prompt_store: Entity, _event: &prompt_store::PromptsUpdatedEvent, _cx: &mut Context, ) { for state in self.projects.values_mut() { state.project_context_needs_refresh.send(()).ok(); } } fn handle_models_updated_event( &mut self, _registry: Entity, event: &language_model::Event, cx: &mut Context, ) { self.models.refresh_list(cx); let registry = LanguageModelRegistry::read_global(cx); let default_model = registry.default_model().map(|m| m.model); let summarization_model = registry.thread_summary_model(cx).map(|m| m.model); for session in self.sessions.values_mut() { session.thread.update(cx, |thread, cx| { if thread.model().is_none() && let Some(model) = default_model.clone() { thread.set_model(model, cx); cx.notify(); } if let Some(model) = summarization_model.clone() { if thread.summarization_model().is_none() || matches!(event, language_model::Event::ThreadSummaryModelChanged) { thread.set_summarization_model(Some(model), cx); } } }); } } fn handle_context_server_store_updated( &mut self, store: Entity, _event: &project::context_server_store::ServerStatusChangedEvent, cx: &mut Context, ) { let project_id = self.projects.iter().find_map(|(id, state)| { if *state.context_server_registry.read(cx).server_store() == store { Some(*id) } else { None } }); if let Some(project_id) = project_id { self.update_available_commands_for_project(project_id, cx); } } fn handle_context_server_registry_event( &mut self, registry: Entity, event: &ContextServerRegistryEvent, cx: &mut Context, ) { match event { ContextServerRegistryEvent::ToolsChanged => {} ContextServerRegistryEvent::PromptsChanged => { let project_id = self.projects.iter().find_map(|(id, state)| { if state.context_server_registry == registry { Some(*id) } else { None } }); if let Some(project_id) = project_id { self.update_available_commands_for_project(project_id, cx); } } } } fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context) { let available_commands = Self::build_available_commands_for_project(self.projects.get(&project_id), cx); for session in self.sessions.values() { if session.project_id != project_id { continue; } session.acp_thread.update(cx, |thread, cx| { thread .handle_session_update( acp::SessionUpdate::AvailableCommandsUpdate( acp::AvailableCommandsUpdate::new(available_commands.clone()), ), cx, ) .log_err(); }); } } fn build_available_commands_for_project( project_state: Option<&ProjectState>, cx: &App, ) -> Vec { let Some(state) = project_state else { return vec![]; }; let registry = state.context_server_registry.read(cx); let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default(); for context_server_prompt in registry.prompts() { *prompt_name_counts .entry(context_server_prompt.prompt.name.as_str()) .or_insert(0) += 1; } registry .prompts() .flat_map(|context_server_prompt| { let prompt = &context_server_prompt.prompt; let should_prefix = prompt_name_counts .get(prompt.name.as_str()) .copied() .unwrap_or(0) > 1; let name = if should_prefix { format!("{}.{}", context_server_prompt.server_id, prompt.name) } else { prompt.name.clone() }; let mut command = acp::AvailableCommand::new( name, prompt.description.clone().unwrap_or_default(), ); match prompt.arguments.as_deref() { Some([arg]) => { let hint = format!("<{}>", arg.name); command = command.input(acp::AvailableCommandInput::Unstructured( acp::UnstructuredCommandInput::new(hint), )); } Some([]) | None => {} Some(_) => { // skip >1 argument commands since we don't support them yet return None; } } Some(command) }) .collect() } pub fn load_thread( &mut self, id: acp::SessionId, project: Entity, cx: &mut Context, ) -> Task>> { let database_future = ThreadsDatabase::connect(cx); cx.spawn(async move |this, cx| { let database = database_future.await.map_err(|err| anyhow!(err))?; let db_thread = database .load_thread(id.clone()) .await? .with_context(|| format!("no thread found with ID: {id:?}"))?; this.update(cx, |this, cx| { let project_id = this.get_or_create_project_state(&project, cx); let project_state = this .projects .get(&project_id) .context("project state not found")?; let summarization_model = LanguageModelRegistry::read_global(cx) .thread_summary_model(cx) .map(|c| c.model); Ok(cx.new(|cx| { let mut thread = Thread::from_db( id.clone(), db_thread, project_state.project.clone(), project_state.project_context.clone(), project_state.context_server_registry.clone(), this.templates.clone(), cx, ); thread.set_summarization_model(summarization_model, cx); thread })) })? }) } pub fn open_thread( &mut self, id: acp::SessionId, project: Entity, cx: &mut Context, ) -> Task>> { if let Some(session) = self.sessions.get_mut(&id) { session.ref_count += 1; return Task::ready(Ok(session.acp_thread.clone())); } if let Some(pending) = self.pending_sessions.get_mut(&id) { pending.ref_count += 1; let task = pending.task.clone(); return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) }); } let task = self.load_thread(id.clone(), project.clone(), cx); let shared_task = cx .spawn({ let id = id.clone(); async move |this, cx| { let thread = match task.await { Ok(thread) => thread, Err(err) => { this.update(cx, |this, _cx| { this.pending_sessions.remove(&id); }) .ok(); return Err(Arc::new(err)); } }; let acp_thread = this .update(cx, |this, cx| { let project_id = this.get_or_create_project_state(&project, cx); let ref_count = this .pending_sessions .remove(&id) .map_or(1, |pending| pending.ref_count); this.register_session(thread.clone(), project_id, ref_count, cx) }) .map_err(Arc::new)?; let events = thread.update(cx, |thread, cx| thread.replay(cx)); cx.update(|cx| { NativeAgentConnection::handle_thread_events( events, acp_thread.downgrade(), cx, ) }) .await .map_err(Arc::new)?; acp_thread.update(cx, |thread, cx| { thread.snapshot_completed_plan(cx); }); Ok(acp_thread) } }) .shared(); self.pending_sessions.insert( id, PendingSession { task: shared_task.clone(), ref_count: 1, }, ); cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) }) } pub fn thread_summary( &mut self, id: acp::SessionId, project: Entity, cx: &mut Context, ) -> Task> { let thread = self.open_thread(id.clone(), project, cx); cx.spawn(async move |this, cx| { let acp_thread = thread.await?; let result = this .update(cx, |this, cx| { this.sessions .get(&id) .unwrap() .thread .update(cx, |thread, cx| thread.summary(cx)) })? .await .context("Failed to generate summary")?; this.update(cx, |this, cx| this.close_session(&id, cx))? .await?; drop(acp_thread); Ok(result) }) } fn close_session( &mut self, session_id: &acp::SessionId, cx: &mut Context, ) -> Task> { let Some(session) = self.sessions.get_mut(session_id) else { return Task::ready(Ok(())); }; session.ref_count -= 1; if session.ref_count > 0 { return Task::ready(Ok(())); } let thread = session.thread.clone(); self.save_thread(thread, cx); let Some(session) = self.sessions.remove(session_id) else { return Task::ready(Ok(())); }; let project_id = session.project_id; let has_remaining = self.sessions.values().any(|s| s.project_id == project_id); if !has_remaining { self.projects.remove(&project_id); } session.pending_save } fn save_thread(&mut self, thread: Entity, cx: &mut Context) { if thread.read(cx).is_empty() { return; } let id = thread.read(cx).id().clone(); let Some(session) = self.sessions.get_mut(&id) else { return; }; let project_id = session.project_id; let Some(state) = self.projects.get(&project_id) else { return; }; let folder_paths = PathList::new( &state .project .read(cx) .visible_worktrees(cx) .map(|worktree| worktree.read(cx).abs_path().to_path_buf()) .collect::>(), ); let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from); let database_future = ThreadsDatabase::connect(cx); let db_thread = thread.update(cx, |thread, cx| { thread.set_draft_prompt(draft_prompt); thread.to_db(cx) }); let thread_store = self.thread_store.clone(); session.pending_save = cx.spawn(async move |_, cx| { let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { return Ok(()); }; let db_thread = db_thread.await; database .save_thread(id, db_thread, folder_paths) .await .log_err(); thread_store.update(cx, |store, cx| store.reload(cx)); Ok(()) }); } fn send_mcp_prompt( &self, message_id: UserMessageId, session_id: acp::SessionId, prompt_name: String, server_id: ContextServerId, arguments: HashMap, original_content: Vec, cx: &mut Context, ) -> Task> { let Some(state) = self.session_project_state(&session_id) else { return Task::ready(Err(anyhow!("Project state not found for session"))); }; let server_store = state .context_server_registry .read(cx) .server_store() .clone(); let path_style = state.project.read(cx).path_style(cx); cx.spawn(async move |this, cx| { let prompt = crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?; let (acp_thread, thread) = this.update(cx, |this, _cx| { let session = this .sessions .get(&session_id) .context("Failed to get session")?; anyhow::Ok((session.acp_thread.clone(), session.thread.clone())) })??; let mut last_is_user = true; thread.update(cx, |thread, cx| { thread.push_acp_user_block( message_id, original_content.into_iter().skip(1), path_style, cx, ); }); for message in prompt.messages { let context_server::types::PromptMessage { role, content } = message; let block = mcp_message_content_to_acp_content_block(content); match role { context_server::types::Role::User => { let id = acp_thread::UserMessageId::new(); acp_thread.update(cx, |acp_thread, cx| { acp_thread.push_user_content_block_with_indent( Some(id.clone()), block.clone(), true, cx, ); }); thread.update(cx, |thread, cx| { thread.push_acp_user_block(id, [block], path_style, cx); }); } context_server::types::Role::Assistant => { acp_thread.update(cx, |acp_thread, cx| { acp_thread.push_assistant_content_block_with_indent( block.clone(), false, true, cx, ); }); thread.update(cx, |thread, cx| { thread.push_acp_agent_block(block, cx); }); } } last_is_user = role == context_server::types::Role::User; } let response_stream = thread.update(cx, |thread, cx| { if last_is_user { thread.send_existing(cx) } else { // Resume if MCP prompt did not end with a user message thread.resume(cx) } })?; cx.update(|cx| { NativeAgentConnection::handle_thread_events( response_stream, acp_thread.downgrade(), cx, ) }) .await }) } } /// Wrapper struct that implements the AgentConnection trait #[derive(Clone)] pub struct NativeAgentConnection(pub Entity); impl NativeAgentConnection { pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option> { self.0 .read(cx) .sessions .get(session_id) .map(|session| session.thread.clone()) } pub fn load_thread( &self, id: acp::SessionId, project: Entity, cx: &mut App, ) -> Task>> { self.0 .update(cx, |this, cx| this.load_thread(id, project, cx)) } fn run_turn( &self, session_id: acp::SessionId, cx: &mut App, f: impl 'static + FnOnce(Entity, &mut App) -> Result>>, ) -> Task> { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { agent .sessions .get_mut(&session_id) .map(|s| (s.thread.clone(), s.acp_thread.clone())) }) else { log::error!("Session not found in run_turn: {}", session_id); return Task::ready(Err(anyhow!("Session not found"))); }; log::debug!("Found session for: {}", session_id); let response_stream = match f(thread, cx) { Ok(stream) => stream, Err(err) => return Task::ready(Err(err)), }; Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx) } fn handle_thread_events( mut events: mpsc::UnboundedReceiver>, acp_thread: WeakEntity, cx: &App, ) -> Task> { cx.spawn(async move |cx| { // Handle response stream and forward to session.acp_thread while let Some(result) = events.next().await { match result { Ok(event) => { log::trace!("Received completion event: {:?}", event); match event { ThreadEvent::UserMessage(message) => { acp_thread.update(cx, |thread, cx| { for content in message.content { thread.push_user_content_block( Some(message.id.clone()), content.into(), cx, ); } })?; } ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block(text.into(), false, cx) })?; } ThreadEvent::AgentThinking(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block(text.into(), true, cx) })?; } ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { tool_call, options, response, context: _, }) => { let outcome_task = acp_thread.update(cx, |thread, cx| { thread.request_tool_call_authorization(tool_call, options, cx) })??; cx.background_spawn(async move { if let acp_thread::RequestPermissionOutcome::Selected(outcome) = outcome_task.await { response .send(outcome) .map_err(|_| { anyhow!("authorization receiver was dropped") }) .log_err(); } }) .detach(); } ThreadEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) })??; } ThreadEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { thread.update_tool_call(update, cx) })??; } ThreadEvent::Plan(plan) => { acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?; } ThreadEvent::SubagentSpawned(session_id) => { acp_thread.update(cx, |thread, cx| { thread.subagent_spawned(session_id, cx); })?; } ThreadEvent::Retry(status) => { acp_thread.update(cx, |thread, cx| { thread.update_retry_status(status, cx) })?; } ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse::new(stop_reason)); } } } Err(e) => { log::error!("Error in model response stream: {:?}", e); return Err(e); } } } log::debug!("Response stream completed"); anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) }) } } struct Command<'a> { prompt_name: &'a str, arg_value: &'a str, explicit_server_id: Option<&'a str>, } impl<'a> Command<'a> { fn parse(prompt: &'a [acp::ContentBlock]) -> Option { let acp::ContentBlock::Text(text_content) = prompt.first()? else { return None; }; let text = text_content.text.trim(); let command = text.strip_prefix('/')?; let (command, arg_value) = command .split_once(char::is_whitespace) .unwrap_or((command, "")); if let Some((server_id, prompt_name)) = command.split_once('.') { Some(Self { prompt_name, arg_value, explicit_server_id: Some(server_id), }) } else { Some(Self { prompt_name: command, arg_value, explicit_server_id: None, }) } } } struct NativeAgentModelSelector { session_id: acp::SessionId, connection: NativeAgentConnection, } impl acp_thread::AgentModelSelector for NativeAgentModelSelector { fn list_models(&self, cx: &mut App) -> Task> { log::debug!("NativeAgentConnection::list_models called"); let list = self.connection.0.read(cx).models.model_list.clone(); Task::ready(if list.is_empty() { Err(anyhow::anyhow!("No models available")) } else { Ok(list) }) } fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task> { log::debug!( "Setting model for session {}: {}", self.session_id, model_id ); let Some(thread) = self .connection .0 .read(cx) .sessions .get(&self.session_id) .map(|session| session.thread.clone()) else { return Task::ready(Err(anyhow!("Session not found"))); }; let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; let favorite = agent_settings::AgentSettings::get_global(cx) .favorite_models .iter() .find(|favorite| { favorite.provider.0 == model.provider_id().0.as_ref() && favorite.model == model.id().0.as_ref() }) .cloned(); let LanguageModelSelection { enable_thinking, effort, speed, .. } = agent_settings::language_model_to_selection(&model, favorite.as_ref()); thread.update(cx, |thread, cx| { thread.set_model(model.clone(), cx); thread.set_thinking_effort(effort.clone(), cx); thread.set_thinking_enabled(enable_thinking, cx); if let Some(speed) = speed { thread.set_speed(speed, cx); } }); update_settings_file( self.connection.0.read(cx).fs.clone(), cx, move |settings, cx| { let provider = model.provider_id().0.to_string(); let model = model.id().0.to_string(); let enable_thinking = thread.read(cx).thinking_enabled(); let speed = thread.read(cx).speed(); settings .agent .get_or_insert_default() .set_model(LanguageModelSelection { provider: provider.into(), model, enable_thinking, effort, speed, }); }, ); Task::ready(Ok(())) } fn selected_model(&self, cx: &mut App) -> Task> { let Some(thread) = self .connection .0 .read(cx) .sessions .get(&self.session_id) .map(|session| session.thread.clone()) else { return Task::ready(Err(anyhow!("Session not found"))); }; let Some(model) = thread.read(cx).model() else { return Task::ready(Err(anyhow!("Model not found"))); }; let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) else { return Task::ready(Err(anyhow!("Provider not found"))); }; Task::ready(Ok(LanguageModels::map_language_model_to_info( model, &provider, ))) } fn watch(&self, cx: &mut App) -> Option> { Some(self.connection.0.read(cx).models.watch()) } fn should_render_footer(&self) -> bool { true } } pub static ZED_AGENT_ID: LazyLock = LazyLock::new(|| AgentId::new("Zed Agent")); impl acp_thread::AgentConnection for NativeAgentConnection { fn agent_id(&self) -> AgentId { ZED_AGENT_ID.clone() } fn telemetry_id(&self) -> SharedString { "zed".into() } fn new_session( self: Rc, project: Entity, work_dirs: PathList, cx: &mut App, ) -> Task>> { log::debug!("Creating new thread for project at: {work_dirs:?}"); Task::ready(Ok(self .0 .update(cx, |agent, cx| agent.new_session(project, cx)))) } fn supports_load_session(&self) -> bool { true } fn load_session( self: Rc, session_id: acp::SessionId, project: Entity, _work_dirs: PathList, _title: Option, cx: &mut App, ) -> Task>> { self.0 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx)) } fn supports_close_session(&self) -> bool { true } fn close_session( self: Rc, session_id: &acp::SessionId, cx: &mut App, ) -> Task> { self.0 .update(cx, |agent, cx| agent.close_session(session_id, cx)) } fn auth_methods(&self) -> &[acp::AuthMethod] { &[] // No auth for in-process } fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { Task::ready(Ok(())) } fn model_selector(&self, session_id: &acp::SessionId) -> Option> { Some(Rc::new(NativeAgentModelSelector { session_id: session_id.clone(), connection: self.clone(), }) as Rc) } fn prompt( &self, id: acp_thread::UserMessageId, params: acp::PromptRequest, cx: &mut App, ) -> Task> { let session_id = params.session_id.clone(); log::info!("Received prompt request for session: {}", session_id); log::debug!("Prompt blocks count: {}", params.prompt.len()); let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else { log::error!("Session not found in prompt: {}", session_id); if self.0.read(cx).sessions.contains_key(&session_id) { log::error!( "Session found in sessions map, but not in project state: {}", session_id ); } return Task::ready(Err(anyhow::anyhow!("Session not found"))); }; if let Some(parsed_command) = Command::parse(¶ms.prompt) { let registry = project_state.context_server_registry.read(cx); let explicit_server_id = parsed_command .explicit_server_id .map(|server_id| ContextServerId(server_id.into())); if let Some(prompt) = registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name) { let arguments = if !parsed_command.arg_value.is_empty() && let Some(arg_name) = prompt .prompt .arguments .as_ref() .and_then(|args| args.first()) .map(|arg| arg.name.clone()) { HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())]) } else { Default::default() }; let prompt_name = prompt.prompt.name.clone(); let server_id = prompt.server_id.clone(); return self.0.update(cx, |agent, cx| { agent.send_mcp_prompt( id, session_id.clone(), prompt_name, server_id, arguments, params.prompt, cx, ) }); } }; let path_style = project_state.project.read(cx).path_style(cx); self.run_turn(session_id, cx, move |thread, cx| { let content: Vec = params .prompt .into_iter() .map(|block| UserMessageContent::from_content_block(block, path_style)) .collect::>(); log::debug!("Converted prompt to message: {} chars", content.len()); log::debug!("Message id: {:?}", id); log::debug!("Message content: {:?}", content); thread.update(cx, |thread, cx| thread.send(id, content, cx)) }) } fn retry( &self, session_id: &acp::SessionId, _cx: &App, ) -> Option> { Some(Rc::new(NativeAgentSessionRetry { connection: self.clone(), session_id: session_id.clone(), }) as _) } fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(session) = agent.sessions.get(session_id) { session .thread .update(cx, |thread, cx| thread.cancel(cx)) .detach(); } }); } fn truncate( &self, session_id: &acp::SessionId, cx: &App, ) -> Option> { self.0.read_with(cx, |agent, _cx| { agent.sessions.get(session_id).map(|session| { Rc::new(NativeAgentSessionTruncate { thread: session.thread.clone(), acp_thread: session.acp_thread.downgrade(), }) as _ }) }) } fn set_title( &self, session_id: &acp::SessionId, cx: &App, ) -> Option> { self.0.read_with(cx, |agent, _cx| { agent .sessions .get(session_id) .filter(|s| !s.thread.read(cx).is_subagent()) .map(|session| { Rc::new(NativeAgentSessionSetTitle { thread: session.thread.clone(), }) as _ }) }) } fn session_list(&self, cx: &mut App) -> Option> { let thread_store = self.0.read(cx).thread_store.clone(); Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _) } fn telemetry(&self) -> Option> { Some(Rc::new(self.clone()) as Rc) } fn into_any(self: Rc) -> Rc { self } } impl acp_thread::AgentTelemetry for NativeAgentConnection { fn thread_data( &self, session_id: &acp::SessionId, cx: &mut App, ) -> Task> { let Some(session) = self.0.read(cx).sessions.get(session_id) else { return Task::ready(Err(anyhow!("Session not found"))); }; let task = session.thread.read(cx).to_db(cx); cx.background_spawn(async move { serde_json::to_value(task.await).context("Failed to serialize thread") }) } } pub struct NativeAgentSessionList { thread_store: Entity, updates_tx: async_channel::Sender, updates_rx: async_channel::Receiver, _subscription: Subscription, } impl NativeAgentSessionList { fn new(thread_store: Entity, cx: &mut App) -> Self { let (tx, rx) = async_channel::unbounded(); let this_tx = tx.clone(); let subscription = cx.observe(&thread_store, move |_, _| { this_tx .try_send(acp_thread::SessionListUpdate::Refresh) .ok(); }); Self { thread_store, updates_tx: tx, updates_rx: rx, _subscription: subscription, } } pub fn thread_store(&self) -> &Entity { &self.thread_store } } impl AgentSessionList for NativeAgentSessionList { fn list_sessions( &self, _request: AgentSessionListRequest, cx: &mut App, ) -> Task> { let sessions = self .thread_store .read(cx) .entries() .map(|entry| AgentSessionInfo::from(&entry)) .collect(); Task::ready(Ok(AgentSessionListResponse::new(sessions))) } fn supports_delete(&self) -> bool { true } fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task> { self.thread_store .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx)) } fn delete_sessions(&self, cx: &mut App) -> Task> { self.thread_store .update(cx, |store, cx| store.delete_threads(cx)) } fn watch( &self, _cx: &mut App, ) -> Option> { Some(self.updates_rx.clone()) } fn notify_refresh(&self) { self.updates_tx .try_send(acp_thread::SessionListUpdate::Refresh) .ok(); } fn into_any(self: Rc) -> Rc { self } } struct NativeAgentSessionTruncate { thread: Entity, acp_thread: WeakEntity, } impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate { fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { match self.thread.update(cx, |thread, cx| { thread.truncate(message_id.clone(), cx)?; Ok(thread.latest_token_usage()) }) { Ok(usage) => { self.acp_thread .update(cx, |thread, cx| { thread.update_token_usage(usage, cx); }) .ok(); Task::ready(Ok(())) } Err(error) => Task::ready(Err(error)), } } } struct NativeAgentSessionRetry { connection: NativeAgentConnection, session_id: acp::SessionId, } impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry { fn run(&self, cx: &mut App) -> Task> { self.connection .run_turn(self.session_id.clone(), cx, |thread, cx| { thread.update(cx, |thread, cx| thread.resume(cx)) }) } } struct NativeAgentSessionSetTitle { thread: Entity, } impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { fn run(&self, title: SharedString, cx: &mut App) -> Task> { self.thread .update(cx, |thread, cx| thread.set_title(title, cx)); Task::ready(Ok(())) } } pub struct NativeThreadEnvironment { agent: WeakEntity, thread: WeakEntity, acp_thread: WeakEntity, } impl NativeThreadEnvironment { pub(crate) fn create_subagent_thread( &self, label: String, cx: &mut App, ) -> Result> { let Some(parent_thread_entity) = self.thread.upgrade() else { anyhow::bail!("Parent thread no longer exists".to_string()); }; let parent_thread = parent_thread_entity.read(cx); let current_depth = parent_thread.depth(); let parent_session_id = parent_thread.id().clone(); if current_depth >= MAX_SUBAGENT_DEPTH { return Err(anyhow!( "Maximum subagent depth ({}) reached", MAX_SUBAGENT_DEPTH )); } let subagent_thread: Entity = cx.new(|cx| { let mut thread = Thread::new_subagent(&parent_thread_entity, cx); thread.set_title(label.into(), cx); thread }); let session_id = subagent_thread.read(cx).id().clone(); let acp_thread = self .agent .update(cx, |agent, cx| -> Result> { let project_id = agent .sessions .get(&parent_session_id) .map(|s| s.project_id) .context("parent session not found")?; Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx)) })??; let depth = current_depth + 1; telemetry::event!( "Subagent Started", session = parent_thread_entity.read(cx).id().to_string(), subagent_session = session_id.to_string(), depth, is_resumed = false, ); self.prompt_subagent(session_id, subagent_thread, acp_thread) } pub(crate) fn resume_subagent_thread( &self, session_id: acp::SessionId, cx: &mut App, ) -> Result> { let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| { let session = agent .sessions .get(&session_id) .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?; anyhow::Ok((session.thread.clone(), session.acp_thread.clone())) })??; let depth = subagent_thread.read(cx).depth(); if let Some(parent_thread_entity) = self.thread.upgrade() { telemetry::event!( "Subagent Started", session = parent_thread_entity.read(cx).id().to_string(), subagent_session = session_id.to_string(), depth, is_resumed = true, ); } self.prompt_subagent(session_id, subagent_thread, acp_thread) } fn prompt_subagent( &self, session_id: acp::SessionId, subagent_thread: Entity, acp_thread: Entity, ) -> Result> { let Some(parent_thread_entity) = self.thread.upgrade() else { anyhow::bail!("Parent thread no longer exists".to_string()); }; Ok(Rc::new(NativeSubagentHandle::new( session_id, subagent_thread, acp_thread, parent_thread_entity, )) as _) } } impl ThreadEnvironment for NativeThreadEnvironment { fn create_terminal( &self, command: String, cwd: Option, output_byte_limit: Option, cx: &mut AsyncApp, ) -> Task>> { let task = self.acp_thread.update(cx, |thread, cx| { thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx) }); let acp_thread = self.acp_thread.clone(); cx.spawn(async move |cx| { let terminal = task?.await?; let (drop_tx, drop_rx) = oneshot::channel(); let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone()); cx.spawn(async move |cx| { drop_rx.await.ok(); acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx)) }) .detach(); let handle = AcpTerminalHandle { terminal, _drop_tx: Some(drop_tx), }; Ok(Rc::new(handle) as _) }) } fn create_subagent(&self, label: String, cx: &mut App) -> Result> { self.create_subagent_thread(label, cx) } fn resume_subagent( &self, session_id: acp::SessionId, cx: &mut App, ) -> Result> { self.resume_subagent_thread(session_id, cx) } } #[derive(Debug, Clone)] enum SubagentPromptResult { Completed, Cancelled, ContextWindowWarning, Error(String), } pub struct NativeSubagentHandle { session_id: acp::SessionId, parent_thread: WeakEntity, subagent_thread: Entity, acp_thread: Entity, } impl NativeSubagentHandle { fn new( session_id: acp::SessionId, subagent_thread: Entity, acp_thread: Entity, parent_thread_entity: Entity, ) -> Self { NativeSubagentHandle { session_id, subagent_thread, parent_thread: parent_thread_entity.downgrade(), acp_thread, } } } impl SubagentHandle for NativeSubagentHandle { fn id(&self) -> acp::SessionId { self.session_id.clone() } fn num_entries(&self, cx: &App) -> usize { self.acp_thread.read(cx).entries().len() } fn send(&self, message: String, cx: &AsyncApp) -> Task> { let thread = self.subagent_thread.clone(); let acp_thread = self.acp_thread.clone(); let subagent_session_id = self.session_id.clone(); let parent_thread = self.parent_thread.clone(); cx.spawn(async move |cx| { let (task, _subscription) = cx.update(|cx| { let ratio_before_prompt = thread .read(cx) .latest_token_usage() .map(|usage| usage.ratio()); parent_thread .update(cx, |parent_thread, _cx| { parent_thread.register_running_subagent(thread.downgrade()) }) .ok(); let task = acp_thread.update(cx, |acp_thread, cx| { acp_thread.send(vec![message.into()], cx) }); let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>(); let mut token_limit_tx = Some(token_limit_tx); let subscription = cx.subscribe( &thread, move |_thread, event: &TokenUsageUpdated, _cx| { if let Some(usage) = &event.0 { let old_ratio = ratio_before_prompt .clone() .unwrap_or(TokenUsageRatio::Normal); let new_ratio = usage.ratio(); if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning { if let Some(tx) = token_limit_tx.take() { tx.send(()).ok(); } } } }, ); let wait_for_prompt = cx .background_spawn(async move { futures::select! { response = task.fuse() => match response { Ok(Some(response)) => { match response.stop_reason { acp::StopReason::Cancelled => SubagentPromptResult::Cancelled, acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()), acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()), acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()), acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed, } } Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()), Err(error) => SubagentPromptResult::Error(error.to_string()), }, _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning, } }); (wait_for_prompt, subscription) }); let result = match task.await { SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| { thread .last_message() .and_then(|message| { let content = message.as_agent_message()? .content .iter() .filter_map(|c| match c { AgentMessageContent::Text(text) => Some(text.as_str()), _ => None, }) .join("\n\n"); if content.is_empty() { None } else { Some( content) } }) .context("No response from subagent") }), SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")), SubagentPromptResult::Error(message) => Err(anyhow!("{message}")), SubagentPromptResult::ContextWindowWarning => { thread.update(cx, |thread, cx| thread.cancel(cx)).await; Err(anyhow!( "The agent is nearing the end of its context window and has been \ stopped. You can prompt the thread again to have the agent wrap up \ or hand off its work." )) } }; parent_thread .update(cx, |parent_thread, cx| { parent_thread.unregister_running_subagent(&subagent_session_id, cx) }) .ok(); result }) } } pub struct AcpTerminalHandle { terminal: Entity, _drop_tx: Option>, } impl TerminalHandle for AcpTerminalHandle { fn id(&self, cx: &AsyncApp) -> Result { Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone())) } fn wait_for_exit(&self, cx: &AsyncApp) -> Result>> { Ok(self .terminal .read_with(cx, |term, _cx| term.wait_for_exit())) } fn current_output(&self, cx: &AsyncApp) -> Result { Ok(self .terminal .read_with(cx, |term, cx| term.current_output(cx))) } fn kill(&self, cx: &AsyncApp) -> Result<()> { cx.update(|cx| { self.terminal.update(cx, |terminal, cx| { terminal.kill(cx); }); }); Ok(()) } fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result { Ok(self .terminal .read_with(cx, |term, _cx| term.was_stopped_by_user())) } } #[cfg(test)] mod internal_tests { use std::path::Path; use super::*; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri}; use fs::FakeFs; use gpui::TestAppContext; use indoc::formatdoc; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use language_model::{ LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName, }; use serde_json::json; use settings::SettingsStore; use util::{path, rel_path::rel_path}; #[gpui::test] async fn test_maintaining_project_context(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/", json!({ "a": {} }), ) .await; let project = Project::test(fs.clone(), [], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)); // Creating a session registers the project and triggers context building. let connection = NativeAgentConnection(agent.clone()); let _acp_thread = cx .update(|cx| { Rc::new(connection).new_session( project.clone(), PathList::new(&[Path::new("/")]), cx, ) }) .await .unwrap(); cx.run_until_parked(); let thread = agent.read_with(cx, |agent, _cx| { agent.sessions.values().next().unwrap().thread.clone() }); agent.read_with(cx, |agent, cx| { let project_id = project.entity_id(); let state = agent.projects.get(&project_id).unwrap(); assert_eq!(state.project_context.read(cx).worktrees, vec![]); assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]); }); let worktree = project .update(cx, |project, cx| project.create_worktree("/a", true, cx)) .await .unwrap(); cx.run_until_parked(); agent.read_with(cx, |agent, cx| { let project_id = project.entity_id(); let state = agent.projects.get(&project_id).unwrap(); let expected_worktrees = vec![WorktreeContext { root_name: "a".into(), abs_path: Path::new("/a").into(), rules_file: None, }]; assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees); assert_eq!( thread.read(cx).project_context().read(cx).worktrees, expected_worktrees ); }); // Creating `/a/.rules` updates the project context. fs.insert_file("/a/.rules", Vec::new()).await; cx.run_until_parked(); agent.read_with(cx, |agent, cx| { let project_id = project.entity_id(); let state = agent.projects.get(&project_id).unwrap(); let rules_entry = worktree .read(cx) .entry_for_path(rel_path(".rules")) .unwrap(); let expected_worktrees = vec![WorktreeContext { root_name: "a".into(), abs_path: Path::new("/a").into(), rules_file: Some(RulesFileContext { path_in_worktree: rel_path(".rules").into(), text: "".into(), project_entry_id: rules_entry.id.to_usize(), }), }]; assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees); assert_eq!( thread.read(cx).project_context().read(cx).worktrees, expected_worktrees ); }); } #[gpui::test] async fn test_listing_models(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let connection = NativeAgentConnection(cx.update(|cx| { NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx) })); // Create a thread/session let acp_thread = cx .update(|cx| { Rc::new(connection.clone()).new_session( project.clone(), PathList::new(&[Path::new("/a")]), cx, ) }) .await .unwrap(); let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); let models = cx .update(|cx| { connection .model_selector(&session_id) .unwrap() .list_models(cx) }) .await .unwrap(); let acp_thread::AgentModelList::Grouped(models) = models else { panic!("Unexpected model group"); }; assert_eq!( models, IndexMap::from_iter([( AgentModelGroupName("Fake".into()), vec![AgentModelInfo { id: acp::ModelId::new("fake/fake"), name: "Fake".into(), description: None, icon: Some(acp_thread::AgentModelIcon::Named( ui::IconName::ZedAssistant )), is_latest: false, cost: None, }] )]) ); } #[gpui::test] async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.create_dir(paths::settings_file().parent().unwrap()) .await .unwrap(); fs.insert_file( paths::settings_file(), json!({ "agent": { "default_model": { "provider": "foo", "model": "bar" } } }) .to_string() .into_bytes(), ) .await; let project = Project::test(fs.clone(), [], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); // Create the agent and connection let agent = cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)); let connection = NativeAgentConnection(agent.clone()); // Create a thread/session let acp_thread = cx .update(|cx| { Rc::new(connection.clone()).new_session( project.clone(), PathList::new(&[Path::new("/a")]), cx, ) }) .await .unwrap(); let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); // Select a model let selector = connection.model_selector(&session_id).unwrap(); let model_id = acp::ModelId::new("fake/fake"); cx.update(|cx| selector.select_model(model_id.clone(), cx)) .await .unwrap(); // Verify the thread has the selected model agent.read_with(cx, |agent, _| { let session = agent.sessions.get(&session_id).unwrap(); session.thread.read_with(cx, |thread, _| { assert_eq!(thread.model().unwrap().id().0, "fake"); }); }); cx.run_until_parked(); // Verify settings file was updated let settings_content = fs.load(paths::settings_file()).await.unwrap(); let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); // Check that the agent settings contain the selected model assert_eq!( settings_json["agent"]["default_model"]["model"], json!("fake") ); assert_eq!( settings_json["agent"]["default_model"]["provider"], json!("fake") ); // Register a thinking model and select it. cx.update(|cx| { let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking( "fake-corp", "fake-thinking", "Fake Thinking", true, )); let thinking_provider = Arc::new( FakeLanguageModelProvider::new( LanguageModelProviderId::from("fake-corp".to_string()), LanguageModelProviderName::from("Fake Corp".to_string()), ) .with_models(vec![thinking_model]), ); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.register_provider(thinking_provider, cx); }); }); agent.update(cx, |agent, cx| agent.models.refresh_list(cx)); let selector = connection.model_selector(&session_id).unwrap(); cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx)) .await .unwrap(); cx.run_until_parked(); // Verify enable_thinking was written to settings as true. let settings_content = fs.load(paths::settings_file()).await.unwrap(); let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); assert_eq!( settings_json["agent"]["default_model"]["enable_thinking"], json!(true), "selecting a thinking model should persist enable_thinking: true to settings" ); } #[gpui::test] async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.create_dir(paths::settings_file().parent().unwrap()) .await .unwrap(); fs.insert_file(paths::settings_file(), b"{}".to_vec()).await; let project = Project::test(fs.clone(), [], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)); let connection = NativeAgentConnection(agent.clone()); let acp_thread = cx .update(|cx| { Rc::new(connection.clone()).new_session( project.clone(), PathList::new(&[Path::new("/a")]), cx, ) }) .await .unwrap(); let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); // Register a second provider with a thinking model. cx.update(|cx| { let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking( "fake-corp", "fake-thinking", "Fake Thinking", true, )); let thinking_provider = Arc::new( FakeLanguageModelProvider::new( LanguageModelProviderId::from("fake-corp".to_string()), LanguageModelProviderName::from("Fake Corp".to_string()), ) .with_models(vec![thinking_model]), ); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.register_provider(thinking_provider, cx); }); }); // Refresh the agent's model list so it picks up the new provider. agent.update(cx, |agent, cx| agent.models.refresh_list(cx)); // Thread starts with thinking_enabled = false (the default). agent.read_with(cx, |agent, _| { let session = agent.sessions.get(&session_id).unwrap(); session.thread.read_with(cx, |thread, _| { assert!(!thread.thinking_enabled(), "thinking defaults to false"); }); }); // Select the thinking model via select_model. let selector = connection.model_selector(&session_id).unwrap(); cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx)) .await .unwrap(); // select_model should have enabled thinking based on the model's supports_thinking(). agent.read_with(cx, |agent, _| { let session = agent.sessions.get(&session_id).unwrap(); session.thread.read_with(cx, |thread, _| { assert!( thread.thinking_enabled(), "select_model should enable thinking when model supports it" ); }); }); // Switch back to the non-thinking model. let selector = connection.model_selector(&session_id).unwrap(); cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx)) .await .unwrap(); // select_model should have disabled thinking. agent.read_with(cx, |agent, _| { let session = agent.sessions.get(&session_id).unwrap(); session.thread.read_with(cx, |thread, _| { assert!( !thread.thinking_enabled(), "select_model should disable thinking when model does not support it" ); }); }); } #[gpui::test] async fn test_summarization_model_survives_transient_registry_clearing( cx: &mut TestAppContext, ) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)); let connection = Rc::new(NativeAgentConnection(agent.clone())); let acp_thread = cx .update(|cx| { connection.clone().new_session( project.clone(), PathList::new(&[Path::new("/a")]), cx, ) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); thread.read_with(cx, |thread, _| { assert!( thread.summarization_model().is_some(), "session should have a summarization model from the test registry" ); }); // Simulate what happens during a provider blip: // update_active_language_model_from_settings calls set_default_model(None) // when it can't resolve the model, clearing all fallbacks. cx.update(|cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model(None, cx); }); }); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert!( thread.summarization_model().is_some(), "summarization model should survive a transient default model clearing" ); }); } #[gpui::test] async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); // Register a thinking model. let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking( "fake-corp", "fake-thinking", "Fake Thinking", true, )); let thinking_provider = Arc::new( FakeLanguageModelProvider::new( LanguageModelProviderId::from("fake-corp".to_string()), LanguageModelProviderName::from("Fake Corp".to_string()), ) .with_models(vec![thinking_model.clone()]), ); cx.update(|cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.register_provider(thinking_provider, cx); }); }); agent.update(cx, |agent, cx| agent.models.refresh_list(cx)); // Create a thread and select the thinking model. let acp_thread = cx .update(|cx| { connection.clone().new_session( project.clone(), PathList::new(&[Path::new("/a")]), cx, ) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let selector = connection.model_selector(&session_id).unwrap(); cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx)) .await .unwrap(); // Verify thinking is enabled after selecting the thinking model. let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); thread.read_with(cx, |thread, _| { assert!( thread.thinking_enabled(), "thinking should be enabled after selecting thinking model" ); }); // Send a message so the thread gets persisted. let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx)); let send = cx.foreground_executor().spawn(send); cx.run_until_parked(); thinking_model.send_last_completion_stream_text_chunk("Response."); thinking_model.end_last_completion_stream(); send.await.unwrap(); cx.run_until_parked(); // Close the session so it can be reloaded from disk. cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); drop(thread); drop(acp_thread); agent.read_with(cx, |agent, _| { assert!(agent.sessions.is_empty()); }); // Reload the thread and verify thinking_enabled is still true. let reloaded_acp_thread = agent .update(cx, |agent, cx| { agent.open_thread(session_id.clone(), project.clone(), cx) }) .await .unwrap(); let reloaded_thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); reloaded_thread.read_with(cx, |thread, _| { assert!( thread.thinking_enabled(), "thinking_enabled should be preserved when reloading a thread with a thinking model" ); }); drop(reloaded_acp_thread); } #[gpui::test] async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); // Register a model where id() != name(), like real Anthropic models // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking"). let model = Arc::new(FakeLanguageModel::with_id_and_thinking( "fake-corp", "custom-model-id", "Custom Model Display Name", false, )); let provider = Arc::new( FakeLanguageModelProvider::new( LanguageModelProviderId::from("fake-corp".to_string()), LanguageModelProviderName::from("Fake Corp".to_string()), ) .with_models(vec![model.clone()]), ); cx.update(|cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.register_provider(provider, cx); }); }); agent.update(cx, |agent, cx| agent.models.refresh_list(cx)); // Create a thread and select the model. let acp_thread = cx .update(|cx| { connection.clone().new_session( project.clone(), PathList::new(&[Path::new("/a")]), cx, ) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let selector = connection.model_selector(&session_id).unwrap(); cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx)) .await .unwrap(); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); thread.read_with(cx, |thread, _| { assert_eq!( thread.model().unwrap().id().0.as_ref(), "custom-model-id", "model should be set before persisting" ); }); // Send a message so the thread gets persisted. let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx)); let send = cx.foreground_executor().spawn(send); cx.run_until_parked(); model.send_last_completion_stream_text_chunk("Response."); model.end_last_completion_stream(); send.await.unwrap(); cx.run_until_parked(); // Close the session so it can be reloaded from disk. cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); drop(thread); drop(acp_thread); agent.read_with(cx, |agent, _| { assert!(agent.sessions.is_empty()); }); // Reload the thread and verify the model was preserved. let reloaded_acp_thread = agent .update(cx, |agent, cx| { agent.open_thread(session_id.clone(), project.clone(), cx) }) .await .unwrap(); let reloaded_thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); reloaded_thread.read_with(cx, |thread, _| { let reloaded_model = thread .model() .expect("model should be present after reload"); assert_eq!( reloaded_model.id().0.as_ref(), "custom-model-id", "reloaded thread should have the same model, not fall back to the default" ); }); drop(reloaded_acp_thread); } #[gpui::test] async fn test_save_load_thread(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/", json!({ "a": { "b.md": "Lorem" } }), ) .await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); let acp_thread = cx .update(|cx| { connection .clone() .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); // Ensure empty threads are not saved, even if they get mutated. let model = Arc::new(FakeLanguageModel::default()); let summary_model = Arc::new(FakeLanguageModel::default()); thread.update(cx, |thread, cx| { thread.set_model(model.clone(), cx); thread.set_summarization_model(Some(summary_model.clone()), cx); }); cx.run_until_parked(); assert_eq!(thread_entries(&thread_store, cx), vec![]); let send = acp_thread.update(cx, |thread, cx| { thread.send( vec![ "What does ".into(), acp::ContentBlock::ResourceLink(acp::ResourceLink::new( "b.md", MentionUri::File { abs_path: path!("/a/b.md").into(), } .to_uri() .to_string(), )), " mean?".into(), ], cx, ) }); let send = cx.foreground_executor().spawn(send); cx.run_until_parked(); model.send_last_completion_stream_text_chunk("Lorem."); model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( language_model::TokenUsage { input_tokens: 150, output_tokens: 75, ..Default::default() }, )); model.end_last_completion_stream(); cx.run_until_parked(); summary_model .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md"))); summary_model.end_last_completion_stream(); send.await.unwrap(); let uri = MentionUri::File { abs_path: path!("/a/b.md").into(), } .to_uri(); acp_thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), formatdoc! {" ## User What does [@b.md]({uri}) mean? ## Assistant Lorem. "} ) }); cx.run_until_parked(); // Set a draft prompt with rich content blocks and scroll position // AFTER run_until_parked, so the only save that captures these // changes is the one performed by close_session itself. let draft_blocks = vec![ acp::ContentBlock::Text(acp::TextContent::new("Check out ")), acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())), acp::ContentBlock::Text(acp::TextContent::new(" please")), ]; acp_thread.update(cx, |thread, cx| { thread.set_draft_prompt(Some(draft_blocks.clone()), cx); }); thread.update(cx, |thread, _cx| { thread.set_ui_scroll_position(Some(gpui::ListOffset { item_ix: 5, offset_in_item: gpui::px(12.5), })); }); // Close the session so it can be reloaded from disk. cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); drop(thread); drop(acp_thread); agent.read_with(cx, |agent, _| { assert_eq!(agent.sessions.keys().cloned().collect::>(), []); }); // Ensure the thread can be reloaded from disk. assert_eq!( thread_entries(&thread_store, cx), vec![( session_id.clone(), format!("Explaining {}", path!("/a/b.md")) )] ); let acp_thread = agent .update(cx, |agent, cx| { agent.open_thread(session_id.clone(), project.clone(), cx) }) .await .unwrap(); acp_thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), formatdoc! {" ## User What does [@b.md]({uri}) mean? ## Assistant Lorem. "} ) }); // Ensure the draft prompt with rich content blocks survived the round-trip. acp_thread.read_with(cx, |thread, _| { assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice())); }); // Ensure token usage survived the round-trip. acp_thread.read_with(cx, |thread, _| { let usage = thread .token_usage() .expect("token usage should be restored after reload"); assert_eq!(usage.input_tokens, 150); assert_eq!(usage.output_tokens, 75); }); // Ensure scroll position survived the round-trip. acp_thread.read_with(cx, |thread, _| { let scroll = thread .ui_scroll_position() .expect("scroll position should be restored after reload"); assert_eq!(scroll.item_ix, 5); assert_eq!(scroll.offset_in_item, gpui::px(12.5)); }); } #[gpui::test] async fn test_close_session_saves_thread(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/", json!({ "a": { "file.txt": "hello" } }), ) .await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); let acp_thread = cx .update(|cx| { connection .clone() .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); let model = Arc::new(FakeLanguageModel::default()); thread.update(cx, |thread, cx| { thread.set_model(model.clone(), cx); }); // Send a message so the thread is non-empty (empty threads aren't saved). let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); let send = cx.foreground_executor().spawn(send); cx.run_until_parked(); model.send_last_completion_stream_text_chunk("world"); model.end_last_completion_stream(); send.await.unwrap(); cx.run_until_parked(); // Set a draft prompt WITHOUT calling run_until_parked afterwards. // This means no observe-triggered save has run for this change. // The only way this data gets persisted is if close_session // itself performs the save. let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new( "unsaved draft", ))]; acp_thread.update(cx, |thread, cx| { thread.set_draft_prompt(Some(draft_blocks.clone()), cx); }); // Close the session immediately — no run_until_parked in between. cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); cx.run_until_parked(); // Reopen and verify the draft prompt was saved. let reloaded = agent .update(cx, |agent, cx| { agent.open_thread(session_id.clone(), project.clone(), cx) }) .await .unwrap(); reloaded.read_with(cx, |thread, _| { assert_eq!( thread.draft_prompt(), Some(draft_blocks.as_slice()), "close_session must save the thread; draft prompt was lost" ); }); } #[gpui::test] async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/", json!({ "a": { "file.txt": "hello" } }), ) .await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); let acp_thread = cx .update(|cx| { connection .clone() .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); let model = Arc::new(FakeLanguageModel::default()); let summary_model = Arc::new(FakeLanguageModel::default()); thread.update(cx, |thread, cx| { thread.set_model(model.clone(), cx); thread.set_summarization_model(Some(summary_model.clone()), cx); }); let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); let send = cx.foreground_executor().spawn(send); cx.run_until_parked(); model.send_last_completion_stream_text_chunk("world"); model.end_last_completion_stream(); send.await.unwrap(); cx.run_until_parked(); let summary = agent.update(cx, |agent, cx| { agent.thread_summary(session_id.clone(), project.clone(), cx) }); cx.run_until_parked(); summary_model.send_last_completion_stream_text_chunk("summary"); summary_model.end_last_completion_stream(); assert_eq!(summary.await.unwrap(), "summary"); cx.run_until_parked(); agent.read_with(cx, |agent, _| { let session = agent .sessions .get(&session_id) .expect("thread_summary should not close the active session"); assert_eq!( session.ref_count, 1, "thread_summary should release its temporary session reference" ); }); cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); cx.run_until_parked(); agent.read_with(cx, |agent, _| { assert!( agent.sessions.is_empty(), "closing the active session after thread_summary should unload it" ); }); } #[gpui::test] async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( "/", json!({ "a": { "file.txt": "hello" } }), ) .await; let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); let acp_thread = cx .update(|cx| { connection .clone() .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); let model = cx.update(|cx| { LanguageModelRegistry::read_global(cx) .default_model() .map(|default_model| default_model.model) .expect("default test model should be available") }); let fake_model = model.as_fake(); thread.update(cx, |thread, cx| { thread.set_model(model.clone(), cx); }); let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); let send = cx.foreground_executor().spawn(send); cx.run_until_parked(); fake_model.send_last_completion_stream_text_chunk("world"); fake_model.end_last_completion_stream(); send.await.unwrap(); cx.run_until_parked(); cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); drop(thread); drop(acp_thread); agent.read_with(cx, |agent, _| { assert!(agent.sessions.is_empty()); }); let first_loaded_thread = cx.update(|cx| { connection.clone().load_session( session_id.clone(), project.clone(), PathList::new(&[Path::new("")]), None, cx, ) }); let second_loaded_thread = cx.update(|cx| { connection.clone().load_session( session_id.clone(), project.clone(), PathList::new(&[Path::new("")]), None, cx, ) }); let first_loaded_thread = first_loaded_thread.await.unwrap(); let second_loaded_thread = second_loaded_thread.await.unwrap(); cx.run_until_parked(); assert_eq!( first_loaded_thread.entity_id(), second_loaded_thread.entity_id(), "concurrent loads for the same session should share one AcpThread" ); cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); agent.read_with(cx, |agent, _| { assert!( agent.sessions.contains_key(&session_id), "closing one loaded session should not drop shared session state" ); }); let follow_up = second_loaded_thread.update(cx, |thread, cx| { thread.send(vec!["still there?".into()], cx) }); let follow_up = cx.foreground_executor().spawn(follow_up); cx.run_until_parked(); fake_model.send_last_completion_stream_text_chunk("yes"); fake_model.end_last_completion_stream(); follow_up.await.unwrap(); cx.run_until_parked(); second_loaded_thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), formatdoc! {" ## User hello ## Assistant world ## User still there? ## Assistant yes "} ); }); cx.update(|cx| connection.clone().close_session(&session_id, cx)) .await .unwrap(); cx.run_until_parked(); drop(first_loaded_thread); drop(second_loaded_thread); agent.read_with(cx, |agent, _| { assert!(agent.sessions.is_empty()); }); } #[gpui::test] async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) { // Regression test: rapid title changes must not cause a propagation loop // between Thread and AcpThread via handle_thread_title_updated. init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree("/", json!({ "a": {} })).await; let project = Project::test(fs.clone(), [], cx).await; let thread_store = cx.new(|cx| ThreadStore::new(cx)); let agent = cx.update(|cx| { NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) }); let connection = Rc::new(NativeAgentConnection(agent.clone())); let acp_thread = cx .update(|cx| { connection .clone() .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) }) .await .unwrap(); let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let thread = agent.read_with(cx, |agent, _| { agent.sessions.get(&session_id).unwrap().thread.clone() }); let title_updated_count = Rc::new(std::cell::RefCell::new(0usize)); cx.update(|cx| { let count = title_updated_count.clone(); cx.subscribe( &thread, move |_entity: Entity, _event: &TitleUpdated, _cx: &mut App| { let new_count = { let mut count = count.borrow_mut(); *count += 1; *count }; assert!( new_count <= 2, "TitleUpdated fired {new_count} times; \ title updates are looping" ); }, ) .detach(); }); thread.update(cx, |thread, cx| thread.set_title("first".into(), cx)); thread.update(cx, |thread, cx| thread.set_title("second".into(), cx)); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!(thread.title(), Some("second".into())); }); acp_thread.read_with(cx, |acp_thread, _| { assert_eq!(acp_thread.title(), Some("second".into())); }); assert_eq!(*title_updated_count.borrow(), 2); } fn thread_entries( thread_store: &Entity, cx: &mut TestAppContext, ) -> Vec<(acp::SessionId, String)> { thread_store.read_with(cx, |store, _| { store .entries() .map(|entry| (entry.id.clone(), entry.title.to_string())) .collect::>() }) } fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); LanguageModelRegistry::test(cx); }); } } fn mcp_message_content_to_acp_content_block( content: context_server::types::MessageContent, ) -> acp::ContentBlock { match content { context_server::types::MessageContent::Text { text, annotations: _, } => text.into(), context_server::types::MessageContent::Image { data, mime_type, annotations: _, } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)), context_server::types::MessageContent::Audio { data, mime_type, annotations: _, } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)), context_server::types::MessageContent::Resource { resource, annotations: _, } => { let mut link = acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string()); if let Some(mime_type) = resource.mime_type { link = link.mime_type(mime_type); } acp::ContentBlock::ResourceLink(link) } } }