mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - agent: Improve reliability when LLM edits file
3457 lines
121 KiB
Rust
3457 lines
121 KiB
Rust
mod db;
|
|
mod legacy_thread;
|
|
mod native_agent_server;
|
|
pub mod outline;
|
|
mod pattern_extraction;
|
|
mod templates;
|
|
#[cfg(test)]
|
|
mod tests;
|
|
mod thread;
|
|
mod thread_store;
|
|
mod tool_permissions;
|
|
mod tools;
|
|
|
|
use context_server::ContextServerId;
|
|
pub use db::*;
|
|
use itertools::Itertools;
|
|
pub use native_agent_server::NativeAgentServer;
|
|
pub use pattern_extraction::*;
|
|
pub use shell_command_parser::extract_commands;
|
|
pub use templates::*;
|
|
pub use thread::*;
|
|
pub use thread_store::*;
|
|
pub use tool_permissions::*;
|
|
pub use tools::*;
|
|
|
|
use acp_thread::{
|
|
AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
|
|
AgentSessionListResponse, TokenUsageRatio, UserMessageId,
|
|
};
|
|
use agent_client_protocol::schema as acp;
|
|
use anyhow::{Context as _, Result, anyhow};
|
|
use chrono::{DateTime, Utc};
|
|
use collections::{HashMap, HashSet, IndexMap};
|
|
use fs::Fs;
|
|
use futures::channel::{mpsc, oneshot};
|
|
use futures::future::Shared;
|
|
use futures::{FutureExt as _, StreamExt as _, future};
|
|
use gpui::{
|
|
App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
|
|
WeakEntity,
|
|
};
|
|
use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
|
use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
|
|
use prompt_store::{
|
|
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
|
|
WorktreeContext,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use settings::{LanguageModelSelection, Settings as _, update_settings_file};
|
|
use std::any::Any;
|
|
use std::path::PathBuf;
|
|
use std::rc::Rc;
|
|
use std::sync::{Arc, LazyLock};
|
|
use util::ResultExt;
|
|
use util::path_list::PathList;
|
|
use util::rel_path::RelPath;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct ProjectSnapshot {
|
|
pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
|
|
pub timestamp: DateTime<Utc>,
|
|
}
|
|
|
|
pub struct RulesLoadingError {
|
|
pub message: SharedString,
|
|
}
|
|
|
|
struct ProjectState {
|
|
project: Entity<Project>,
|
|
project_context: Entity<ProjectContext>,
|
|
project_context_needs_refresh: watch::Sender<()>,
|
|
_maintain_project_context: Task<Result<()>>,
|
|
context_server_registry: Entity<ContextServerRegistry>,
|
|
_subscriptions: Vec<Subscription>,
|
|
}
|
|
|
|
/// Holds both the internal Thread and the AcpThread for a session
|
|
struct Session {
|
|
/// The internal thread that processes messages
|
|
thread: Entity<Thread>,
|
|
/// The ACP thread that handles protocol communication
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
project_id: EntityId,
|
|
pending_save: Task<Result<()>>,
|
|
_subscriptions: Vec<Subscription>,
|
|
ref_count: usize,
|
|
}
|
|
|
|
struct PendingSession {
|
|
task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
|
|
ref_count: usize,
|
|
}
|
|
|
|
pub struct LanguageModels {
|
|
/// Access language model by ID
|
|
models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
|
|
/// Cached list for returning language model information
|
|
model_list: acp_thread::AgentModelList,
|
|
refresh_models_rx: watch::Receiver<()>,
|
|
refresh_models_tx: watch::Sender<()>,
|
|
_authenticate_all_providers_task: Task<()>,
|
|
}
|
|
|
|
impl LanguageModels {
|
|
fn new(cx: &mut App) -> Self {
|
|
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
|
|
|
|
let mut this = Self {
|
|
models: HashMap::default(),
|
|
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
|
|
refresh_models_rx,
|
|
refresh_models_tx,
|
|
_authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
|
|
};
|
|
this.refresh_list(cx);
|
|
this
|
|
}
|
|
|
|
fn refresh_list(&mut self, cx: &App) {
|
|
let providers = LanguageModelRegistry::global(cx)
|
|
.read(cx)
|
|
.visible_providers()
|
|
.into_iter()
|
|
.filter(|provider| provider.is_authenticated(cx))
|
|
.collect::<Vec<_>>();
|
|
|
|
let mut language_model_list = IndexMap::default();
|
|
let mut recommended_models = HashSet::default();
|
|
|
|
let mut recommended = Vec::new();
|
|
for provider in &providers {
|
|
for model in provider.recommended_models(cx) {
|
|
recommended_models.insert((model.provider_id(), model.id()));
|
|
recommended.push(Self::map_language_model_to_info(&model, provider));
|
|
}
|
|
}
|
|
if !recommended.is_empty() {
|
|
language_model_list.insert(
|
|
acp_thread::AgentModelGroupName("Recommended".into()),
|
|
recommended,
|
|
);
|
|
}
|
|
|
|
let mut models = HashMap::default();
|
|
for provider in providers {
|
|
let mut provider_models = Vec::new();
|
|
for model in provider.provided_models(cx) {
|
|
let model_info = Self::map_language_model_to_info(&model, &provider);
|
|
let model_id = model_info.id.clone();
|
|
provider_models.push(model_info);
|
|
models.insert(model_id, model);
|
|
}
|
|
if !provider_models.is_empty() {
|
|
language_model_list.insert(
|
|
acp_thread::AgentModelGroupName(provider.name().0.clone()),
|
|
provider_models,
|
|
);
|
|
}
|
|
}
|
|
|
|
self.models = models;
|
|
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
|
|
self.refresh_models_tx.send(()).ok();
|
|
}
|
|
|
|
fn watch(&self) -> watch::Receiver<()> {
|
|
self.refresh_models_rx.clone()
|
|
}
|
|
|
|
pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
|
|
self.models.get(model_id).cloned()
|
|
}
|
|
|
|
fn map_language_model_to_info(
|
|
model: &Arc<dyn LanguageModel>,
|
|
provider: &Arc<dyn LanguageModelProvider>,
|
|
) -> acp_thread::AgentModelInfo {
|
|
acp_thread::AgentModelInfo {
|
|
id: Self::model_id(model),
|
|
name: model.name().0,
|
|
description: None,
|
|
icon: Some(match provider.icon() {
|
|
IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
|
|
IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
|
|
}),
|
|
is_latest: model.is_latest(),
|
|
cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
|
|
}
|
|
}
|
|
|
|
fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
|
|
acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
|
|
}
|
|
|
|
fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
|
|
let authenticate_all_providers = LanguageModelRegistry::global(cx)
|
|
.read(cx)
|
|
.visible_providers()
|
|
.iter()
|
|
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
|
|
.collect::<Vec<_>>();
|
|
|
|
cx.spawn(async move |cx| {
|
|
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
|
|
if let Err(err) = authenticate_task.await {
|
|
match err {
|
|
language_model::AuthenticateError::CredentialsNotFound => {
|
|
// Since we're authenticating these providers in the
|
|
// background for the purposes of populating the
|
|
// language selector, we don't care about providers
|
|
// where the credentials are not found.
|
|
}
|
|
language_model::AuthenticateError::ConnectionRefused => {
|
|
// Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
|
|
// LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
|
|
// TODO: Better manage LM Studio auth logic to avoid these noisy failures.
|
|
}
|
|
_ => {
|
|
// Some providers have noisy failure states that we
|
|
// don't want to spam the logs with every time the
|
|
// language model selector is initialized.
|
|
//
|
|
// Ideally these should have more clear failure modes
|
|
// that we know are safe to ignore here, like what we do
|
|
// with `CredentialsNotFound` above.
|
|
match provider_id.0.as_ref() {
|
|
"lmstudio" | "ollama" => {
|
|
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
|
|
//
|
|
// These fail noisily, so we don't log them.
|
|
}
|
|
"copilot_chat" => {
|
|
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
|
|
}
|
|
_ => {
|
|
log::error!(
|
|
"Failed to authenticate provider: {}: {err:#}",
|
|
provider_name.0
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
cx.update(language_models::update_environment_fallback_model);
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct NativeAgent {
|
|
/// Session ID -> Session mapping
|
|
sessions: HashMap<acp::SessionId, Session>,
|
|
pending_sessions: HashMap<acp::SessionId, PendingSession>,
|
|
thread_store: Entity<ThreadStore>,
|
|
/// Project-specific state keyed by project EntityId
|
|
projects: HashMap<EntityId, ProjectState>,
|
|
/// Shared templates for all threads
|
|
templates: Arc<Templates>,
|
|
/// Cached model information
|
|
models: LanguageModels,
|
|
prompt_store: Option<Entity<PromptStore>>,
|
|
fs: Arc<dyn Fs>,
|
|
_subscriptions: Vec<Subscription>,
|
|
}
|
|
|
|
impl NativeAgent {
|
|
pub fn new(
|
|
thread_store: Entity<ThreadStore>,
|
|
templates: Arc<Templates>,
|
|
prompt_store: Option<Entity<PromptStore>>,
|
|
fs: Arc<dyn Fs>,
|
|
cx: &mut App,
|
|
) -> Entity<NativeAgent> {
|
|
log::debug!("Creating new NativeAgent");
|
|
|
|
cx.new(|cx| {
|
|
let mut subscriptions = vec![cx.subscribe(
|
|
&LanguageModelRegistry::global(cx),
|
|
Self::handle_models_updated_event,
|
|
)];
|
|
if let Some(prompt_store) = prompt_store.as_ref() {
|
|
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
|
|
}
|
|
|
|
Self {
|
|
sessions: HashMap::default(),
|
|
pending_sessions: HashMap::default(),
|
|
thread_store,
|
|
projects: HashMap::default(),
|
|
templates,
|
|
models: LanguageModels::new(cx),
|
|
prompt_store,
|
|
fs,
|
|
_subscriptions: subscriptions,
|
|
}
|
|
})
|
|
}
|
|
|
|
fn new_session(
|
|
&mut self,
|
|
project: Entity<Project>,
|
|
cx: &mut Context<Self>,
|
|
) -> Entity<AcpThread> {
|
|
let project_id = self.get_or_create_project_state(&project, cx);
|
|
let project_state = &self.projects[&project_id];
|
|
|
|
let registry = LanguageModelRegistry::read_global(cx);
|
|
let available_count = registry.available_models(cx).count();
|
|
log::debug!("Total available models: {}", available_count);
|
|
|
|
let default_model = registry.default_model().and_then(|default_model| {
|
|
self.models
|
|
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
|
});
|
|
let thread = cx.new(|cx| {
|
|
Thread::new(
|
|
project,
|
|
project_state.project_context.clone(),
|
|
project_state.context_server_registry.clone(),
|
|
self.templates.clone(),
|
|
default_model,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
self.register_session(thread, project_id, 1, cx)
|
|
}
|
|
|
|
fn register_session(
|
|
&mut self,
|
|
thread_handle: Entity<Thread>,
|
|
project_id: EntityId,
|
|
ref_count: usize,
|
|
cx: &mut Context<Self>,
|
|
) -> Entity<AcpThread> {
|
|
let connection = Rc::new(NativeAgentConnection(cx.entity()));
|
|
|
|
let thread = thread_handle.read(cx);
|
|
let session_id = thread.id().clone();
|
|
let parent_session_id = thread.parent_thread_id();
|
|
let title = thread.title();
|
|
let draft_prompt = thread.draft_prompt().map(Vec::from);
|
|
let scroll_position = thread.ui_scroll_position();
|
|
let token_usage = thread.latest_token_usage();
|
|
let project = thread.project.clone();
|
|
let action_log = thread.action_log.clone();
|
|
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
|
|
let acp_thread = cx.new(|cx| {
|
|
let mut acp_thread = acp_thread::AcpThread::new(
|
|
parent_session_id,
|
|
title,
|
|
None,
|
|
connection,
|
|
project.clone(),
|
|
action_log.clone(),
|
|
session_id.clone(),
|
|
prompt_capabilities_rx,
|
|
cx,
|
|
);
|
|
acp_thread.set_draft_prompt(draft_prompt, cx);
|
|
acp_thread.set_ui_scroll_position(scroll_position);
|
|
acp_thread.update_token_usage(token_usage, cx);
|
|
acp_thread
|
|
});
|
|
|
|
let registry = LanguageModelRegistry::read_global(cx);
|
|
let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
|
|
|
|
let weak = cx.weak_entity();
|
|
let weak_thread = thread_handle.downgrade();
|
|
thread_handle.update(cx, |thread, cx| {
|
|
thread.set_summarization_model(summarization_model, cx);
|
|
thread.add_default_tools(
|
|
Rc::new(NativeThreadEnvironment {
|
|
acp_thread: acp_thread.downgrade(),
|
|
thread: weak_thread,
|
|
agent: weak,
|
|
}) as _,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let subscriptions = vec![
|
|
cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
|
|
cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
|
|
cx.observe(&thread_handle, move |this, thread, cx| {
|
|
this.save_thread(thread, cx)
|
|
}),
|
|
];
|
|
|
|
self.sessions.insert(
|
|
session_id,
|
|
Session {
|
|
thread: thread_handle,
|
|
acp_thread: acp_thread.clone(),
|
|
project_id,
|
|
_subscriptions: subscriptions,
|
|
pending_save: Task::ready(Ok(())),
|
|
ref_count,
|
|
},
|
|
);
|
|
|
|
self.update_available_commands_for_project(project_id, cx);
|
|
|
|
acp_thread
|
|
}
|
|
|
|
pub fn models(&self) -> &LanguageModels {
|
|
&self.models
|
|
}
|
|
|
|
fn get_or_create_project_state(
|
|
&mut self,
|
|
project: &Entity<Project>,
|
|
cx: &mut Context<Self>,
|
|
) -> EntityId {
|
|
let project_id = project.entity_id();
|
|
if self.projects.contains_key(&project_id) {
|
|
return project_id;
|
|
}
|
|
|
|
let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
|
|
self.register_project_with_initial_context(project.clone(), project_context, cx);
|
|
if let Some(state) = self.projects.get_mut(&project_id) {
|
|
state.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
project_id
|
|
}
|
|
|
|
fn register_project_with_initial_context(
|
|
&mut self,
|
|
project: Entity<Project>,
|
|
project_context: Entity<ProjectContext>,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
let project_id = project.entity_id();
|
|
|
|
let context_server_store = project.read(cx).context_server_store();
|
|
let context_server_registry =
|
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
|
|
|
let subscriptions = vec![
|
|
cx.subscribe(&project, Self::handle_project_event),
|
|
cx.subscribe(
|
|
&context_server_store,
|
|
Self::handle_context_server_store_updated,
|
|
),
|
|
cx.subscribe(
|
|
&context_server_registry,
|
|
Self::handle_context_server_registry_event,
|
|
),
|
|
];
|
|
|
|
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
|
|
watch::channel(());
|
|
|
|
self.projects.insert(
|
|
project_id,
|
|
ProjectState {
|
|
project,
|
|
project_context,
|
|
project_context_needs_refresh: project_context_needs_refresh_tx,
|
|
_maintain_project_context: cx.spawn(async move |this, cx| {
|
|
Self::maintain_project_context(
|
|
this,
|
|
project_id,
|
|
project_context_needs_refresh_rx,
|
|
cx,
|
|
)
|
|
.await
|
|
}),
|
|
context_server_registry,
|
|
_subscriptions: subscriptions,
|
|
},
|
|
);
|
|
}
|
|
|
|
fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
|
|
self.sessions
|
|
.get(session_id)
|
|
.and_then(|session| self.projects.get(&session.project_id))
|
|
}
|
|
|
|
async fn maintain_project_context(
|
|
this: WeakEntity<Self>,
|
|
project_id: EntityId,
|
|
mut needs_refresh: watch::Receiver<()>,
|
|
cx: &mut AsyncApp,
|
|
) -> Result<()> {
|
|
while needs_refresh.changed().await.is_ok() {
|
|
let project_context = this
|
|
.update(cx, |this, cx| {
|
|
let state = this
|
|
.projects
|
|
.get(&project_id)
|
|
.context("project state not found")?;
|
|
anyhow::Ok(Self::build_project_context(
|
|
&state.project,
|
|
this.prompt_store.as_ref(),
|
|
cx,
|
|
))
|
|
})??
|
|
.await;
|
|
this.update(cx, |this, cx| {
|
|
if let Some(state) = this.projects.get(&project_id) {
|
|
state
|
|
.project_context
|
|
.update(cx, |current_project_context, _cx| {
|
|
*current_project_context = project_context;
|
|
});
|
|
}
|
|
})?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn build_project_context(
|
|
project: &Entity<Project>,
|
|
prompt_store: Option<&Entity<PromptStore>>,
|
|
cx: &mut App,
|
|
) -> Task<ProjectContext> {
|
|
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
|
|
let worktree_tasks = worktrees
|
|
.into_iter()
|
|
.map(|worktree| {
|
|
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
|
|
prompt_store.read_with(cx, |prompt_store, cx| {
|
|
let prompts = prompt_store.default_prompt_metadata();
|
|
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
|
|
let contents = prompt_store.load(prompt_metadata.id, cx);
|
|
async move { (contents.await, prompt_metadata) }
|
|
});
|
|
cx.background_spawn(future::join_all(load_tasks))
|
|
})
|
|
} else {
|
|
Task::ready(vec![])
|
|
};
|
|
|
|
cx.spawn(async move |_cx| {
|
|
let (worktrees, default_user_rules) =
|
|
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
|
|
|
|
let worktrees = worktrees
|
|
.into_iter()
|
|
.map(|(worktree, _rules_error)| {
|
|
// TODO: show error message
|
|
// if let Some(rules_error) = rules_error {
|
|
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
|
|
// }
|
|
worktree
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let default_user_rules = default_user_rules
|
|
.into_iter()
|
|
.flat_map(|(contents, prompt_metadata)| match contents {
|
|
Ok(contents) => Some(UserRulesContext {
|
|
uuid: prompt_metadata.id.as_user()?,
|
|
title: prompt_metadata.title.map(|title| title.to_string()),
|
|
contents,
|
|
}),
|
|
Err(_err) => {
|
|
// TODO: show error message
|
|
// this.update(cx, |_, cx| {
|
|
// cx.emit(RulesLoadingError {
|
|
// message: format!("{err:?}").into(),
|
|
// });
|
|
// })
|
|
// .ok();
|
|
None
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
ProjectContext::new(worktrees, default_user_rules)
|
|
})
|
|
}
|
|
|
|
fn load_worktree_info_for_system_prompt(
|
|
worktree: Entity<Worktree>,
|
|
project: Entity<Project>,
|
|
cx: &mut App,
|
|
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
|
let tree = worktree.read(cx);
|
|
let root_name = tree.root_name_str().into();
|
|
let abs_path = tree.abs_path();
|
|
let scan_complete = tree.as_local().map(|local| local.scan_complete());
|
|
|
|
let mut context = WorktreeContext {
|
|
root_name,
|
|
abs_path,
|
|
rules_file: None,
|
|
};
|
|
|
|
cx.spawn(async move |cx| {
|
|
if let Some(scan_complete) = scan_complete {
|
|
scan_complete.await;
|
|
}
|
|
|
|
let rules_task = cx.update(|cx| Self::load_worktree_rules_file(worktree, project, cx));
|
|
|
|
let (rules_file, rules_file_error) = match rules_task {
|
|
Some(rules_task) => match rules_task.await {
|
|
Ok(rules_file) => (Some(rules_file), None),
|
|
Err(err) => (
|
|
None,
|
|
Some(RulesLoadingError {
|
|
message: format!("{err}").into(),
|
|
}),
|
|
),
|
|
},
|
|
None => (None, None),
|
|
};
|
|
context.rules_file = rules_file;
|
|
(context, rules_file_error)
|
|
})
|
|
}
|
|
|
|
fn load_worktree_rules_file(
|
|
worktree: Entity<Worktree>,
|
|
project: Entity<Project>,
|
|
cx: &mut App,
|
|
) -> Option<Task<Result<RulesFileContext>>> {
|
|
let worktree = worktree.read(cx);
|
|
let worktree_id = worktree.id();
|
|
let selected_rules_file = RULES_FILE_NAMES
|
|
.into_iter()
|
|
.filter_map(|name| {
|
|
worktree
|
|
.entry_for_path(RelPath::unix(name).unwrap())
|
|
.filter(|entry| entry.is_file())
|
|
.map(|entry| entry.path.clone())
|
|
})
|
|
.next();
|
|
|
|
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
|
// supported. This doesn't seem to occur often in GitHub repositories.
|
|
selected_rules_file.map(|path_in_worktree| {
|
|
let project_path = ProjectPath {
|
|
worktree_id,
|
|
path: path_in_worktree.clone(),
|
|
};
|
|
let buffer_task =
|
|
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
|
let rope_task = cx.spawn(async move |cx| {
|
|
let buffer = buffer_task.await?;
|
|
let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
|
|
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
|
|
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
|
|
})?;
|
|
anyhow::Ok((project_entry_id, rope))
|
|
});
|
|
// Build a string from the rope on a background thread.
|
|
cx.background_spawn(async move {
|
|
let (project_entry_id, rope) = rope_task.await?;
|
|
anyhow::Ok(RulesFileContext {
|
|
path_in_worktree,
|
|
text: rope.to_string().trim().to_string(),
|
|
project_entry_id: project_entry_id.to_usize(),
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
fn handle_thread_title_updated(
|
|
&mut self,
|
|
thread: Entity<Thread>,
|
|
_: &TitleUpdated,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
let session_id = thread.read(cx).id();
|
|
let Some(session) = self.sessions.get(session_id) else {
|
|
return;
|
|
};
|
|
|
|
let thread = thread.downgrade();
|
|
let acp_thread = session.acp_thread.downgrade();
|
|
cx.spawn(async move |_, cx| {
|
|
let title = thread.read_with(cx, |thread, _| thread.title())?;
|
|
if let Some(title) = title {
|
|
let task =
|
|
acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
|
|
task.await?;
|
|
}
|
|
anyhow::Ok(())
|
|
})
|
|
.detach_and_log_err(cx);
|
|
}
|
|
|
|
fn handle_thread_token_usage_updated(
|
|
&mut self,
|
|
thread: Entity<Thread>,
|
|
usage: &TokenUsageUpdated,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
let Some(session) = self.sessions.get(thread.read(cx).id()) else {
|
|
return;
|
|
};
|
|
session.acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.update_token_usage(usage.0.clone(), cx);
|
|
});
|
|
}
|
|
|
|
fn handle_project_event(
|
|
&mut self,
|
|
project: Entity<Project>,
|
|
event: &project::Event,
|
|
_cx: &mut Context<Self>,
|
|
) {
|
|
let project_id = project.entity_id();
|
|
let Some(state) = self.projects.get_mut(&project_id) else {
|
|
return;
|
|
};
|
|
match event {
|
|
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
|
state.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
project::Event::WorktreeUpdatedEntries(_, items) => {
|
|
if items.iter().any(|(path, _, _)| {
|
|
RULES_FILE_NAMES
|
|
.iter()
|
|
.any(|name| path.as_ref() == RelPath::unix(name).unwrap())
|
|
}) {
|
|
state.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
fn handle_prompts_updated_event(
|
|
&mut self,
|
|
_prompt_store: Entity<PromptStore>,
|
|
_event: &prompt_store::PromptsUpdatedEvent,
|
|
_cx: &mut Context<Self>,
|
|
) {
|
|
for state in self.projects.values_mut() {
|
|
state.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
}
|
|
|
|
fn handle_models_updated_event(
|
|
&mut self,
|
|
_registry: Entity<LanguageModelRegistry>,
|
|
event: &language_model::Event,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
self.models.refresh_list(cx);
|
|
|
|
let registry = LanguageModelRegistry::read_global(cx);
|
|
let default_model = registry.default_model().map(|m| m.model);
|
|
let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
|
|
|
|
for session in self.sessions.values_mut() {
|
|
session.thread.update(cx, |thread, cx| {
|
|
if thread.model().is_none()
|
|
&& let Some(model) = default_model.clone()
|
|
{
|
|
thread.set_model(model, cx);
|
|
cx.notify();
|
|
}
|
|
if let Some(model) = summarization_model.clone() {
|
|
if thread.summarization_model().is_none()
|
|
|| matches!(event, language_model::Event::ThreadSummaryModelChanged)
|
|
{
|
|
thread.set_summarization_model(Some(model), cx);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
fn handle_context_server_store_updated(
|
|
&mut self,
|
|
store: Entity<project::context_server_store::ContextServerStore>,
|
|
_event: &project::context_server_store::ServerStatusChangedEvent,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
let project_id = self.projects.iter().find_map(|(id, state)| {
|
|
if *state.context_server_registry.read(cx).server_store() == store {
|
|
Some(*id)
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
if let Some(project_id) = project_id {
|
|
self.update_available_commands_for_project(project_id, cx);
|
|
}
|
|
}
|
|
|
|
fn handle_context_server_registry_event(
|
|
&mut self,
|
|
registry: Entity<ContextServerRegistry>,
|
|
event: &ContextServerRegistryEvent,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
match event {
|
|
ContextServerRegistryEvent::ToolsChanged => {}
|
|
ContextServerRegistryEvent::PromptsChanged => {
|
|
let project_id = self.projects.iter().find_map(|(id, state)| {
|
|
if state.context_server_registry == registry {
|
|
Some(*id)
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
if let Some(project_id) = project_id {
|
|
self.update_available_commands_for_project(project_id, cx);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
|
|
let available_commands =
|
|
Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
|
|
for session in self.sessions.values() {
|
|
if session.project_id != project_id {
|
|
continue;
|
|
}
|
|
session.acp_thread.update(cx, |thread, cx| {
|
|
thread
|
|
.handle_session_update(
|
|
acp::SessionUpdate::AvailableCommandsUpdate(
|
|
acp::AvailableCommandsUpdate::new(available_commands.clone()),
|
|
),
|
|
cx,
|
|
)
|
|
.log_err();
|
|
});
|
|
}
|
|
}
|
|
|
|
fn build_available_commands_for_project(
|
|
project_state: Option<&ProjectState>,
|
|
cx: &App,
|
|
) -> Vec<acp::AvailableCommand> {
|
|
let Some(state) = project_state else {
|
|
return vec![];
|
|
};
|
|
let registry = state.context_server_registry.read(cx);
|
|
|
|
let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
|
|
for context_server_prompt in registry.prompts() {
|
|
*prompt_name_counts
|
|
.entry(context_server_prompt.prompt.name.as_str())
|
|
.or_insert(0) += 1;
|
|
}
|
|
|
|
registry
|
|
.prompts()
|
|
.flat_map(|context_server_prompt| {
|
|
let prompt = &context_server_prompt.prompt;
|
|
|
|
let should_prefix = prompt_name_counts
|
|
.get(prompt.name.as_str())
|
|
.copied()
|
|
.unwrap_or(0)
|
|
> 1;
|
|
|
|
let name = if should_prefix {
|
|
format!("{}.{}", context_server_prompt.server_id, prompt.name)
|
|
} else {
|
|
prompt.name.clone()
|
|
};
|
|
|
|
let mut command = acp::AvailableCommand::new(
|
|
name,
|
|
prompt.description.clone().unwrap_or_default(),
|
|
);
|
|
|
|
match prompt.arguments.as_deref() {
|
|
Some([arg]) => {
|
|
let hint = format!("<{}>", arg.name);
|
|
|
|
command = command.input(acp::AvailableCommandInput::Unstructured(
|
|
acp::UnstructuredCommandInput::new(hint),
|
|
));
|
|
}
|
|
Some([]) | None => {}
|
|
Some(_) => {
|
|
// skip >1 argument commands since we don't support them yet
|
|
return None;
|
|
}
|
|
}
|
|
|
|
Some(command)
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
pub fn load_thread(
|
|
&mut self,
|
|
id: acp::SessionId,
|
|
project: Entity<Project>,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<Entity<Thread>>> {
|
|
let database_future = ThreadsDatabase::connect(cx);
|
|
cx.spawn(async move |this, cx| {
|
|
let database = database_future.await.map_err(|err| anyhow!(err))?;
|
|
let db_thread = database
|
|
.load_thread(id.clone())
|
|
.await?
|
|
.with_context(|| format!("no thread found with ID: {id:?}"))?;
|
|
|
|
this.update(cx, |this, cx| {
|
|
let project_id = this.get_or_create_project_state(&project, cx);
|
|
let project_state = this
|
|
.projects
|
|
.get(&project_id)
|
|
.context("project state not found")?;
|
|
let summarization_model = LanguageModelRegistry::read_global(cx)
|
|
.thread_summary_model(cx)
|
|
.map(|c| c.model);
|
|
|
|
Ok(cx.new(|cx| {
|
|
let mut thread = Thread::from_db(
|
|
id.clone(),
|
|
db_thread,
|
|
project_state.project.clone(),
|
|
project_state.project_context.clone(),
|
|
project_state.context_server_registry.clone(),
|
|
this.templates.clone(),
|
|
cx,
|
|
);
|
|
thread.set_summarization_model(summarization_model, cx);
|
|
thread
|
|
}))
|
|
})?
|
|
})
|
|
}
|
|
|
|
pub fn open_thread(
|
|
&mut self,
|
|
id: acp::SessionId,
|
|
project: Entity<Project>,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<Entity<AcpThread>>> {
|
|
if let Some(session) = self.sessions.get_mut(&id) {
|
|
session.ref_count += 1;
|
|
return Task::ready(Ok(session.acp_thread.clone()));
|
|
}
|
|
|
|
if let Some(pending) = self.pending_sessions.get_mut(&id) {
|
|
pending.ref_count += 1;
|
|
let task = pending.task.clone();
|
|
return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
|
|
}
|
|
|
|
let task = self.load_thread(id.clone(), project.clone(), cx);
|
|
let shared_task = cx
|
|
.spawn({
|
|
let id = id.clone();
|
|
async move |this, cx| {
|
|
let thread = match task.await {
|
|
Ok(thread) => thread,
|
|
Err(err) => {
|
|
this.update(cx, |this, _cx| {
|
|
this.pending_sessions.remove(&id);
|
|
})
|
|
.ok();
|
|
return Err(Arc::new(err));
|
|
}
|
|
};
|
|
let acp_thread = this
|
|
.update(cx, |this, cx| {
|
|
let project_id = this.get_or_create_project_state(&project, cx);
|
|
let ref_count = this
|
|
.pending_sessions
|
|
.remove(&id)
|
|
.map_or(1, |pending| pending.ref_count);
|
|
this.register_session(thread.clone(), project_id, ref_count, cx)
|
|
})
|
|
.map_err(Arc::new)?;
|
|
let events = thread.update(cx, |thread, cx| thread.replay(cx));
|
|
cx.update(|cx| {
|
|
NativeAgentConnection::handle_thread_events(
|
|
events,
|
|
acp_thread.downgrade(),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.map_err(Arc::new)?;
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.snapshot_completed_plan(cx);
|
|
});
|
|
Ok(acp_thread)
|
|
}
|
|
})
|
|
.shared();
|
|
self.pending_sessions.insert(
|
|
id,
|
|
PendingSession {
|
|
task: shared_task.clone(),
|
|
ref_count: 1,
|
|
},
|
|
);
|
|
|
|
cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
|
|
}
|
|
|
|
pub fn thread_summary(
|
|
&mut self,
|
|
id: acp::SessionId,
|
|
project: Entity<Project>,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<SharedString>> {
|
|
let thread = self.open_thread(id.clone(), project, cx);
|
|
cx.spawn(async move |this, cx| {
|
|
let acp_thread = thread.await?;
|
|
let result = this
|
|
.update(cx, |this, cx| {
|
|
this.sessions
|
|
.get(&id)
|
|
.unwrap()
|
|
.thread
|
|
.update(cx, |thread, cx| thread.summary(cx))
|
|
})?
|
|
.await
|
|
.context("Failed to generate summary")?;
|
|
|
|
this.update(cx, |this, cx| this.close_session(&id, cx))?
|
|
.await?;
|
|
drop(acp_thread);
|
|
Ok(result)
|
|
})
|
|
}
|
|
|
|
fn close_session(
|
|
&mut self,
|
|
session_id: &acp::SessionId,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<()>> {
|
|
let Some(session) = self.sessions.get_mut(session_id) else {
|
|
return Task::ready(Ok(()));
|
|
};
|
|
|
|
session.ref_count -= 1;
|
|
if session.ref_count > 0 {
|
|
return Task::ready(Ok(()));
|
|
}
|
|
|
|
let thread = session.thread.clone();
|
|
self.save_thread(thread, cx);
|
|
let Some(session) = self.sessions.remove(session_id) else {
|
|
return Task::ready(Ok(()));
|
|
};
|
|
let project_id = session.project_id;
|
|
|
|
let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
|
|
if !has_remaining {
|
|
self.projects.remove(&project_id);
|
|
}
|
|
|
|
session.pending_save
|
|
}
|
|
|
|
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
|
if thread.read(cx).is_empty() {
|
|
return;
|
|
}
|
|
|
|
let id = thread.read(cx).id().clone();
|
|
let Some(session) = self.sessions.get_mut(&id) else {
|
|
return;
|
|
};
|
|
|
|
let project_id = session.project_id;
|
|
let Some(state) = self.projects.get(&project_id) else {
|
|
return;
|
|
};
|
|
|
|
let folder_paths = PathList::new(
|
|
&state
|
|
.project
|
|
.read(cx)
|
|
.visible_worktrees(cx)
|
|
.map(|worktree| worktree.read(cx).abs_path().to_path_buf())
|
|
.collect::<Vec<_>>(),
|
|
);
|
|
|
|
let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
|
|
let database_future = ThreadsDatabase::connect(cx);
|
|
let db_thread = thread.update(cx, |thread, cx| {
|
|
thread.set_draft_prompt(draft_prompt);
|
|
thread.to_db(cx)
|
|
});
|
|
let thread_store = self.thread_store.clone();
|
|
session.pending_save = cx.spawn(async move |_, cx| {
|
|
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
|
|
return Ok(());
|
|
};
|
|
let db_thread = db_thread.await;
|
|
database
|
|
.save_thread(id, db_thread, folder_paths)
|
|
.await
|
|
.log_err();
|
|
thread_store.update(cx, |store, cx| store.reload(cx));
|
|
Ok(())
|
|
});
|
|
}
|
|
|
|
fn send_mcp_prompt(
|
|
&self,
|
|
message_id: UserMessageId,
|
|
session_id: acp::SessionId,
|
|
prompt_name: String,
|
|
server_id: ContextServerId,
|
|
arguments: HashMap<String, String>,
|
|
original_content: Vec<acp::ContentBlock>,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
let Some(state) = self.session_project_state(&session_id) else {
|
|
return Task::ready(Err(anyhow!("Project state not found for session")));
|
|
};
|
|
let server_store = state
|
|
.context_server_registry
|
|
.read(cx)
|
|
.server_store()
|
|
.clone();
|
|
let path_style = state.project.read(cx).path_style(cx);
|
|
|
|
cx.spawn(async move |this, cx| {
|
|
let prompt =
|
|
crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
|
|
|
|
let (acp_thread, thread) = this.update(cx, |this, _cx| {
|
|
let session = this
|
|
.sessions
|
|
.get(&session_id)
|
|
.context("Failed to get session")?;
|
|
anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
|
|
})??;
|
|
|
|
let mut last_is_user = true;
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.push_acp_user_block(
|
|
message_id,
|
|
original_content.into_iter().skip(1),
|
|
path_style,
|
|
cx,
|
|
);
|
|
});
|
|
|
|
for message in prompt.messages {
|
|
let context_server::types::PromptMessage { role, content } = message;
|
|
let block = mcp_message_content_to_acp_content_block(content);
|
|
|
|
match role {
|
|
context_server::types::Role::User => {
|
|
let id = acp_thread::UserMessageId::new();
|
|
|
|
acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.push_user_content_block_with_indent(
|
|
Some(id.clone()),
|
|
block.clone(),
|
|
true,
|
|
cx,
|
|
);
|
|
});
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.push_acp_user_block(id, [block], path_style, cx);
|
|
});
|
|
}
|
|
context_server::types::Role::Assistant => {
|
|
acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.push_assistant_content_block_with_indent(
|
|
block.clone(),
|
|
false,
|
|
true,
|
|
cx,
|
|
);
|
|
});
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.push_acp_agent_block(block, cx);
|
|
});
|
|
}
|
|
}
|
|
|
|
last_is_user = role == context_server::types::Role::User;
|
|
}
|
|
|
|
let response_stream = thread.update(cx, |thread, cx| {
|
|
if last_is_user {
|
|
thread.send_existing(cx)
|
|
} else {
|
|
// Resume if MCP prompt did not end with a user message
|
|
thread.resume(cx)
|
|
}
|
|
})?;
|
|
|
|
cx.update(|cx| {
|
|
NativeAgentConnection::handle_thread_events(
|
|
response_stream,
|
|
acp_thread.downgrade(),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Wrapper struct that implements the AgentConnection trait
|
|
#[derive(Clone)]
|
|
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
|
|
|
impl NativeAgentConnection {
|
|
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
|
|
self.0
|
|
.read(cx)
|
|
.sessions
|
|
.get(session_id)
|
|
.map(|session| session.thread.clone())
|
|
}
|
|
|
|
pub fn load_thread(
|
|
&self,
|
|
id: acp::SessionId,
|
|
project: Entity<Project>,
|
|
cx: &mut App,
|
|
) -> Task<Result<Entity<Thread>>> {
|
|
self.0
|
|
.update(cx, |this, cx| this.load_thread(id, project, cx))
|
|
}
|
|
|
|
fn run_turn(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
cx: &mut App,
|
|
f: impl 'static
|
|
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get_mut(&session_id)
|
|
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
|
}) else {
|
|
log::error!("Session not found in run_turn: {}", session_id);
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
log::debug!("Found session for: {}", session_id);
|
|
|
|
let response_stream = match f(thread, cx) {
|
|
Ok(stream) => stream,
|
|
Err(err) => return Task::ready(Err(err)),
|
|
};
|
|
Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
|
|
}
|
|
|
|
fn handle_thread_events(
|
|
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
|
acp_thread: WeakEntity<AcpThread>,
|
|
cx: &App,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
cx.spawn(async move |cx| {
|
|
// Handle response stream and forward to session.acp_thread
|
|
while let Some(result) = events.next().await {
|
|
match result {
|
|
Ok(event) => {
|
|
log::trace!("Received completion event: {:?}", event);
|
|
|
|
match event {
|
|
ThreadEvent::UserMessage(message) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
for content in message.content {
|
|
thread.push_user_content_block(
|
|
Some(message.id.clone()),
|
|
content.into(),
|
|
cx,
|
|
);
|
|
}
|
|
})?;
|
|
}
|
|
ThreadEvent::AgentText(text) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.push_assistant_content_block(text.into(), false, cx)
|
|
})?;
|
|
}
|
|
ThreadEvent::AgentThinking(text) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.push_assistant_content_block(text.into(), true, cx)
|
|
})?;
|
|
}
|
|
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
|
tool_call,
|
|
options,
|
|
response,
|
|
context: _,
|
|
}) => {
|
|
let outcome_task = acp_thread.update(cx, |thread, cx| {
|
|
thread.request_tool_call_authorization(tool_call, options, cx)
|
|
})??;
|
|
cx.background_spawn(async move {
|
|
if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
|
|
outcome_task.await
|
|
{
|
|
response
|
|
.send(outcome)
|
|
.map_err(|_| {
|
|
anyhow!("authorization receiver was dropped")
|
|
})
|
|
.log_err();
|
|
}
|
|
})
|
|
.detach();
|
|
}
|
|
ThreadEvent::ToolCall(tool_call) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.upsert_tool_call(tool_call, cx)
|
|
})??;
|
|
}
|
|
ThreadEvent::ToolCallUpdate(update) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.update_tool_call(update, cx)
|
|
})??;
|
|
}
|
|
ThreadEvent::Plan(plan) => {
|
|
acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
|
|
}
|
|
ThreadEvent::SubagentSpawned(session_id) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.subagent_spawned(session_id, cx);
|
|
})?;
|
|
}
|
|
ThreadEvent::Retry(status) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.update_retry_status(status, cx)
|
|
})?;
|
|
}
|
|
ThreadEvent::Stop(stop_reason) => {
|
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
|
return Ok(acp::PromptResponse::new(stop_reason));
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
log::error!("Error in model response stream: {:?}", e);
|
|
return Err(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
log::debug!("Response stream completed");
|
|
anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
|
|
})
|
|
}
|
|
}
|
|
|
|
struct Command<'a> {
|
|
prompt_name: &'a str,
|
|
arg_value: &'a str,
|
|
explicit_server_id: Option<&'a str>,
|
|
}
|
|
|
|
impl<'a> Command<'a> {
|
|
fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
|
|
let acp::ContentBlock::Text(text_content) = prompt.first()? else {
|
|
return None;
|
|
};
|
|
let text = text_content.text.trim();
|
|
let command = text.strip_prefix('/')?;
|
|
let (command, arg_value) = command
|
|
.split_once(char::is_whitespace)
|
|
.unwrap_or((command, ""));
|
|
|
|
if let Some((server_id, prompt_name)) = command.split_once('.') {
|
|
Some(Self {
|
|
prompt_name,
|
|
arg_value,
|
|
explicit_server_id: Some(server_id),
|
|
})
|
|
} else {
|
|
Some(Self {
|
|
prompt_name: command,
|
|
arg_value,
|
|
explicit_server_id: None,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
struct NativeAgentModelSelector {
|
|
session_id: acp::SessionId,
|
|
connection: NativeAgentConnection,
|
|
}
|
|
|
|
impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
|
|
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
|
log::debug!("NativeAgentConnection::list_models called");
|
|
let list = self.connection.0.read(cx).models.model_list.clone();
|
|
Task::ready(if list.is_empty() {
|
|
Err(anyhow::anyhow!("No models available"))
|
|
} else {
|
|
Ok(list)
|
|
})
|
|
}
|
|
|
|
fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
|
|
log::debug!(
|
|
"Setting model for session {}: {}",
|
|
self.session_id,
|
|
model_id
|
|
);
|
|
let Some(thread) = self
|
|
.connection
|
|
.0
|
|
.read(cx)
|
|
.sessions
|
|
.get(&self.session_id)
|
|
.map(|session| session.thread.clone())
|
|
else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
|
|
let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
|
|
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
|
};
|
|
|
|
let favorite = agent_settings::AgentSettings::get_global(cx)
|
|
.favorite_models
|
|
.iter()
|
|
.find(|favorite| {
|
|
favorite.provider.0 == model.provider_id().0.as_ref()
|
|
&& favorite.model == model.id().0.as_ref()
|
|
})
|
|
.cloned();
|
|
|
|
let LanguageModelSelection {
|
|
enable_thinking,
|
|
effort,
|
|
speed,
|
|
..
|
|
} = agent_settings::language_model_to_selection(&model, favorite.as_ref());
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
thread.set_thinking_effort(effort.clone(), cx);
|
|
thread.set_thinking_enabled(enable_thinking, cx);
|
|
if let Some(speed) = speed {
|
|
thread.set_speed(speed, cx);
|
|
}
|
|
});
|
|
|
|
update_settings_file(
|
|
self.connection.0.read(cx).fs.clone(),
|
|
cx,
|
|
move |settings, cx| {
|
|
let provider = model.provider_id().0.to_string();
|
|
let model = model.id().0.to_string();
|
|
let enable_thinking = thread.read(cx).thinking_enabled();
|
|
let speed = thread.read(cx).speed();
|
|
settings
|
|
.agent
|
|
.get_or_insert_default()
|
|
.set_model(LanguageModelSelection {
|
|
provider: provider.into(),
|
|
model,
|
|
enable_thinking,
|
|
effort,
|
|
speed,
|
|
});
|
|
},
|
|
);
|
|
|
|
Task::ready(Ok(()))
|
|
}
|
|
|
|
fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
|
|
let Some(thread) = self
|
|
.connection
|
|
.0
|
|
.read(cx)
|
|
.sessions
|
|
.get(&self.session_id)
|
|
.map(|session| session.thread.clone())
|
|
else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
let Some(model) = thread.read(cx).model() else {
|
|
return Task::ready(Err(anyhow!("Model not found")));
|
|
};
|
|
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
|
else {
|
|
return Task::ready(Err(anyhow!("Provider not found")));
|
|
};
|
|
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
|
model, &provider,
|
|
)))
|
|
}
|
|
|
|
fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
|
|
Some(self.connection.0.read(cx).models.watch())
|
|
}
|
|
|
|
fn should_render_footer(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
|
|
|
|
impl acp_thread::AgentConnection for NativeAgentConnection {
|
|
fn agent_id(&self) -> AgentId {
|
|
ZED_AGENT_ID.clone()
|
|
}
|
|
|
|
fn telemetry_id(&self) -> SharedString {
|
|
"zed".into()
|
|
}
|
|
|
|
fn new_session(
|
|
self: Rc<Self>,
|
|
project: Entity<Project>,
|
|
work_dirs: PathList,
|
|
cx: &mut App,
|
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
|
log::debug!("Creating new thread for project at: {work_dirs:?}");
|
|
Task::ready(Ok(self
|
|
.0
|
|
.update(cx, |agent, cx| agent.new_session(project, cx))))
|
|
}
|
|
|
|
fn supports_load_session(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn load_session(
|
|
self: Rc<Self>,
|
|
session_id: acp::SessionId,
|
|
project: Entity<Project>,
|
|
_work_dirs: PathList,
|
|
_title: Option<SharedString>,
|
|
cx: &mut App,
|
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
|
self.0
|
|
.update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
|
|
}
|
|
|
|
fn supports_close_session(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn close_session(
|
|
self: Rc<Self>,
|
|
session_id: &acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Task<Result<()>> {
|
|
self.0
|
|
.update(cx, |agent, cx| agent.close_session(session_id, cx))
|
|
}
|
|
|
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
|
&[] // No auth for in-process
|
|
}
|
|
|
|
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
|
Task::ready(Ok(()))
|
|
}
|
|
|
|
fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
|
|
Some(Rc::new(NativeAgentModelSelector {
|
|
session_id: session_id.clone(),
|
|
connection: self.clone(),
|
|
}) as Rc<dyn AgentModelSelector>)
|
|
}
|
|
|
|
fn prompt(
|
|
&self,
|
|
id: acp_thread::UserMessageId,
|
|
params: acp::PromptRequest,
|
|
cx: &mut App,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
let session_id = params.session_id.clone();
|
|
log::info!("Received prompt request for session: {}", session_id);
|
|
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
|
|
|
let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
|
|
log::error!("Session not found in prompt: {}", session_id);
|
|
if self.0.read(cx).sessions.contains_key(&session_id) {
|
|
log::error!(
|
|
"Session found in sessions map, but not in project state: {}",
|
|
session_id
|
|
);
|
|
}
|
|
return Task::ready(Err(anyhow::anyhow!("Session not found")));
|
|
};
|
|
|
|
if let Some(parsed_command) = Command::parse(¶ms.prompt) {
|
|
let registry = project_state.context_server_registry.read(cx);
|
|
|
|
let explicit_server_id = parsed_command
|
|
.explicit_server_id
|
|
.map(|server_id| ContextServerId(server_id.into()));
|
|
|
|
if let Some(prompt) =
|
|
registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
|
|
{
|
|
let arguments = if !parsed_command.arg_value.is_empty()
|
|
&& let Some(arg_name) = prompt
|
|
.prompt
|
|
.arguments
|
|
.as_ref()
|
|
.and_then(|args| args.first())
|
|
.map(|arg| arg.name.clone())
|
|
{
|
|
HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
|
|
} else {
|
|
Default::default()
|
|
};
|
|
|
|
let prompt_name = prompt.prompt.name.clone();
|
|
let server_id = prompt.server_id.clone();
|
|
|
|
return self.0.update(cx, |agent, cx| {
|
|
agent.send_mcp_prompt(
|
|
id,
|
|
session_id.clone(),
|
|
prompt_name,
|
|
server_id,
|
|
arguments,
|
|
params.prompt,
|
|
cx,
|
|
)
|
|
});
|
|
}
|
|
};
|
|
|
|
let path_style = project_state.project.read(cx).path_style(cx);
|
|
|
|
self.run_turn(session_id, cx, move |thread, cx| {
|
|
let content: Vec<UserMessageContent> = params
|
|
.prompt
|
|
.into_iter()
|
|
.map(|block| UserMessageContent::from_content_block(block, path_style))
|
|
.collect::<Vec<_>>();
|
|
log::debug!("Converted prompt to message: {} chars", content.len());
|
|
log::debug!("Message id: {:?}", id);
|
|
log::debug!("Message content: {:?}", content);
|
|
|
|
thread.update(cx, |thread, cx| thread.send(id, content, cx))
|
|
})
|
|
}
|
|
|
|
fn retry(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
_cx: &App,
|
|
) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
|
|
Some(Rc::new(NativeAgentSessionRetry {
|
|
connection: self.clone(),
|
|
session_id: session_id.clone(),
|
|
}) as _)
|
|
}
|
|
|
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
|
log::info!("Cancelling on session: {}", session_id);
|
|
self.0.update(cx, |agent, cx| {
|
|
if let Some(session) = agent.sessions.get(session_id) {
|
|
session
|
|
.thread
|
|
.update(cx, |thread, cx| thread.cancel(cx))
|
|
.detach();
|
|
}
|
|
});
|
|
}
|
|
|
|
fn truncate(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
cx: &App,
|
|
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
|
|
self.0.read_with(cx, |agent, _cx| {
|
|
agent.sessions.get(session_id).map(|session| {
|
|
Rc::new(NativeAgentSessionTruncate {
|
|
thread: session.thread.clone(),
|
|
acp_thread: session.acp_thread.downgrade(),
|
|
}) as _
|
|
})
|
|
})
|
|
}
|
|
|
|
fn set_title(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
cx: &App,
|
|
) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
|
|
self.0.read_with(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get(session_id)
|
|
.filter(|s| !s.thread.read(cx).is_subagent())
|
|
.map(|session| {
|
|
Rc::new(NativeAgentSessionSetTitle {
|
|
thread: session.thread.clone(),
|
|
}) as _
|
|
})
|
|
})
|
|
}
|
|
|
|
fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
|
|
let thread_store = self.0.read(cx).thread_store.clone();
|
|
Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
|
|
}
|
|
|
|
fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
|
|
Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
|
|
}
|
|
|
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
|
self
|
|
}
|
|
}
|
|
|
|
impl acp_thread::AgentTelemetry for NativeAgentConnection {
|
|
fn thread_data(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Task<Result<serde_json::Value>> {
|
|
let Some(session) = self.0.read(cx).sessions.get(session_id) else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
|
|
let task = session.thread.read(cx).to_db(cx);
|
|
cx.background_spawn(async move {
|
|
serde_json::to_value(task.await).context("Failed to serialize thread")
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct NativeAgentSessionList {
|
|
thread_store: Entity<ThreadStore>,
|
|
updates_tx: async_channel::Sender<acp_thread::SessionListUpdate>,
|
|
updates_rx: async_channel::Receiver<acp_thread::SessionListUpdate>,
|
|
_subscription: Subscription,
|
|
}
|
|
|
|
impl NativeAgentSessionList {
|
|
fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
|
|
let (tx, rx) = async_channel::unbounded();
|
|
let this_tx = tx.clone();
|
|
let subscription = cx.observe(&thread_store, move |_, _| {
|
|
this_tx
|
|
.try_send(acp_thread::SessionListUpdate::Refresh)
|
|
.ok();
|
|
});
|
|
Self {
|
|
thread_store,
|
|
updates_tx: tx,
|
|
updates_rx: rx,
|
|
_subscription: subscription,
|
|
}
|
|
}
|
|
|
|
pub fn thread_store(&self) -> &Entity<ThreadStore> {
|
|
&self.thread_store
|
|
}
|
|
}
|
|
|
|
impl AgentSessionList for NativeAgentSessionList {
|
|
fn list_sessions(
|
|
&self,
|
|
_request: AgentSessionListRequest,
|
|
cx: &mut App,
|
|
) -> Task<Result<AgentSessionListResponse>> {
|
|
let sessions = self
|
|
.thread_store
|
|
.read(cx)
|
|
.entries()
|
|
.map(|entry| AgentSessionInfo::from(&entry))
|
|
.collect();
|
|
Task::ready(Ok(AgentSessionListResponse::new(sessions)))
|
|
}
|
|
|
|
fn supports_delete(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
|
|
self.thread_store
|
|
.update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
|
|
}
|
|
|
|
fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
|
|
self.thread_store
|
|
.update(cx, |store, cx| store.delete_threads(cx))
|
|
}
|
|
|
|
fn watch(
|
|
&self,
|
|
_cx: &mut App,
|
|
) -> Option<async_channel::Receiver<acp_thread::SessionListUpdate>> {
|
|
Some(self.updates_rx.clone())
|
|
}
|
|
|
|
fn notify_refresh(&self) {
|
|
self.updates_tx
|
|
.try_send(acp_thread::SessionListUpdate::Refresh)
|
|
.ok();
|
|
}
|
|
|
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
|
self
|
|
}
|
|
}
|
|
|
|
struct NativeAgentSessionTruncate {
|
|
thread: Entity<Thread>,
|
|
acp_thread: WeakEntity<AcpThread>,
|
|
}
|
|
|
|
impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
|
|
fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
|
match self.thread.update(cx, |thread, cx| {
|
|
thread.truncate(message_id.clone(), cx)?;
|
|
Ok(thread.latest_token_usage())
|
|
}) {
|
|
Ok(usage) => {
|
|
self.acp_thread
|
|
.update(cx, |thread, cx| {
|
|
thread.update_token_usage(usage, cx);
|
|
})
|
|
.ok();
|
|
Task::ready(Ok(()))
|
|
}
|
|
Err(error) => Task::ready(Err(error)),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct NativeAgentSessionRetry {
|
|
connection: NativeAgentConnection,
|
|
session_id: acp::SessionId,
|
|
}
|
|
|
|
impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
|
|
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
|
|
self.connection
|
|
.run_turn(self.session_id.clone(), cx, |thread, cx| {
|
|
thread.update(cx, |thread, cx| thread.resume(cx))
|
|
})
|
|
}
|
|
}
|
|
|
|
struct NativeAgentSessionSetTitle {
|
|
thread: Entity<Thread>,
|
|
}
|
|
|
|
impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
|
|
fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
|
|
self.thread
|
|
.update(cx, |thread, cx| thread.set_title(title, cx));
|
|
Task::ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
pub struct NativeThreadEnvironment {
|
|
agent: WeakEntity<NativeAgent>,
|
|
thread: WeakEntity<Thread>,
|
|
acp_thread: WeakEntity<AcpThread>,
|
|
}
|
|
|
|
impl NativeThreadEnvironment {
|
|
pub(crate) fn create_subagent_thread(
|
|
&self,
|
|
label: String,
|
|
cx: &mut App,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
let Some(parent_thread_entity) = self.thread.upgrade() else {
|
|
anyhow::bail!("Parent thread no longer exists".to_string());
|
|
};
|
|
let parent_thread = parent_thread_entity.read(cx);
|
|
let current_depth = parent_thread.depth();
|
|
let parent_session_id = parent_thread.id().clone();
|
|
|
|
if current_depth >= MAX_SUBAGENT_DEPTH {
|
|
return Err(anyhow!(
|
|
"Maximum subagent depth ({}) reached",
|
|
MAX_SUBAGENT_DEPTH
|
|
));
|
|
}
|
|
|
|
let subagent_thread: Entity<Thread> = cx.new(|cx| {
|
|
let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
|
|
thread.set_title(label.into(), cx);
|
|
thread
|
|
});
|
|
|
|
let session_id = subagent_thread.read(cx).id().clone();
|
|
|
|
let acp_thread = self
|
|
.agent
|
|
.update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
|
|
let project_id = agent
|
|
.sessions
|
|
.get(&parent_session_id)
|
|
.map(|s| s.project_id)
|
|
.context("parent session not found")?;
|
|
Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
|
|
})??;
|
|
|
|
let depth = current_depth + 1;
|
|
|
|
telemetry::event!(
|
|
"Subagent Started",
|
|
session = parent_thread_entity.read(cx).id().to_string(),
|
|
subagent_session = session_id.to_string(),
|
|
depth,
|
|
is_resumed = false,
|
|
);
|
|
|
|
self.prompt_subagent(session_id, subagent_thread, acp_thread)
|
|
}
|
|
|
|
pub(crate) fn resume_subagent_thread(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
|
|
let session = agent
|
|
.sessions
|
|
.get(&session_id)
|
|
.ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
|
|
anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
|
|
})??;
|
|
|
|
let depth = subagent_thread.read(cx).depth();
|
|
|
|
if let Some(parent_thread_entity) = self.thread.upgrade() {
|
|
telemetry::event!(
|
|
"Subagent Started",
|
|
session = parent_thread_entity.read(cx).id().to_string(),
|
|
subagent_session = session_id.to_string(),
|
|
depth,
|
|
is_resumed = true,
|
|
);
|
|
}
|
|
|
|
self.prompt_subagent(session_id, subagent_thread, acp_thread)
|
|
}
|
|
|
|
fn prompt_subagent(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
subagent_thread: Entity<Thread>,
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
let Some(parent_thread_entity) = self.thread.upgrade() else {
|
|
anyhow::bail!("Parent thread no longer exists".to_string());
|
|
};
|
|
Ok(Rc::new(NativeSubagentHandle::new(
|
|
session_id,
|
|
subagent_thread,
|
|
acp_thread,
|
|
parent_thread_entity,
|
|
)) as _)
|
|
}
|
|
}
|
|
|
|
impl ThreadEnvironment for NativeThreadEnvironment {
|
|
fn create_terminal(
|
|
&self,
|
|
command: String,
|
|
cwd: Option<PathBuf>,
|
|
output_byte_limit: Option<u64>,
|
|
cx: &mut AsyncApp,
|
|
) -> Task<Result<Rc<dyn TerminalHandle>>> {
|
|
let task = self.acp_thread.update(cx, |thread, cx| {
|
|
thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
|
|
});
|
|
|
|
let acp_thread = self.acp_thread.clone();
|
|
cx.spawn(async move |cx| {
|
|
let terminal = task?.await?;
|
|
|
|
let (drop_tx, drop_rx) = oneshot::channel();
|
|
let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
|
|
|
|
cx.spawn(async move |cx| {
|
|
drop_rx.await.ok();
|
|
acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
|
|
})
|
|
.detach();
|
|
|
|
let handle = AcpTerminalHandle {
|
|
terminal,
|
|
_drop_tx: Some(drop_tx),
|
|
};
|
|
|
|
Ok(Rc::new(handle) as _)
|
|
})
|
|
}
|
|
|
|
fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
|
|
self.create_subagent_thread(label, cx)
|
|
}
|
|
|
|
fn resume_subagent(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
self.resume_subagent_thread(session_id, cx)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum SubagentPromptResult {
|
|
Completed,
|
|
Cancelled,
|
|
ContextWindowWarning,
|
|
Error(String),
|
|
}
|
|
|
|
pub struct NativeSubagentHandle {
|
|
session_id: acp::SessionId,
|
|
parent_thread: WeakEntity<Thread>,
|
|
subagent_thread: Entity<Thread>,
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
}
|
|
|
|
impl NativeSubagentHandle {
|
|
fn new(
|
|
session_id: acp::SessionId,
|
|
subagent_thread: Entity<Thread>,
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
parent_thread_entity: Entity<Thread>,
|
|
) -> Self {
|
|
NativeSubagentHandle {
|
|
session_id,
|
|
subagent_thread,
|
|
parent_thread: parent_thread_entity.downgrade(),
|
|
acp_thread,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl SubagentHandle for NativeSubagentHandle {
|
|
fn id(&self) -> acp::SessionId {
|
|
self.session_id.clone()
|
|
}
|
|
|
|
fn num_entries(&self, cx: &App) -> usize {
|
|
self.acp_thread.read(cx).entries().len()
|
|
}
|
|
|
|
fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
|
|
let thread = self.subagent_thread.clone();
|
|
let acp_thread = self.acp_thread.clone();
|
|
let subagent_session_id = self.session_id.clone();
|
|
let parent_thread = self.parent_thread.clone();
|
|
|
|
cx.spawn(async move |cx| {
|
|
let (task, _subscription) = cx.update(|cx| {
|
|
let ratio_before_prompt = thread
|
|
.read(cx)
|
|
.latest_token_usage()
|
|
.map(|usage| usage.ratio());
|
|
|
|
parent_thread
|
|
.update(cx, |parent_thread, _cx| {
|
|
parent_thread.register_running_subagent(thread.downgrade())
|
|
})
|
|
.ok();
|
|
|
|
let task = acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.send(vec![message.into()], cx)
|
|
});
|
|
|
|
let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
|
|
let mut token_limit_tx = Some(token_limit_tx);
|
|
|
|
let subscription = cx.subscribe(
|
|
&thread,
|
|
move |_thread, event: &TokenUsageUpdated, _cx| {
|
|
if let Some(usage) = &event.0 {
|
|
let old_ratio = ratio_before_prompt
|
|
.clone()
|
|
.unwrap_or(TokenUsageRatio::Normal);
|
|
let new_ratio = usage.ratio();
|
|
if old_ratio == TokenUsageRatio::Normal
|
|
&& new_ratio == TokenUsageRatio::Warning
|
|
{
|
|
if let Some(tx) = token_limit_tx.take() {
|
|
tx.send(()).ok();
|
|
}
|
|
}
|
|
}
|
|
},
|
|
);
|
|
|
|
let wait_for_prompt = cx
|
|
.background_spawn(async move {
|
|
futures::select! {
|
|
response = task.fuse() => match response {
|
|
Ok(Some(response)) => {
|
|
match response.stop_reason {
|
|
acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
|
|
acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
|
|
acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
|
|
acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
|
|
acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
|
|
}
|
|
}
|
|
Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
|
|
Err(error) => SubagentPromptResult::Error(error.to_string()),
|
|
},
|
|
_ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
|
|
}
|
|
});
|
|
|
|
(wait_for_prompt, subscription)
|
|
});
|
|
|
|
let result = match task.await {
|
|
SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
|
|
thread
|
|
.last_message()
|
|
.and_then(|message| {
|
|
let content = message.as_agent_message()?
|
|
.content
|
|
.iter()
|
|
.filter_map(|c| match c {
|
|
AgentMessageContent::Text(text) => Some(text.as_str()),
|
|
_ => None,
|
|
})
|
|
.join("\n\n");
|
|
if content.is_empty() {
|
|
None
|
|
} else {
|
|
Some( content)
|
|
}
|
|
})
|
|
.context("No response from subagent")
|
|
}),
|
|
SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
|
|
SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
|
|
SubagentPromptResult::ContextWindowWarning => {
|
|
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
|
Err(anyhow!(
|
|
"The agent is nearing the end of its context window and has been \
|
|
stopped. You can prompt the thread again to have the agent wrap up \
|
|
or hand off its work."
|
|
))
|
|
}
|
|
};
|
|
|
|
parent_thread
|
|
.update(cx, |parent_thread, cx| {
|
|
parent_thread.unregister_running_subagent(&subagent_session_id, cx)
|
|
})
|
|
.ok();
|
|
|
|
result
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct AcpTerminalHandle {
|
|
terminal: Entity<acp_thread::Terminal>,
|
|
_drop_tx: Option<oneshot::Sender<()>>,
|
|
}
|
|
|
|
impl TerminalHandle for AcpTerminalHandle {
|
|
fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
|
|
Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
|
|
}
|
|
|
|
fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
|
|
Ok(self
|
|
.terminal
|
|
.read_with(cx, |term, _cx| term.wait_for_exit()))
|
|
}
|
|
|
|
fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
|
|
Ok(self
|
|
.terminal
|
|
.read_with(cx, |term, cx| term.current_output(cx)))
|
|
}
|
|
|
|
fn kill(&self, cx: &AsyncApp) -> Result<()> {
|
|
cx.update(|cx| {
|
|
self.terminal.update(cx, |terminal, cx| {
|
|
terminal.kill(cx);
|
|
});
|
|
});
|
|
Ok(())
|
|
}
|
|
|
|
fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
|
|
Ok(self
|
|
.terminal
|
|
.read_with(cx, |term, _cx| term.was_stopped_by_user()))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod internal_tests {
|
|
use std::path::Path;
|
|
|
|
use super::*;
|
|
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
|
|
use fs::FakeFs;
|
|
use gpui::TestAppContext;
|
|
use indoc::formatdoc;
|
|
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
|
use language_model::{
|
|
LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
|
|
};
|
|
use serde_json::json;
|
|
use settings::SettingsStore;
|
|
use util::{path, rel_path::rel_path};
|
|
|
|
#[gpui::test]
|
|
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent =
|
|
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
|
|
|
|
// Creating a session registers the project and triggers context building.
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
let _acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection).new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let thread = agent.read_with(cx, |agent, _cx| {
|
|
agent.sessions.values().next().unwrap().thread.clone()
|
|
});
|
|
|
|
agent.read_with(cx, |agent, cx| {
|
|
let project_id = project.entity_id();
|
|
let state = agent.projects.get(&project_id).unwrap();
|
|
assert_eq!(state.project_context.read(cx).worktrees, vec![]);
|
|
assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
|
|
});
|
|
|
|
let worktree = project
|
|
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
agent.read_with(cx, |agent, cx| {
|
|
let project_id = project.entity_id();
|
|
let state = agent.projects.get(&project_id).unwrap();
|
|
let expected_worktrees = vec![WorktreeContext {
|
|
root_name: "a".into(),
|
|
abs_path: Path::new("/a").into(),
|
|
rules_file: None,
|
|
}];
|
|
assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
|
|
assert_eq!(
|
|
thread.read(cx).project_context().read(cx).worktrees,
|
|
expected_worktrees
|
|
);
|
|
});
|
|
|
|
// Creating `/a/.rules` updates the project context.
|
|
fs.insert_file("/a/.rules", Vec::new()).await;
|
|
cx.run_until_parked();
|
|
agent.read_with(cx, |agent, cx| {
|
|
let project_id = project.entity_id();
|
|
let state = agent.projects.get(&project_id).unwrap();
|
|
let rules_entry = worktree
|
|
.read(cx)
|
|
.entry_for_path(rel_path(".rules"))
|
|
.unwrap();
|
|
let expected_worktrees = vec![WorktreeContext {
|
|
root_name: "a".into(),
|
|
abs_path: Path::new("/a").into(),
|
|
rules_file: Some(RulesFileContext {
|
|
path_in_worktree: rel_path(".rules").into(),
|
|
text: "".into(),
|
|
project_entry_id: rules_entry.id.to_usize(),
|
|
}),
|
|
}];
|
|
assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
|
|
assert_eq!(
|
|
thread.read(cx).project_context().read(cx).worktrees,
|
|
expected_worktrees
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_listing_models(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let connection =
|
|
NativeAgentConnection(cx.update(|cx| {
|
|
NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
|
|
}));
|
|
|
|
// Create a thread/session
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection.clone()).new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/a")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
|
|
|
let models = cx
|
|
.update(|cx| {
|
|
connection
|
|
.model_selector(&session_id)
|
|
.unwrap()
|
|
.list_models(cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let acp_thread::AgentModelList::Grouped(models) = models else {
|
|
panic!("Unexpected model group");
|
|
};
|
|
assert_eq!(
|
|
models,
|
|
IndexMap::from_iter([(
|
|
AgentModelGroupName("Fake".into()),
|
|
vec![AgentModelInfo {
|
|
id: acp::ModelId::new("fake/fake"),
|
|
name: "Fake".into(),
|
|
description: None,
|
|
icon: Some(acp_thread::AgentModelIcon::Named(
|
|
ui::IconName::ZedAssistant
|
|
)),
|
|
is_latest: false,
|
|
cost: None,
|
|
}]
|
|
)])
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.create_dir(paths::settings_file().parent().unwrap())
|
|
.await
|
|
.unwrap();
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"default_model": {
|
|
"provider": "foo",
|
|
"model": "bar"
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
|
|
// Create the agent and connection
|
|
let agent =
|
|
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
|
|
// Create a thread/session
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection.clone()).new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/a")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
|
|
|
// Select a model
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
let model_id = acp::ModelId::new("fake/fake");
|
|
cx.update(|cx| selector.select_model(model_id.clone(), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Verify the thread has the selected model
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.model().unwrap().id().0, "fake");
|
|
});
|
|
});
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Verify settings file was updated
|
|
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
|
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
|
|
|
// Check that the agent settings contain the selected model
|
|
assert_eq!(
|
|
settings_json["agent"]["default_model"]["model"],
|
|
json!("fake")
|
|
);
|
|
assert_eq!(
|
|
settings_json["agent"]["default_model"]["provider"],
|
|
json!("fake")
|
|
);
|
|
|
|
// Register a thinking model and select it.
|
|
cx.update(|cx| {
|
|
let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"fake-thinking",
|
|
"Fake Thinking",
|
|
true,
|
|
));
|
|
let thinking_provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![thinking_model]),
|
|
);
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(thinking_provider, cx);
|
|
});
|
|
});
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Verify enable_thinking was written to settings as true.
|
|
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
|
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
|
assert_eq!(
|
|
settings_json["agent"]["default_model"]["enable_thinking"],
|
|
json!(true),
|
|
"selecting a thinking model should persist enable_thinking: true to settings"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.create_dir(paths::settings_file().parent().unwrap())
|
|
.await
|
|
.unwrap();
|
|
fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent =
|
|
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection.clone()).new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/a")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
|
|
|
// Register a second provider with a thinking model.
|
|
cx.update(|cx| {
|
|
let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"fake-thinking",
|
|
"Fake Thinking",
|
|
true,
|
|
));
|
|
let thinking_provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![thinking_model]),
|
|
);
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(thinking_provider, cx);
|
|
});
|
|
});
|
|
// Refresh the agent's model list so it picks up the new provider.
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
// Thread starts with thinking_enabled = false (the default).
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert!(!thread.thinking_enabled(), "thinking defaults to false");
|
|
});
|
|
});
|
|
|
|
// Select the thinking model via select_model.
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// select_model should have enabled thinking based on the model's supports_thinking().
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.thinking_enabled(),
|
|
"select_model should enable thinking when model supports it"
|
|
);
|
|
});
|
|
});
|
|
|
|
// Switch back to the non-thinking model.
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// select_model should have disabled thinking.
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
!thread.thinking_enabled(),
|
|
"select_model should disable thinking when model does not support it"
|
|
);
|
|
});
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_summarization_model_survives_transient_registry_clearing(
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent =
|
|
cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection.clone().new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/a")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.summarization_model().is_some(),
|
|
"session should have a summarization model from the test registry"
|
|
);
|
|
});
|
|
|
|
// Simulate what happens during a provider blip:
|
|
// update_active_language_model_from_settings calls set_default_model(None)
|
|
// when it can't resolve the model, clearing all fallbacks.
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.set_default_model(None, cx);
|
|
});
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.summarization_model().is_some(),
|
|
"summarization model should survive a transient default model clearing"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
// Register a thinking model.
|
|
let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"fake-thinking",
|
|
"Fake Thinking",
|
|
true,
|
|
));
|
|
let thinking_provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![thinking_model.clone()]),
|
|
);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(thinking_provider, cx);
|
|
});
|
|
});
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
// Create a thread and select the thinking model.
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection.clone().new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/a")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Verify thinking is enabled after selecting the thinking model.
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.thinking_enabled(),
|
|
"thinking should be enabled after selecting thinking model"
|
|
);
|
|
});
|
|
|
|
// Send a message so the thread gets persisted.
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
thinking_model.send_last_completion_stream_text_chunk("Response.");
|
|
thinking_model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Close the session so it can be reloaded from disk.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(agent.sessions.is_empty());
|
|
});
|
|
|
|
// Reload the thread and verify thinking_enabled is still true.
|
|
let reloaded_acp_thread = agent
|
|
.update(cx, |agent, cx| {
|
|
agent.open_thread(session_id.clone(), project.clone(), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let reloaded_thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
reloaded_thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.thinking_enabled(),
|
|
"thinking_enabled should be preserved when reloading a thread with a thinking model"
|
|
);
|
|
});
|
|
|
|
drop(reloaded_acp_thread);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
// Register a model where id() != name(), like real Anthropic models
|
|
// (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
|
|
let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"custom-model-id",
|
|
"Custom Model Display Name",
|
|
false,
|
|
));
|
|
let provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![model.clone()]),
|
|
);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(provider, cx);
|
|
});
|
|
});
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
// Create a thread and select the model.
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection.clone().new_session(
|
|
project.clone(),
|
|
PathList::new(&[Path::new("/a")]),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.model().unwrap().id().0.as_ref(),
|
|
"custom-model-id",
|
|
"model should be set before persisting"
|
|
);
|
|
});
|
|
|
|
// Send a message so the thread gets persisted.
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
model.send_last_completion_stream_text_chunk("Response.");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Close the session so it can be reloaded from disk.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(agent.sessions.is_empty());
|
|
});
|
|
|
|
// Reload the thread and verify the model was preserved.
|
|
let reloaded_acp_thread = agent
|
|
.update(cx, |agent, cx| {
|
|
agent.open_thread(session_id.clone(), project.clone(), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let reloaded_thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
reloaded_thread.read_with(cx, |thread, _| {
|
|
let reloaded_model = thread
|
|
.model()
|
|
.expect("model should be present after reload");
|
|
assert_eq!(
|
|
reloaded_model.id().0.as_ref(),
|
|
"custom-model-id",
|
|
"reloaded thread should have the same model, not fall back to the default"
|
|
);
|
|
});
|
|
|
|
drop(reloaded_acp_thread);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_save_load_thread(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
// Ensure empty threads are not saved, even if they get mutated.
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
thread.set_summarization_model(Some(summary_model.clone()), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
assert_eq!(thread_entries(&thread_store, cx), vec![]);
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| {
|
|
thread.send(
|
|
vec![
|
|
"What does ".into(),
|
|
acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
|
|
"b.md",
|
|
MentionUri::File {
|
|
abs_path: path!("/a/b.md").into(),
|
|
}
|
|
.to_uri()
|
|
.to_string(),
|
|
)),
|
|
" mean?".into(),
|
|
],
|
|
cx,
|
|
)
|
|
});
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
model.send_last_completion_stream_text_chunk("Lorem.");
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 150,
|
|
output_tokens: 75,
|
|
..Default::default()
|
|
},
|
|
));
|
|
model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
summary_model
|
|
.send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
|
|
summary_model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
let uri = MentionUri::File {
|
|
abs_path: path!("/a/b.md").into(),
|
|
}
|
|
.to_uri();
|
|
acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
formatdoc! {"
|
|
## User
|
|
|
|
What does [@b.md]({uri}) mean?
|
|
|
|
## Assistant
|
|
|
|
Lorem.
|
|
|
|
"}
|
|
)
|
|
});
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Set a draft prompt with rich content blocks and scroll position
|
|
// AFTER run_until_parked, so the only save that captures these
|
|
// changes is the one performed by close_session itself.
|
|
let draft_blocks = vec![
|
|
acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
|
|
acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
|
|
acp::ContentBlock::Text(acp::TextContent::new(" please")),
|
|
];
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
|
|
});
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.set_ui_scroll_position(Some(gpui::ListOffset {
|
|
item_ix: 5,
|
|
offset_in_item: gpui::px(12.5),
|
|
}));
|
|
});
|
|
|
|
// Close the session so it can be reloaded from disk.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
|
|
});
|
|
|
|
// Ensure the thread can be reloaded from disk.
|
|
assert_eq!(
|
|
thread_entries(&thread_store, cx),
|
|
vec![(
|
|
session_id.clone(),
|
|
format!("Explaining {}", path!("/a/b.md"))
|
|
)]
|
|
);
|
|
let acp_thread = agent
|
|
.update(cx, |agent, cx| {
|
|
agent.open_thread(session_id.clone(), project.clone(), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
formatdoc! {"
|
|
## User
|
|
|
|
What does [@b.md]({uri}) mean?
|
|
|
|
## Assistant
|
|
|
|
Lorem.
|
|
|
|
"}
|
|
)
|
|
});
|
|
|
|
// Ensure the draft prompt with rich content blocks survived the round-trip.
|
|
acp_thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
|
|
});
|
|
|
|
// Ensure token usage survived the round-trip.
|
|
acp_thread.read_with(cx, |thread, _| {
|
|
let usage = thread
|
|
.token_usage()
|
|
.expect("token usage should be restored after reload");
|
|
assert_eq!(usage.input_tokens, 150);
|
|
assert_eq!(usage.output_tokens, 75);
|
|
});
|
|
|
|
// Ensure scroll position survived the round-trip.
|
|
acp_thread.read_with(cx, |thread, _| {
|
|
let scroll = thread
|
|
.ui_scroll_position()
|
|
.expect("scroll position should be restored after reload");
|
|
assert_eq!(scroll.item_ix, 5);
|
|
assert_eq!(scroll.offset_in_item, gpui::px(12.5));
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"file.txt": "hello"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
|
|
// Send a message so the thread is non-empty (empty threads aren't saved).
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
model.send_last_completion_stream_text_chunk("world");
|
|
model.end_last_completion_stream();
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Set a draft prompt WITHOUT calling run_until_parked afterwards.
|
|
// This means no observe-triggered save has run for this change.
|
|
// The only way this data gets persisted is if close_session
|
|
// itself performs the save.
|
|
let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
|
|
"unsaved draft",
|
|
))];
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
|
|
});
|
|
|
|
// Close the session immediately — no run_until_parked in between.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Reopen and verify the draft prompt was saved.
|
|
let reloaded = agent
|
|
.update(cx, |agent, cx| {
|
|
agent.open_thread(session_id.clone(), project.clone(), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
reloaded.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.draft_prompt(),
|
|
Some(draft_blocks.as_slice()),
|
|
"close_session must save the thread; draft prompt was lost"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"file.txt": "hello"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
thread.set_summarization_model(Some(summary_model.clone()), cx);
|
|
});
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
model.send_last_completion_stream_text_chunk("world");
|
|
model.end_last_completion_stream();
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let summary = agent.update(cx, |agent, cx| {
|
|
agent.thread_summary(session_id.clone(), project.clone(), cx)
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
summary_model.send_last_completion_stream_text_chunk("summary");
|
|
summary_model.end_last_completion_stream();
|
|
|
|
assert_eq!(summary.await.unwrap(), "summary");
|
|
cx.run_until_parked();
|
|
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent
|
|
.sessions
|
|
.get(&session_id)
|
|
.expect("thread_summary should not close the active session");
|
|
assert_eq!(
|
|
session.ref_count, 1,
|
|
"thread_summary should release its temporary session reference"
|
|
);
|
|
});
|
|
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(
|
|
agent.sessions.is_empty(),
|
|
"closing the active session after thread_summary should unload it"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"file.txt": "hello"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
let model = cx.update(|cx| {
|
|
LanguageModelRegistry::read_global(cx)
|
|
.default_model()
|
|
.map(|default_model| default_model.model)
|
|
.expect("default test model should be available")
|
|
});
|
|
let fake_model = model.as_fake();
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("world");
|
|
fake_model.end_last_completion_stream();
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(agent.sessions.is_empty());
|
|
});
|
|
|
|
let first_loaded_thread = cx.update(|cx| {
|
|
connection.clone().load_session(
|
|
session_id.clone(),
|
|
project.clone(),
|
|
PathList::new(&[Path::new("")]),
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
let second_loaded_thread = cx.update(|cx| {
|
|
connection.clone().load_session(
|
|
session_id.clone(),
|
|
project.clone(),
|
|
PathList::new(&[Path::new("")]),
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let first_loaded_thread = first_loaded_thread.await.unwrap();
|
|
let second_loaded_thread = second_loaded_thread.await.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
assert_eq!(
|
|
first_loaded_thread.entity_id(),
|
|
second_loaded_thread.entity_id(),
|
|
"concurrent loads for the same session should share one AcpThread"
|
|
);
|
|
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(
|
|
agent.sessions.contains_key(&session_id),
|
|
"closing one loaded session should not drop shared session state"
|
|
);
|
|
});
|
|
|
|
let follow_up = second_loaded_thread.update(cx, |thread, cx| {
|
|
thread.send(vec!["still there?".into()], cx)
|
|
});
|
|
let follow_up = cx.foreground_executor().spawn(follow_up);
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("yes");
|
|
fake_model.end_last_completion_stream();
|
|
follow_up.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
second_loaded_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
formatdoc! {"
|
|
## User
|
|
|
|
hello
|
|
|
|
## Assistant
|
|
|
|
world
|
|
|
|
## User
|
|
|
|
still there?
|
|
|
|
## Assistant
|
|
|
|
yes
|
|
|
|
"}
|
|
);
|
|
});
|
|
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
drop(first_loaded_thread);
|
|
drop(second_loaded_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(agent.sessions.is_empty());
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
|
|
// Regression test: rapid title changes must not cause a propagation loop
|
|
// between Thread and AcpThread via handle_thread_title_updated.
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
|
|
cx.update(|cx| {
|
|
let count = title_updated_count.clone();
|
|
cx.subscribe(
|
|
&thread,
|
|
move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
|
|
let new_count = {
|
|
let mut count = count.borrow_mut();
|
|
*count += 1;
|
|
*count
|
|
};
|
|
assert!(
|
|
new_count <= 2,
|
|
"TitleUpdated fired {new_count} times; \
|
|
title updates are looping"
|
|
);
|
|
},
|
|
)
|
|
.detach();
|
|
});
|
|
|
|
thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
|
|
thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
|
|
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.title(), Some("second".into()));
|
|
});
|
|
acp_thread.read_with(cx, |acp_thread, _| {
|
|
assert_eq!(acp_thread.title(), Some("second".into()));
|
|
});
|
|
|
|
assert_eq!(*title_updated_count.borrow(), 2);
|
|
}
|
|
|
|
fn thread_entries(
|
|
thread_store: &Entity<ThreadStore>,
|
|
cx: &mut TestAppContext,
|
|
) -> Vec<(acp::SessionId, String)> {
|
|
thread_store.read_with(cx, |store, _| {
|
|
store
|
|
.entries()
|
|
.map(|entry| (entry.id.clone(), entry.title.to_string()))
|
|
.collect::<Vec<_>>()
|
|
})
|
|
}
|
|
|
|
fn init_test(cx: &mut TestAppContext) {
|
|
env_logger::try_init().ok();
|
|
cx.update(|cx| {
|
|
let settings_store = SettingsStore::test(cx);
|
|
cx.set_global(settings_store);
|
|
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
}
|
|
}
|
|
|
|
fn mcp_message_content_to_acp_content_block(
|
|
content: context_server::types::MessageContent,
|
|
) -> acp::ContentBlock {
|
|
match content {
|
|
context_server::types::MessageContent::Text {
|
|
text,
|
|
annotations: _,
|
|
} => text.into(),
|
|
context_server::types::MessageContent::Image {
|
|
data,
|
|
mime_type,
|
|
annotations: _,
|
|
} => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
|
|
context_server::types::MessageContent::Audio {
|
|
data,
|
|
mime_type,
|
|
annotations: _,
|
|
} => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
|
|
context_server::types::MessageContent::Resource {
|
|
resource,
|
|
annotations: _,
|
|
} => {
|
|
let mut link =
|
|
acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
|
|
if let Some(mime_type) = resource.mime_type {
|
|
link = link.mime_type(mime_type);
|
|
}
|
|
acp::ContentBlock::ResourceLink(link)
|
|
}
|
|
}
|
|
}
|