mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
2662 lines
93 KiB
Rust
2662 lines
93 KiB
Rust
mod db;
|
|
mod edit_agent;
|
|
mod legacy_thread;
|
|
mod native_agent_server;
|
|
pub mod outline;
|
|
mod pattern_extraction;
|
|
mod templates;
|
|
#[cfg(test)]
|
|
mod tests;
|
|
mod thread;
|
|
mod thread_store;
|
|
mod tool_permissions;
|
|
mod tools;
|
|
|
|
use context_server::ContextServerId;
|
|
pub use db::*;
|
|
use itertools::Itertools;
|
|
pub use native_agent_server::NativeAgentServer;
|
|
pub use pattern_extraction::*;
|
|
pub use shell_command_parser::extract_commands;
|
|
pub use templates::*;
|
|
pub use thread::*;
|
|
pub use thread_store::*;
|
|
pub use tool_permissions::*;
|
|
pub use tools::*;
|
|
|
|
use acp_thread::{
|
|
AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
|
|
AgentSessionListResponse, TokenUsageRatio, UserMessageId,
|
|
};
|
|
use agent_client_protocol as acp;
|
|
use anyhow::{Context as _, Result, anyhow};
|
|
use chrono::{DateTime, Utc};
|
|
use collections::{HashMap, HashSet, IndexMap};
|
|
use fs::Fs;
|
|
use futures::channel::{mpsc, oneshot};
|
|
use futures::future::Shared;
|
|
use futures::{FutureExt as _, StreamExt as _, future};
|
|
use gpui::{
|
|
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
|
};
|
|
use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
|
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
|
use prompt_store::{
|
|
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
|
|
WorktreeContext,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use settings::{LanguageModelSelection, update_settings_file};
|
|
use std::any::Any;
|
|
use std::path::{Path, PathBuf};
|
|
use std::rc::Rc;
|
|
use std::sync::Arc;
|
|
use util::ResultExt;
|
|
use util::path_list::PathList;
|
|
use util::rel_path::RelPath;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct ProjectSnapshot {
|
|
pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
|
|
pub timestamp: DateTime<Utc>,
|
|
}
|
|
|
|
pub struct RulesLoadingError {
|
|
pub message: SharedString,
|
|
}
|
|
|
|
/// Holds both the internal Thread and the AcpThread for a session
|
|
struct Session {
|
|
/// The internal thread that processes messages
|
|
thread: Entity<Thread>,
|
|
/// The ACP thread that handles protocol communication
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
pending_save: Task<()>,
|
|
_subscriptions: Vec<Subscription>,
|
|
}
|
|
|
|
pub struct LanguageModels {
|
|
/// Access language model by ID
|
|
models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
|
|
/// Cached list for returning language model information
|
|
model_list: acp_thread::AgentModelList,
|
|
refresh_models_rx: watch::Receiver<()>,
|
|
refresh_models_tx: watch::Sender<()>,
|
|
_authenticate_all_providers_task: Task<()>,
|
|
}
|
|
|
|
impl LanguageModels {
|
|
fn new(cx: &mut App) -> Self {
|
|
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
|
|
|
|
let mut this = Self {
|
|
models: HashMap::default(),
|
|
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
|
|
refresh_models_rx,
|
|
refresh_models_tx,
|
|
_authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
|
|
};
|
|
this.refresh_list(cx);
|
|
this
|
|
}
|
|
|
|
fn refresh_list(&mut self, cx: &App) {
|
|
let providers = LanguageModelRegistry::global(cx)
|
|
.read(cx)
|
|
.visible_providers()
|
|
.into_iter()
|
|
.filter(|provider| provider.is_authenticated(cx))
|
|
.collect::<Vec<_>>();
|
|
|
|
let mut language_model_list = IndexMap::default();
|
|
let mut recommended_models = HashSet::default();
|
|
|
|
let mut recommended = Vec::new();
|
|
for provider in &providers {
|
|
for model in provider.recommended_models(cx) {
|
|
recommended_models.insert((model.provider_id(), model.id()));
|
|
recommended.push(Self::map_language_model_to_info(&model, provider));
|
|
}
|
|
}
|
|
if !recommended.is_empty() {
|
|
language_model_list.insert(
|
|
acp_thread::AgentModelGroupName("Recommended".into()),
|
|
recommended,
|
|
);
|
|
}
|
|
|
|
let mut models = HashMap::default();
|
|
for provider in providers {
|
|
let mut provider_models = Vec::new();
|
|
for model in provider.provided_models(cx) {
|
|
let model_info = Self::map_language_model_to_info(&model, &provider);
|
|
let model_id = model_info.id.clone();
|
|
provider_models.push(model_info);
|
|
models.insert(model_id, model);
|
|
}
|
|
if !provider_models.is_empty() {
|
|
language_model_list.insert(
|
|
acp_thread::AgentModelGroupName(provider.name().0.clone()),
|
|
provider_models,
|
|
);
|
|
}
|
|
}
|
|
|
|
self.models = models;
|
|
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
|
|
self.refresh_models_tx.send(()).ok();
|
|
}
|
|
|
|
fn watch(&self) -> watch::Receiver<()> {
|
|
self.refresh_models_rx.clone()
|
|
}
|
|
|
|
pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
|
|
self.models.get(model_id).cloned()
|
|
}
|
|
|
|
fn map_language_model_to_info(
|
|
model: &Arc<dyn LanguageModel>,
|
|
provider: &Arc<dyn LanguageModelProvider>,
|
|
) -> acp_thread::AgentModelInfo {
|
|
acp_thread::AgentModelInfo {
|
|
id: Self::model_id(model),
|
|
name: model.name().0,
|
|
description: None,
|
|
icon: Some(match provider.icon() {
|
|
IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
|
|
IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
|
|
}),
|
|
is_latest: model.is_latest(),
|
|
cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
|
|
}
|
|
}
|
|
|
|
fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
|
|
acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
|
|
}
|
|
|
|
fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
|
|
let authenticate_all_providers = LanguageModelRegistry::global(cx)
|
|
.read(cx)
|
|
.visible_providers()
|
|
.iter()
|
|
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
|
|
.collect::<Vec<_>>();
|
|
|
|
cx.background_spawn(async move {
|
|
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
|
|
if let Err(err) = authenticate_task.await {
|
|
match err {
|
|
language_model::AuthenticateError::CredentialsNotFound => {
|
|
// Since we're authenticating these providers in the
|
|
// background for the purposes of populating the
|
|
// language selector, we don't care about providers
|
|
// where the credentials are not found.
|
|
}
|
|
language_model::AuthenticateError::ConnectionRefused => {
|
|
// Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
|
|
// LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
|
|
// TODO: Better manage LM Studio auth logic to avoid these noisy failures.
|
|
}
|
|
_ => {
|
|
// Some providers have noisy failure states that we
|
|
// don't want to spam the logs with every time the
|
|
// language model selector is initialized.
|
|
//
|
|
// Ideally these should have more clear failure modes
|
|
// that we know are safe to ignore here, like what we do
|
|
// with `CredentialsNotFound` above.
|
|
match provider_id.0.as_ref() {
|
|
"lmstudio" | "ollama" => {
|
|
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
|
|
//
|
|
// These fail noisily, so we don't log them.
|
|
}
|
|
"copilot_chat" => {
|
|
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
|
|
}
|
|
_ => {
|
|
log::error!(
|
|
"Failed to authenticate provider: {}: {err:#}",
|
|
provider_name.0
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct NativeAgent {
|
|
/// Session ID -> Session mapping
|
|
sessions: HashMap<acp::SessionId, Session>,
|
|
thread_store: Entity<ThreadStore>,
|
|
/// Shared project context for all threads
|
|
project_context: Entity<ProjectContext>,
|
|
project_context_needs_refresh: watch::Sender<()>,
|
|
_maintain_project_context: Task<Result<()>>,
|
|
context_server_registry: Entity<ContextServerRegistry>,
|
|
/// Shared templates for all threads
|
|
templates: Arc<Templates>,
|
|
/// Cached model information
|
|
models: LanguageModels,
|
|
project: Entity<Project>,
|
|
prompt_store: Option<Entity<PromptStore>>,
|
|
fs: Arc<dyn Fs>,
|
|
_subscriptions: Vec<Subscription>,
|
|
}
|
|
|
|
impl NativeAgent {
|
|
pub async fn new(
|
|
project: Entity<Project>,
|
|
thread_store: Entity<ThreadStore>,
|
|
templates: Arc<Templates>,
|
|
prompt_store: Option<Entity<PromptStore>>,
|
|
fs: Arc<dyn Fs>,
|
|
cx: &mut AsyncApp,
|
|
) -> Result<Entity<NativeAgent>> {
|
|
log::debug!("Creating new NativeAgent");
|
|
|
|
let project_context = cx
|
|
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
|
|
.await;
|
|
|
|
Ok(cx.new(|cx| {
|
|
let context_server_store = project.read(cx).context_server_store();
|
|
let context_server_registry =
|
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
|
|
|
let mut subscriptions = vec![
|
|
cx.subscribe(&project, Self::handle_project_event),
|
|
cx.subscribe(
|
|
&LanguageModelRegistry::global(cx),
|
|
Self::handle_models_updated_event,
|
|
),
|
|
cx.subscribe(
|
|
&context_server_store,
|
|
Self::handle_context_server_store_updated,
|
|
),
|
|
cx.subscribe(
|
|
&context_server_registry,
|
|
Self::handle_context_server_registry_event,
|
|
),
|
|
];
|
|
if let Some(prompt_store) = prompt_store.as_ref() {
|
|
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
|
|
}
|
|
|
|
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
|
|
watch::channel(());
|
|
Self {
|
|
sessions: HashMap::default(),
|
|
thread_store,
|
|
project_context: cx.new(|_| project_context),
|
|
project_context_needs_refresh: project_context_needs_refresh_tx,
|
|
_maintain_project_context: cx.spawn(async move |this, cx| {
|
|
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
|
}),
|
|
context_server_registry,
|
|
templates,
|
|
models: LanguageModels::new(cx),
|
|
project,
|
|
prompt_store,
|
|
fs,
|
|
_subscriptions: subscriptions,
|
|
}
|
|
}))
|
|
}
|
|
|
|
fn new_session(
|
|
&mut self,
|
|
project: Entity<Project>,
|
|
cx: &mut Context<Self>,
|
|
) -> Entity<AcpThread> {
|
|
// Create Thread
|
|
// Fetch default model from registry settings
|
|
let registry = LanguageModelRegistry::read_global(cx);
|
|
// Log available models for debugging
|
|
let available_count = registry.available_models(cx).count();
|
|
log::debug!("Total available models: {}", available_count);
|
|
|
|
let default_model = registry.default_model().and_then(|default_model| {
|
|
self.models
|
|
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
|
});
|
|
let thread = cx.new(|cx| {
|
|
Thread::new(
|
|
project.clone(),
|
|
self.project_context.clone(),
|
|
self.context_server_registry.clone(),
|
|
self.templates.clone(),
|
|
default_model,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
self.register_session(thread, cx)
|
|
}
|
|
|
|
fn register_session(
|
|
&mut self,
|
|
thread_handle: Entity<Thread>,
|
|
cx: &mut Context<Self>,
|
|
) -> Entity<AcpThread> {
|
|
let connection = Rc::new(NativeAgentConnection(cx.entity()));
|
|
|
|
let thread = thread_handle.read(cx);
|
|
let session_id = thread.id().clone();
|
|
let parent_session_id = thread.parent_thread_id();
|
|
let title = thread.title();
|
|
let project = thread.project.clone();
|
|
let action_log = thread.action_log.clone();
|
|
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
|
|
let acp_thread = cx.new(|cx| {
|
|
acp_thread::AcpThread::new(
|
|
parent_session_id,
|
|
title,
|
|
connection,
|
|
project.clone(),
|
|
action_log.clone(),
|
|
session_id.clone(),
|
|
prompt_capabilities_rx,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let registry = LanguageModelRegistry::read_global(cx);
|
|
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
|
|
|
let weak = cx.weak_entity();
|
|
let weak_thread = thread_handle.downgrade();
|
|
thread_handle.update(cx, |thread, cx| {
|
|
thread.set_summarization_model(summarization_model, cx);
|
|
thread.add_default_tools(
|
|
Rc::new(NativeThreadEnvironment {
|
|
acp_thread: acp_thread.downgrade(),
|
|
thread: weak_thread,
|
|
agent: weak,
|
|
}) as _,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let subscriptions = vec![
|
|
cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
|
|
cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
|
|
cx.observe(&thread_handle, move |this, thread, cx| {
|
|
this.save_thread(thread, cx)
|
|
}),
|
|
];
|
|
|
|
self.sessions.insert(
|
|
session_id,
|
|
Session {
|
|
thread: thread_handle,
|
|
acp_thread: acp_thread.clone(),
|
|
_subscriptions: subscriptions,
|
|
pending_save: Task::ready(()),
|
|
},
|
|
);
|
|
|
|
self.update_available_commands(cx);
|
|
|
|
acp_thread
|
|
}
|
|
|
|
pub fn models(&self) -> &LanguageModels {
|
|
&self.models
|
|
}
|
|
|
|
async fn maintain_project_context(
|
|
this: WeakEntity<Self>,
|
|
mut needs_refresh: watch::Receiver<()>,
|
|
cx: &mut AsyncApp,
|
|
) -> Result<()> {
|
|
while needs_refresh.changed().await.is_ok() {
|
|
let project_context = this
|
|
.update(cx, |this, cx| {
|
|
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
|
|
})?
|
|
.await;
|
|
this.update(cx, |this, cx| {
|
|
this.project_context = cx.new(|_| project_context);
|
|
})?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn build_project_context(
|
|
project: &Entity<Project>,
|
|
prompt_store: Option<&Entity<PromptStore>>,
|
|
cx: &mut App,
|
|
) -> Task<ProjectContext> {
|
|
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
|
|
let worktree_tasks = worktrees
|
|
.into_iter()
|
|
.map(|worktree| {
|
|
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
|
|
prompt_store.read_with(cx, |prompt_store, cx| {
|
|
let prompts = prompt_store.default_prompt_metadata();
|
|
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
|
|
let contents = prompt_store.load(prompt_metadata.id, cx);
|
|
async move { (contents.await, prompt_metadata) }
|
|
});
|
|
cx.background_spawn(future::join_all(load_tasks))
|
|
})
|
|
} else {
|
|
Task::ready(vec![])
|
|
};
|
|
|
|
cx.spawn(async move |_cx| {
|
|
let (worktrees, default_user_rules) =
|
|
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
|
|
|
|
let worktrees = worktrees
|
|
.into_iter()
|
|
.map(|(worktree, _rules_error)| {
|
|
// TODO: show error message
|
|
// if let Some(rules_error) = rules_error {
|
|
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
|
|
// }
|
|
worktree
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let default_user_rules = default_user_rules
|
|
.into_iter()
|
|
.flat_map(|(contents, prompt_metadata)| match contents {
|
|
Ok(contents) => Some(UserRulesContext {
|
|
uuid: prompt_metadata.id.as_user()?,
|
|
title: prompt_metadata.title.map(|title| title.to_string()),
|
|
contents,
|
|
}),
|
|
Err(_err) => {
|
|
// TODO: show error message
|
|
// this.update(cx, |_, cx| {
|
|
// cx.emit(RulesLoadingError {
|
|
// message: format!("{err:?}").into(),
|
|
// });
|
|
// })
|
|
// .ok();
|
|
None
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
ProjectContext::new(worktrees, default_user_rules)
|
|
})
|
|
}
|
|
|
|
fn load_worktree_info_for_system_prompt(
|
|
worktree: Entity<Worktree>,
|
|
project: Entity<Project>,
|
|
cx: &mut App,
|
|
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
|
let tree = worktree.read(cx);
|
|
let root_name = tree.root_name_str().into();
|
|
let abs_path = tree.abs_path();
|
|
|
|
let mut context = WorktreeContext {
|
|
root_name,
|
|
abs_path,
|
|
rules_file: None,
|
|
};
|
|
|
|
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
|
|
let Some(rules_task) = rules_task else {
|
|
return Task::ready((context, None));
|
|
};
|
|
|
|
cx.spawn(async move |_| {
|
|
let (rules_file, rules_file_error) = match rules_task.await {
|
|
Ok(rules_file) => (Some(rules_file), None),
|
|
Err(err) => (
|
|
None,
|
|
Some(RulesLoadingError {
|
|
message: format!("{err}").into(),
|
|
}),
|
|
),
|
|
};
|
|
context.rules_file = rules_file;
|
|
(context, rules_file_error)
|
|
})
|
|
}
|
|
|
|
fn load_worktree_rules_file(
|
|
worktree: Entity<Worktree>,
|
|
project: Entity<Project>,
|
|
cx: &mut App,
|
|
) -> Option<Task<Result<RulesFileContext>>> {
|
|
let worktree = worktree.read(cx);
|
|
let worktree_id = worktree.id();
|
|
let selected_rules_file = RULES_FILE_NAMES
|
|
.into_iter()
|
|
.filter_map(|name| {
|
|
worktree
|
|
.entry_for_path(RelPath::unix(name).unwrap())
|
|
.filter(|entry| entry.is_file())
|
|
.map(|entry| entry.path.clone())
|
|
})
|
|
.next();
|
|
|
|
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
|
// supported. This doesn't seem to occur often in GitHub repositories.
|
|
selected_rules_file.map(|path_in_worktree| {
|
|
let project_path = ProjectPath {
|
|
worktree_id,
|
|
path: path_in_worktree.clone(),
|
|
};
|
|
let buffer_task =
|
|
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
|
let rope_task = cx.spawn(async move |cx| {
|
|
let buffer = buffer_task.await?;
|
|
let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
|
|
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
|
|
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
|
|
})?;
|
|
anyhow::Ok((project_entry_id, rope))
|
|
});
|
|
// Build a string from the rope on a background thread.
|
|
cx.background_spawn(async move {
|
|
let (project_entry_id, rope) = rope_task.await?;
|
|
anyhow::Ok(RulesFileContext {
|
|
path_in_worktree,
|
|
text: rope.to_string().trim().to_string(),
|
|
project_entry_id: project_entry_id.to_usize(),
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
fn handle_thread_title_updated(
|
|
&mut self,
|
|
thread: Entity<Thread>,
|
|
_: &TitleUpdated,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
let session_id = thread.read(cx).id();
|
|
let Some(session) = self.sessions.get(session_id) else {
|
|
return;
|
|
};
|
|
let thread = thread.downgrade();
|
|
let acp_thread = session.acp_thread.downgrade();
|
|
cx.spawn(async move |_, cx| {
|
|
let title = thread.read_with(cx, |thread, _| thread.title())?;
|
|
let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
|
|
task.await
|
|
})
|
|
.detach_and_log_err(cx);
|
|
}
|
|
|
|
fn handle_thread_token_usage_updated(
|
|
&mut self,
|
|
thread: Entity<Thread>,
|
|
usage: &TokenUsageUpdated,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
let Some(session) = self.sessions.get(thread.read(cx).id()) else {
|
|
return;
|
|
};
|
|
session.acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.update_token_usage(usage.0.clone(), cx);
|
|
});
|
|
}
|
|
|
|
fn handle_project_event(
|
|
&mut self,
|
|
_project: Entity<Project>,
|
|
event: &project::Event,
|
|
_cx: &mut Context<Self>,
|
|
) {
|
|
match event {
|
|
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
|
self.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
project::Event::WorktreeUpdatedEntries(_, items) => {
|
|
if items.iter().any(|(path, _, _)| {
|
|
RULES_FILE_NAMES
|
|
.iter()
|
|
.any(|name| path.as_ref() == RelPath::unix(name).unwrap())
|
|
}) {
|
|
self.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
fn handle_prompts_updated_event(
|
|
&mut self,
|
|
_prompt_store: Entity<PromptStore>,
|
|
_event: &prompt_store::PromptsUpdatedEvent,
|
|
_cx: &mut Context<Self>,
|
|
) {
|
|
self.project_context_needs_refresh.send(()).ok();
|
|
}
|
|
|
|
fn handle_models_updated_event(
|
|
&mut self,
|
|
_registry: Entity<LanguageModelRegistry>,
|
|
_event: &language_model::Event,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
self.models.refresh_list(cx);
|
|
|
|
let registry = LanguageModelRegistry::read_global(cx);
|
|
let default_model = registry.default_model().map(|m| m.model);
|
|
let summarization_model = registry.thread_summary_model().map(|m| m.model);
|
|
|
|
for session in self.sessions.values_mut() {
|
|
session.thread.update(cx, |thread, cx| {
|
|
if thread.model().is_none()
|
|
&& let Some(model) = default_model.clone()
|
|
{
|
|
thread.set_model(model, cx);
|
|
cx.notify();
|
|
}
|
|
thread.set_summarization_model(summarization_model.clone(), cx);
|
|
});
|
|
}
|
|
}
|
|
|
|
fn handle_context_server_store_updated(
|
|
&mut self,
|
|
_store: Entity<project::context_server_store::ContextServerStore>,
|
|
_event: &project::context_server_store::ServerStatusChangedEvent,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
self.update_available_commands(cx);
|
|
}
|
|
|
|
fn handle_context_server_registry_event(
|
|
&mut self,
|
|
_registry: Entity<ContextServerRegistry>,
|
|
event: &ContextServerRegistryEvent,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
match event {
|
|
ContextServerRegistryEvent::ToolsChanged => {}
|
|
ContextServerRegistryEvent::PromptsChanged => {
|
|
self.update_available_commands(cx);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn update_available_commands(&self, cx: &mut Context<Self>) {
|
|
let available_commands = self.build_available_commands(cx);
|
|
for session in self.sessions.values() {
|
|
session.acp_thread.update(cx, |thread, cx| {
|
|
thread
|
|
.handle_session_update(
|
|
acp::SessionUpdate::AvailableCommandsUpdate(
|
|
acp::AvailableCommandsUpdate::new(available_commands.clone()),
|
|
),
|
|
cx,
|
|
)
|
|
.log_err();
|
|
});
|
|
}
|
|
}
|
|
|
|
fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
|
|
let registry = self.context_server_registry.read(cx);
|
|
|
|
let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
|
|
for context_server_prompt in registry.prompts() {
|
|
*prompt_name_counts
|
|
.entry(context_server_prompt.prompt.name.as_str())
|
|
.or_insert(0) += 1;
|
|
}
|
|
|
|
registry
|
|
.prompts()
|
|
.flat_map(|context_server_prompt| {
|
|
let prompt = &context_server_prompt.prompt;
|
|
|
|
let should_prefix = prompt_name_counts
|
|
.get(prompt.name.as_str())
|
|
.copied()
|
|
.unwrap_or(0)
|
|
> 1;
|
|
|
|
let name = if should_prefix {
|
|
format!("{}.{}", context_server_prompt.server_id, prompt.name)
|
|
} else {
|
|
prompt.name.clone()
|
|
};
|
|
|
|
let mut command = acp::AvailableCommand::new(
|
|
name,
|
|
prompt.description.clone().unwrap_or_default(),
|
|
);
|
|
|
|
match prompt.arguments.as_deref() {
|
|
Some([arg]) => {
|
|
let hint = format!("<{}>", arg.name);
|
|
|
|
command = command.input(acp::AvailableCommandInput::Unstructured(
|
|
acp::UnstructuredCommandInput::new(hint),
|
|
));
|
|
}
|
|
Some([]) | None => {}
|
|
Some(_) => {
|
|
// skip >1 argument commands since we don't support them yet
|
|
return None;
|
|
}
|
|
}
|
|
|
|
Some(command)
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
pub fn load_thread(
|
|
&mut self,
|
|
id: acp::SessionId,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<Entity<Thread>>> {
|
|
let database_future = ThreadsDatabase::connect(cx);
|
|
cx.spawn(async move |this, cx| {
|
|
let database = database_future.await.map_err(|err| anyhow!(err))?;
|
|
let db_thread = database
|
|
.load_thread(id.clone())
|
|
.await?
|
|
.with_context(|| format!("no thread found with ID: {id:?}"))?;
|
|
|
|
this.update(cx, |this, cx| {
|
|
let summarization_model = LanguageModelRegistry::read_global(cx)
|
|
.thread_summary_model()
|
|
.map(|c| c.model);
|
|
|
|
cx.new(|cx| {
|
|
let mut thread = Thread::from_db(
|
|
id.clone(),
|
|
db_thread,
|
|
this.project.clone(),
|
|
this.project_context.clone(),
|
|
this.context_server_registry.clone(),
|
|
this.templates.clone(),
|
|
cx,
|
|
);
|
|
thread.set_summarization_model(summarization_model, cx);
|
|
thread
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
pub fn open_thread(
|
|
&mut self,
|
|
id: acp::SessionId,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<Entity<AcpThread>>> {
|
|
if let Some(session) = self.sessions.get(&id) {
|
|
return Task::ready(Ok(session.acp_thread.clone()));
|
|
}
|
|
|
|
let task = self.load_thread(id, cx);
|
|
cx.spawn(async move |this, cx| {
|
|
let thread = task.await?;
|
|
let acp_thread =
|
|
this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
|
|
let events = thread.update(cx, |thread, cx| thread.replay(cx));
|
|
cx.update(|cx| {
|
|
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
|
|
})
|
|
.await?;
|
|
Ok(acp_thread)
|
|
})
|
|
}
|
|
|
|
pub fn thread_summary(
|
|
&mut self,
|
|
id: acp::SessionId,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<SharedString>> {
|
|
let thread = self.open_thread(id.clone(), cx);
|
|
cx.spawn(async move |this, cx| {
|
|
let acp_thread = thread.await?;
|
|
let result = this
|
|
.update(cx, |this, cx| {
|
|
this.sessions
|
|
.get(&id)
|
|
.unwrap()
|
|
.thread
|
|
.update(cx, |thread, cx| thread.summary(cx))
|
|
})?
|
|
.await
|
|
.context("Failed to generate summary")?;
|
|
drop(acp_thread);
|
|
Ok(result)
|
|
})
|
|
}
|
|
|
|
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
|
if thread.read(cx).is_empty() {
|
|
return;
|
|
}
|
|
|
|
let database_future = ThreadsDatabase::connect(cx);
|
|
let (id, db_thread) =
|
|
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
|
|
let Some(session) = self.sessions.get_mut(&id) else {
|
|
return;
|
|
};
|
|
|
|
let folder_paths = PathList::new(
|
|
&self
|
|
.project
|
|
.read(cx)
|
|
.visible_worktrees(cx)
|
|
.map(|worktree| worktree.read(cx).abs_path().to_path_buf())
|
|
.collect::<Vec<_>>(),
|
|
);
|
|
|
|
let thread_store = self.thread_store.clone();
|
|
session.pending_save = cx.spawn(async move |_, cx| {
|
|
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
|
|
return;
|
|
};
|
|
let db_thread = db_thread.await;
|
|
database
|
|
.save_thread(id, db_thread, folder_paths)
|
|
.await
|
|
.log_err();
|
|
thread_store.update(cx, |store, cx| store.reload(cx));
|
|
});
|
|
}
|
|
|
|
fn send_mcp_prompt(
|
|
&self,
|
|
message_id: UserMessageId,
|
|
session_id: agent_client_protocol::SessionId,
|
|
prompt_name: String,
|
|
server_id: ContextServerId,
|
|
arguments: HashMap<String, String>,
|
|
original_content: Vec<acp::ContentBlock>,
|
|
cx: &mut Context<Self>,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
let server_store = self.context_server_registry.read(cx).server_store().clone();
|
|
let path_style = self.project.read(cx).path_style(cx);
|
|
|
|
cx.spawn(async move |this, cx| {
|
|
let prompt =
|
|
crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
|
|
|
|
let (acp_thread, thread) = this.update(cx, |this, _cx| {
|
|
let session = this
|
|
.sessions
|
|
.get(&session_id)
|
|
.context("Failed to get session")?;
|
|
anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
|
|
})??;
|
|
|
|
let mut last_is_user = true;
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.push_acp_user_block(
|
|
message_id,
|
|
original_content.into_iter().skip(1),
|
|
path_style,
|
|
cx,
|
|
);
|
|
});
|
|
|
|
for message in prompt.messages {
|
|
let context_server::types::PromptMessage { role, content } = message;
|
|
let block = mcp_message_content_to_acp_content_block(content);
|
|
|
|
match role {
|
|
context_server::types::Role::User => {
|
|
let id = acp_thread::UserMessageId::new();
|
|
|
|
acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.push_user_content_block_with_indent(
|
|
Some(id.clone()),
|
|
block.clone(),
|
|
true,
|
|
cx,
|
|
);
|
|
});
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.push_acp_user_block(id, [block], path_style, cx);
|
|
});
|
|
}
|
|
context_server::types::Role::Assistant => {
|
|
acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.push_assistant_content_block_with_indent(
|
|
block.clone(),
|
|
false,
|
|
true,
|
|
cx,
|
|
);
|
|
});
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.push_acp_agent_block(block, cx);
|
|
});
|
|
}
|
|
}
|
|
|
|
last_is_user = role == context_server::types::Role::User;
|
|
}
|
|
|
|
let response_stream = thread.update(cx, |thread, cx| {
|
|
if last_is_user {
|
|
thread.send_existing(cx)
|
|
} else {
|
|
// Resume if MCP prompt did not end with a user message
|
|
thread.resume(cx)
|
|
}
|
|
})?;
|
|
|
|
cx.update(|cx| {
|
|
NativeAgentConnection::handle_thread_events(
|
|
response_stream,
|
|
acp_thread.downgrade(),
|
|
cx,
|
|
)
|
|
})
|
|
.await
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Wrapper struct that implements the AgentConnection trait
|
|
#[derive(Clone)]
|
|
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
|
|
|
impl NativeAgentConnection {
|
|
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
|
|
self.0
|
|
.read(cx)
|
|
.sessions
|
|
.get(session_id)
|
|
.map(|session| session.thread.clone())
|
|
}
|
|
|
|
pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
|
|
self.0.update(cx, |this, cx| this.load_thread(id, cx))
|
|
}
|
|
|
|
fn run_turn(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
cx: &mut App,
|
|
f: impl 'static
|
|
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get_mut(&session_id)
|
|
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
|
}) else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
log::debug!("Found session for: {}", session_id);
|
|
|
|
let response_stream = match f(thread, cx) {
|
|
Ok(stream) => stream,
|
|
Err(err) => return Task::ready(Err(err)),
|
|
};
|
|
Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
|
|
}
|
|
|
|
fn handle_thread_events(
|
|
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
|
acp_thread: WeakEntity<AcpThread>,
|
|
cx: &App,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
cx.spawn(async move |cx| {
|
|
// Handle response stream and forward to session.acp_thread
|
|
while let Some(result) = events.next().await {
|
|
match result {
|
|
Ok(event) => {
|
|
log::trace!("Received completion event: {:?}", event);
|
|
|
|
match event {
|
|
ThreadEvent::UserMessage(message) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
for content in message.content {
|
|
thread.push_user_content_block(
|
|
Some(message.id.clone()),
|
|
content.into(),
|
|
cx,
|
|
);
|
|
}
|
|
})?;
|
|
}
|
|
ThreadEvent::AgentText(text) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.push_assistant_content_block(text.into(), false, cx)
|
|
})?;
|
|
}
|
|
ThreadEvent::AgentThinking(text) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.push_assistant_content_block(text.into(), true, cx)
|
|
})?;
|
|
}
|
|
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
|
tool_call,
|
|
options,
|
|
response,
|
|
context: _,
|
|
}) => {
|
|
let outcome_task = acp_thread.update(cx, |thread, cx| {
|
|
thread.request_tool_call_authorization(tool_call, options, cx)
|
|
})??;
|
|
cx.background_spawn(async move {
|
|
if let acp::RequestPermissionOutcome::Selected(
|
|
acp::SelectedPermissionOutcome { option_id, .. },
|
|
) = outcome_task.await
|
|
{
|
|
response
|
|
.send(option_id)
|
|
.map(|_| anyhow!("authorization receiver was dropped"))
|
|
.log_err();
|
|
}
|
|
})
|
|
.detach();
|
|
}
|
|
ThreadEvent::ToolCall(tool_call) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.upsert_tool_call(tool_call, cx)
|
|
})??;
|
|
}
|
|
ThreadEvent::ToolCallUpdate(update) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.update_tool_call(update, cx)
|
|
})??;
|
|
}
|
|
ThreadEvent::SubagentSpawned(session_id) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.subagent_spawned(session_id, cx);
|
|
})?;
|
|
}
|
|
ThreadEvent::Retry(status) => {
|
|
acp_thread.update(cx, |thread, cx| {
|
|
thread.update_retry_status(status, cx)
|
|
})?;
|
|
}
|
|
ThreadEvent::Stop(stop_reason) => {
|
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
|
return Ok(acp::PromptResponse::new(stop_reason));
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
log::error!("Error in model response stream: {:?}", e);
|
|
return Err(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
log::debug!("Response stream completed");
|
|
anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
|
|
})
|
|
}
|
|
}
|
|
|
|
struct Command<'a> {
|
|
prompt_name: &'a str,
|
|
arg_value: &'a str,
|
|
explicit_server_id: Option<&'a str>,
|
|
}
|
|
|
|
impl<'a> Command<'a> {
|
|
fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
|
|
let acp::ContentBlock::Text(text_content) = prompt.first()? else {
|
|
return None;
|
|
};
|
|
let text = text_content.text.trim();
|
|
let command = text.strip_prefix('/')?;
|
|
let (command, arg_value) = command
|
|
.split_once(char::is_whitespace)
|
|
.unwrap_or((command, ""));
|
|
|
|
if let Some((server_id, prompt_name)) = command.split_once('.') {
|
|
Some(Self {
|
|
prompt_name,
|
|
arg_value,
|
|
explicit_server_id: Some(server_id),
|
|
})
|
|
} else {
|
|
Some(Self {
|
|
prompt_name: command,
|
|
arg_value,
|
|
explicit_server_id: None,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
struct NativeAgentModelSelector {
|
|
session_id: acp::SessionId,
|
|
connection: NativeAgentConnection,
|
|
}
|
|
|
|
impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
|
|
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
|
log::debug!("NativeAgentConnection::list_models called");
|
|
let list = self.connection.0.read(cx).models.model_list.clone();
|
|
Task::ready(if list.is_empty() {
|
|
Err(anyhow::anyhow!("No models available"))
|
|
} else {
|
|
Ok(list)
|
|
})
|
|
}
|
|
|
|
fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
|
|
log::debug!(
|
|
"Setting model for session {}: {}",
|
|
self.session_id,
|
|
model_id
|
|
);
|
|
let Some(thread) = self
|
|
.connection
|
|
.0
|
|
.read(cx)
|
|
.sessions
|
|
.get(&self.session_id)
|
|
.map(|session| session.thread.clone())
|
|
else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
|
|
let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
|
|
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
|
};
|
|
|
|
// We want to reset the effort level when switching models, as the currently-selected effort level may
|
|
// not be compatible.
|
|
let effort = model
|
|
.default_effort_level()
|
|
.map(|effort_level| effort_level.value.to_string());
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
thread.set_thinking_effort(effort.clone(), cx);
|
|
thread.set_thinking_enabled(model.supports_thinking(), cx);
|
|
});
|
|
|
|
update_settings_file(
|
|
self.connection.0.read(cx).fs.clone(),
|
|
cx,
|
|
move |settings, cx| {
|
|
let provider = model.provider_id().0.to_string();
|
|
let model = model.id().0.to_string();
|
|
let enable_thinking = thread.read(cx).thinking_enabled();
|
|
settings
|
|
.agent
|
|
.get_or_insert_default()
|
|
.set_model(LanguageModelSelection {
|
|
provider: provider.into(),
|
|
model,
|
|
enable_thinking,
|
|
effort,
|
|
});
|
|
},
|
|
);
|
|
|
|
Task::ready(Ok(()))
|
|
}
|
|
|
|
fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
|
|
let Some(thread) = self
|
|
.connection
|
|
.0
|
|
.read(cx)
|
|
.sessions
|
|
.get(&self.session_id)
|
|
.map(|session| session.thread.clone())
|
|
else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
let Some(model) = thread.read(cx).model() else {
|
|
return Task::ready(Err(anyhow!("Model not found")));
|
|
};
|
|
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
|
else {
|
|
return Task::ready(Err(anyhow!("Provider not found")));
|
|
};
|
|
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
|
model, &provider,
|
|
)))
|
|
}
|
|
|
|
fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
|
|
Some(self.connection.0.read(cx).models.watch())
|
|
}
|
|
|
|
fn should_render_footer(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
impl acp_thread::AgentConnection for NativeAgentConnection {
|
|
fn telemetry_id(&self) -> SharedString {
|
|
"zed".into()
|
|
}
|
|
|
|
fn new_session(
|
|
self: Rc<Self>,
|
|
project: Entity<Project>,
|
|
cwd: &Path,
|
|
cx: &mut App,
|
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
|
log::debug!("Creating new thread for project at: {cwd:?}");
|
|
Task::ready(Ok(self
|
|
.0
|
|
.update(cx, |agent, cx| agent.new_session(project, cx))))
|
|
}
|
|
|
|
fn supports_load_session(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn load_session(
|
|
self: Rc<Self>,
|
|
session: AgentSessionInfo,
|
|
_project: Entity<Project>,
|
|
_cwd: &Path,
|
|
cx: &mut App,
|
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
|
self.0
|
|
.update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
|
|
}
|
|
|
|
fn supports_close_session(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
|
|
self.0.update(cx, |agent, _cx| {
|
|
agent.sessions.remove(session_id);
|
|
});
|
|
Task::ready(Ok(()))
|
|
}
|
|
|
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
|
&[] // No auth for in-process
|
|
}
|
|
|
|
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
|
Task::ready(Ok(()))
|
|
}
|
|
|
|
fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
|
|
Some(Rc::new(NativeAgentModelSelector {
|
|
session_id: session_id.clone(),
|
|
connection: self.clone(),
|
|
}) as Rc<dyn AgentModelSelector>)
|
|
}
|
|
|
|
fn prompt(
|
|
&self,
|
|
id: Option<acp_thread::UserMessageId>,
|
|
params: acp::PromptRequest,
|
|
cx: &mut App,
|
|
) -> Task<Result<acp::PromptResponse>> {
|
|
let id = id.expect("UserMessageId is required");
|
|
let session_id = params.session_id.clone();
|
|
log::info!("Received prompt request for session: {}", session_id);
|
|
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
|
|
|
if let Some(parsed_command) = Command::parse(¶ms.prompt) {
|
|
let registry = self.0.read(cx).context_server_registry.read(cx);
|
|
|
|
let explicit_server_id = parsed_command
|
|
.explicit_server_id
|
|
.map(|server_id| ContextServerId(server_id.into()));
|
|
|
|
if let Some(prompt) =
|
|
registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
|
|
{
|
|
let arguments = if !parsed_command.arg_value.is_empty()
|
|
&& let Some(arg_name) = prompt
|
|
.prompt
|
|
.arguments
|
|
.as_ref()
|
|
.and_then(|args| args.first())
|
|
.map(|arg| arg.name.clone())
|
|
{
|
|
HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
|
|
} else {
|
|
Default::default()
|
|
};
|
|
|
|
let prompt_name = prompt.prompt.name.clone();
|
|
let server_id = prompt.server_id.clone();
|
|
|
|
return self.0.update(cx, |agent, cx| {
|
|
agent.send_mcp_prompt(
|
|
id,
|
|
session_id.clone(),
|
|
prompt_name,
|
|
server_id,
|
|
arguments,
|
|
params.prompt,
|
|
cx,
|
|
)
|
|
});
|
|
};
|
|
};
|
|
|
|
let path_style = self.0.read(cx).project.read(cx).path_style(cx);
|
|
|
|
self.run_turn(session_id, cx, move |thread, cx| {
|
|
let content: Vec<UserMessageContent> = params
|
|
.prompt
|
|
.into_iter()
|
|
.map(|block| UserMessageContent::from_content_block(block, path_style))
|
|
.collect::<Vec<_>>();
|
|
log::debug!("Converted prompt to message: {} chars", content.len());
|
|
log::debug!("Message id: {:?}", id);
|
|
log::debug!("Message content: {:?}", content);
|
|
|
|
thread.update(cx, |thread, cx| thread.send(id, content, cx))
|
|
})
|
|
}
|
|
|
|
fn retry(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
_cx: &App,
|
|
) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
|
|
Some(Rc::new(NativeAgentSessionRetry {
|
|
connection: self.clone(),
|
|
session_id: session_id.clone(),
|
|
}) as _)
|
|
}
|
|
|
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
|
log::info!("Cancelling on session: {}", session_id);
|
|
self.0.update(cx, |agent, cx| {
|
|
if let Some(session) = agent.sessions.get(session_id) {
|
|
session
|
|
.thread
|
|
.update(cx, |thread, cx| thread.cancel(cx))
|
|
.detach();
|
|
}
|
|
});
|
|
}
|
|
|
|
fn truncate(
|
|
&self,
|
|
session_id: &agent_client_protocol::SessionId,
|
|
cx: &App,
|
|
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
|
|
self.0.read_with(cx, |agent, _cx| {
|
|
agent.sessions.get(session_id).map(|session| {
|
|
Rc::new(NativeAgentSessionTruncate {
|
|
thread: session.thread.clone(),
|
|
acp_thread: session.acp_thread.downgrade(),
|
|
}) as _
|
|
})
|
|
})
|
|
}
|
|
|
|
fn set_title(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
cx: &App,
|
|
) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
|
|
self.0.read_with(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get(session_id)
|
|
.filter(|s| !s.thread.read(cx).is_subagent())
|
|
.map(|session| {
|
|
Rc::new(NativeAgentSessionSetTitle {
|
|
thread: session.thread.clone(),
|
|
}) as _
|
|
})
|
|
})
|
|
}
|
|
|
|
fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
|
|
let thread_store = self.0.read(cx).thread_store.clone();
|
|
Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
|
|
}
|
|
|
|
fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
|
|
Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
|
|
}
|
|
|
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
|
self
|
|
}
|
|
}
|
|
|
|
impl acp_thread::AgentTelemetry for NativeAgentConnection {
|
|
fn thread_data(
|
|
&self,
|
|
session_id: &acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Task<Result<serde_json::Value>> {
|
|
let Some(session) = self.0.read(cx).sessions.get(session_id) else {
|
|
return Task::ready(Err(anyhow!("Session not found")));
|
|
};
|
|
|
|
let task = session.thread.read(cx).to_db(cx);
|
|
cx.background_spawn(async move {
|
|
serde_json::to_value(task.await).context("Failed to serialize thread")
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct NativeAgentSessionList {
|
|
thread_store: Entity<ThreadStore>,
|
|
updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
|
|
updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
|
|
_subscription: Subscription,
|
|
}
|
|
|
|
impl NativeAgentSessionList {
|
|
fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
|
|
let (tx, rx) = smol::channel::unbounded();
|
|
let this_tx = tx.clone();
|
|
let subscription = cx.observe(&thread_store, move |_, _| {
|
|
this_tx
|
|
.try_send(acp_thread::SessionListUpdate::Refresh)
|
|
.ok();
|
|
});
|
|
Self {
|
|
thread_store,
|
|
updates_tx: tx,
|
|
updates_rx: rx,
|
|
_subscription: subscription,
|
|
}
|
|
}
|
|
|
|
fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
|
|
AgentSessionInfo {
|
|
session_id: entry.id,
|
|
cwd: None,
|
|
title: Some(entry.title),
|
|
updated_at: Some(entry.updated_at),
|
|
meta: None,
|
|
}
|
|
}
|
|
|
|
pub fn thread_store(&self) -> &Entity<ThreadStore> {
|
|
&self.thread_store
|
|
}
|
|
}
|
|
|
|
impl AgentSessionList for NativeAgentSessionList {
|
|
fn list_sessions(
|
|
&self,
|
|
_request: AgentSessionListRequest,
|
|
cx: &mut App,
|
|
) -> Task<Result<AgentSessionListResponse>> {
|
|
let sessions = self
|
|
.thread_store
|
|
.read(cx)
|
|
.entries()
|
|
.map(Self::to_session_info)
|
|
.collect();
|
|
Task::ready(Ok(AgentSessionListResponse::new(sessions)))
|
|
}
|
|
|
|
fn supports_delete(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
|
|
self.thread_store
|
|
.update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
|
|
}
|
|
|
|
fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
|
|
self.thread_store
|
|
.update(cx, |store, cx| store.delete_threads(cx))
|
|
}
|
|
|
|
fn watch(
|
|
&self,
|
|
_cx: &mut App,
|
|
) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
|
|
Some(self.updates_rx.clone())
|
|
}
|
|
|
|
fn notify_refresh(&self) {
|
|
self.updates_tx
|
|
.try_send(acp_thread::SessionListUpdate::Refresh)
|
|
.ok();
|
|
}
|
|
|
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
|
self
|
|
}
|
|
}
|
|
|
|
struct NativeAgentSessionTruncate {
|
|
thread: Entity<Thread>,
|
|
acp_thread: WeakEntity<AcpThread>,
|
|
}
|
|
|
|
impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
|
|
fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
|
match self.thread.update(cx, |thread, cx| {
|
|
thread.truncate(message_id.clone(), cx)?;
|
|
Ok(thread.latest_token_usage())
|
|
}) {
|
|
Ok(usage) => {
|
|
self.acp_thread
|
|
.update(cx, |thread, cx| {
|
|
thread.update_token_usage(usage, cx);
|
|
})
|
|
.ok();
|
|
Task::ready(Ok(()))
|
|
}
|
|
Err(error) => Task::ready(Err(error)),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct NativeAgentSessionRetry {
|
|
connection: NativeAgentConnection,
|
|
session_id: acp::SessionId,
|
|
}
|
|
|
|
impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
|
|
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
|
|
self.connection
|
|
.run_turn(self.session_id.clone(), cx, |thread, cx| {
|
|
thread.update(cx, |thread, cx| thread.resume(cx))
|
|
})
|
|
}
|
|
}
|
|
|
|
struct NativeAgentSessionSetTitle {
|
|
thread: Entity<Thread>,
|
|
}
|
|
|
|
impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
|
|
fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
|
|
self.thread
|
|
.update(cx, |thread, cx| thread.set_title(title, cx));
|
|
Task::ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
pub struct NativeThreadEnvironment {
|
|
agent: WeakEntity<NativeAgent>,
|
|
thread: WeakEntity<Thread>,
|
|
acp_thread: WeakEntity<AcpThread>,
|
|
}
|
|
|
|
impl NativeThreadEnvironment {
|
|
pub(crate) fn create_subagent_thread(
|
|
&self,
|
|
label: String,
|
|
cx: &mut App,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
let Some(parent_thread_entity) = self.thread.upgrade() else {
|
|
anyhow::bail!("Parent thread no longer exists".to_string());
|
|
};
|
|
let parent_thread = parent_thread_entity.read(cx);
|
|
let current_depth = parent_thread.depth();
|
|
|
|
if current_depth >= MAX_SUBAGENT_DEPTH {
|
|
return Err(anyhow!(
|
|
"Maximum subagent depth ({}) reached",
|
|
MAX_SUBAGENT_DEPTH
|
|
));
|
|
}
|
|
|
|
let subagent_thread: Entity<Thread> = cx.new(|cx| {
|
|
let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
|
|
thread.set_title(label.into(), cx);
|
|
thread
|
|
});
|
|
|
|
let session_id = subagent_thread.read(cx).id().clone();
|
|
|
|
let acp_thread = self.agent.update(cx, |agent, cx| {
|
|
agent.register_session(subagent_thread.clone(), cx)
|
|
})?;
|
|
|
|
self.prompt_subagent(session_id, subagent_thread, acp_thread)
|
|
}
|
|
|
|
pub(crate) fn resume_subagent_thread(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
|
|
let session = agent
|
|
.sessions
|
|
.get(&session_id)
|
|
.ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
|
|
anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
|
|
})??;
|
|
|
|
self.prompt_subagent(session_id, subagent_thread, acp_thread)
|
|
}
|
|
|
|
fn prompt_subagent(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
subagent_thread: Entity<Thread>,
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
let Some(parent_thread_entity) = self.thread.upgrade() else {
|
|
anyhow::bail!("Parent thread no longer exists".to_string());
|
|
};
|
|
Ok(Rc::new(NativeSubagentHandle::new(
|
|
session_id,
|
|
subagent_thread,
|
|
acp_thread,
|
|
parent_thread_entity,
|
|
)) as _)
|
|
}
|
|
}
|
|
|
|
impl ThreadEnvironment for NativeThreadEnvironment {
|
|
fn create_terminal(
|
|
&self,
|
|
command: String,
|
|
cwd: Option<PathBuf>,
|
|
output_byte_limit: Option<u64>,
|
|
cx: &mut AsyncApp,
|
|
) -> Task<Result<Rc<dyn TerminalHandle>>> {
|
|
let task = self.acp_thread.update(cx, |thread, cx| {
|
|
thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
|
|
});
|
|
|
|
let acp_thread = self.acp_thread.clone();
|
|
cx.spawn(async move |cx| {
|
|
let terminal = task?.await?;
|
|
|
|
let (drop_tx, drop_rx) = oneshot::channel();
|
|
let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
|
|
|
|
cx.spawn(async move |cx| {
|
|
drop_rx.await.ok();
|
|
acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
|
|
})
|
|
.detach();
|
|
|
|
let handle = AcpTerminalHandle {
|
|
terminal,
|
|
_drop_tx: Some(drop_tx),
|
|
};
|
|
|
|
Ok(Rc::new(handle) as _)
|
|
})
|
|
}
|
|
|
|
fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
|
|
self.create_subagent_thread(label, cx)
|
|
}
|
|
|
|
fn resume_subagent(
|
|
&self,
|
|
session_id: acp::SessionId,
|
|
cx: &mut App,
|
|
) -> Result<Rc<dyn SubagentHandle>> {
|
|
self.resume_subagent_thread(session_id, cx)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum SubagentPromptResult {
|
|
Completed,
|
|
Cancelled,
|
|
ContextWindowWarning,
|
|
Error(String),
|
|
}
|
|
|
|
pub struct NativeSubagentHandle {
|
|
session_id: acp::SessionId,
|
|
parent_thread: WeakEntity<Thread>,
|
|
subagent_thread: Entity<Thread>,
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
}
|
|
|
|
impl NativeSubagentHandle {
|
|
fn new(
|
|
session_id: acp::SessionId,
|
|
subagent_thread: Entity<Thread>,
|
|
acp_thread: Entity<acp_thread::AcpThread>,
|
|
parent_thread_entity: Entity<Thread>,
|
|
) -> Self {
|
|
NativeSubagentHandle {
|
|
session_id,
|
|
subagent_thread,
|
|
parent_thread: parent_thread_entity.downgrade(),
|
|
acp_thread,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl SubagentHandle for NativeSubagentHandle {
|
|
fn id(&self) -> acp::SessionId {
|
|
self.session_id.clone()
|
|
}
|
|
|
|
fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
|
|
let thread = self.subagent_thread.clone();
|
|
let acp_thread = self.acp_thread.clone();
|
|
let subagent_session_id = self.session_id.clone();
|
|
let parent_thread = self.parent_thread.clone();
|
|
|
|
cx.spawn(async move |cx| {
|
|
let (task, _subscription) = cx.update(|cx| {
|
|
let ratio_before_prompt = thread
|
|
.read(cx)
|
|
.latest_token_usage()
|
|
.map(|usage| usage.ratio());
|
|
|
|
parent_thread
|
|
.update(cx, |parent_thread, _cx| {
|
|
parent_thread.register_running_subagent(thread.downgrade())
|
|
})
|
|
.ok();
|
|
|
|
let task = acp_thread.update(cx, |acp_thread, cx| {
|
|
acp_thread.send(vec![message.into()], cx)
|
|
});
|
|
|
|
let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
|
|
let mut token_limit_tx = Some(token_limit_tx);
|
|
|
|
let subscription = cx.subscribe(
|
|
&thread,
|
|
move |_thread, event: &TokenUsageUpdated, _cx| {
|
|
if let Some(usage) = &event.0 {
|
|
let old_ratio = ratio_before_prompt
|
|
.clone()
|
|
.unwrap_or(TokenUsageRatio::Normal);
|
|
let new_ratio = usage.ratio();
|
|
if old_ratio == TokenUsageRatio::Normal
|
|
&& new_ratio == TokenUsageRatio::Warning
|
|
{
|
|
if let Some(tx) = token_limit_tx.take() {
|
|
tx.send(()).ok();
|
|
}
|
|
}
|
|
}
|
|
},
|
|
);
|
|
|
|
let wait_for_prompt = cx
|
|
.background_spawn(async move {
|
|
futures::select! {
|
|
response = task.fuse() => match response {
|
|
Ok(Some(response)) => {
|
|
match response.stop_reason {
|
|
acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
|
|
acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
|
|
acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
|
|
acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
|
|
acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
|
|
}
|
|
}
|
|
Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
|
|
Err(error) => SubagentPromptResult::Error(error.to_string()),
|
|
},
|
|
_ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
|
|
}
|
|
});
|
|
|
|
(wait_for_prompt, subscription)
|
|
});
|
|
|
|
let result = match task.await {
|
|
SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
|
|
thread
|
|
.last_message()
|
|
.and_then(|message| {
|
|
let content = message.as_agent_message()?
|
|
.content
|
|
.iter()
|
|
.filter_map(|c| match c {
|
|
AgentMessageContent::Text(text) => Some(text.as_str()),
|
|
_ => None,
|
|
})
|
|
.join("\n\n");
|
|
if content.is_empty() {
|
|
None
|
|
} else {
|
|
Some(content)
|
|
}
|
|
})
|
|
.context("No response from subagent")
|
|
}),
|
|
SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
|
|
SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
|
|
SubagentPromptResult::ContextWindowWarning => {
|
|
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
|
Err(anyhow!(
|
|
"The agent is nearing the end of its context window and has been \
|
|
stopped. You can prompt the thread again to have the agent wrap up \
|
|
or hand off its work."
|
|
))
|
|
}
|
|
};
|
|
|
|
parent_thread
|
|
.update(cx, |parent_thread, cx| {
|
|
parent_thread.unregister_running_subagent(&subagent_session_id, cx)
|
|
})
|
|
.ok();
|
|
|
|
result
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct AcpTerminalHandle {
|
|
terminal: Entity<acp_thread::Terminal>,
|
|
_drop_tx: Option<oneshot::Sender<()>>,
|
|
}
|
|
|
|
impl TerminalHandle for AcpTerminalHandle {
|
|
fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
|
|
Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
|
|
}
|
|
|
|
fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
|
|
Ok(self
|
|
.terminal
|
|
.read_with(cx, |term, _cx| term.wait_for_exit()))
|
|
}
|
|
|
|
fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
|
|
Ok(self
|
|
.terminal
|
|
.read_with(cx, |term, cx| term.current_output(cx)))
|
|
}
|
|
|
|
fn kill(&self, cx: &AsyncApp) -> Result<()> {
|
|
cx.update(|cx| {
|
|
self.terminal.update(cx, |terminal, cx| {
|
|
terminal.kill(cx);
|
|
});
|
|
});
|
|
Ok(())
|
|
}
|
|
|
|
fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
|
|
Ok(self
|
|
.terminal
|
|
.read_with(cx, |term, _cx| term.was_stopped_by_user()))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod internal_tests {
|
|
use super::*;
|
|
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
|
|
use fs::FakeFs;
|
|
use gpui::TestAppContext;
|
|
use indoc::formatdoc;
|
|
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
|
use language_model::{LanguageModelProviderId, LanguageModelProviderName};
|
|
use serde_json::json;
|
|
use settings::SettingsStore;
|
|
use util::{path, rel_path::rel_path};
|
|
|
|
#[gpui::test]
|
|
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = NativeAgent::new(
|
|
project.clone(),
|
|
thread_store,
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
agent.read_with(cx, |agent, cx| {
|
|
assert_eq!(agent.project_context.read(cx).worktrees, vec![])
|
|
});
|
|
|
|
let worktree = project
|
|
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
agent.read_with(cx, |agent, cx| {
|
|
assert_eq!(
|
|
agent.project_context.read(cx).worktrees,
|
|
vec![WorktreeContext {
|
|
root_name: "a".into(),
|
|
abs_path: Path::new("/a").into(),
|
|
rules_file: None
|
|
}]
|
|
)
|
|
});
|
|
|
|
// Creating `/a/.rules` updates the project context.
|
|
fs.insert_file("/a/.rules", Vec::new()).await;
|
|
cx.run_until_parked();
|
|
agent.read_with(cx, |agent, cx| {
|
|
let rules_entry = worktree
|
|
.read(cx)
|
|
.entry_for_path(rel_path(".rules"))
|
|
.unwrap();
|
|
assert_eq!(
|
|
agent.project_context.read(cx).worktrees,
|
|
vec![WorktreeContext {
|
|
root_name: "a".into(),
|
|
abs_path: Path::new("/a").into(),
|
|
rules_file: Some(RulesFileContext {
|
|
path_in_worktree: rel_path(".rules").into(),
|
|
text: "".into(),
|
|
project_entry_id: rules_entry.id.to_usize()
|
|
})
|
|
}]
|
|
)
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_listing_models(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let connection = NativeAgentConnection(
|
|
NativeAgent::new(
|
|
project.clone(),
|
|
thread_store,
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap(),
|
|
);
|
|
|
|
// Create a thread/session
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
|
|
|
let models = cx
|
|
.update(|cx| {
|
|
connection
|
|
.model_selector(&session_id)
|
|
.unwrap()
|
|
.list_models(cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let acp_thread::AgentModelList::Grouped(models) = models else {
|
|
panic!("Unexpected model group");
|
|
};
|
|
assert_eq!(
|
|
models,
|
|
IndexMap::from_iter([(
|
|
AgentModelGroupName("Fake".into()),
|
|
vec![AgentModelInfo {
|
|
id: acp::ModelId::new("fake/fake"),
|
|
name: "Fake".into(),
|
|
description: None,
|
|
icon: Some(acp_thread::AgentModelIcon::Named(
|
|
ui::IconName::ZedAssistant
|
|
)),
|
|
is_latest: false,
|
|
cost: None,
|
|
}]
|
|
)])
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.create_dir(paths::settings_file().parent().unwrap())
|
|
.await
|
|
.unwrap();
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"default_model": {
|
|
"provider": "foo",
|
|
"model": "bar"
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
|
|
// Create the agent and connection
|
|
let agent = NativeAgent::new(
|
|
project.clone(),
|
|
thread_store,
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
|
|
// Create a thread/session
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
|
|
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
|
|
|
// Select a model
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
let model_id = acp::ModelId::new("fake/fake");
|
|
cx.update(|cx| selector.select_model(model_id.clone(), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Verify the thread has the selected model
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.model().unwrap().id().0, "fake");
|
|
});
|
|
});
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Verify settings file was updated
|
|
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
|
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
|
|
|
// Check that the agent settings contain the selected model
|
|
assert_eq!(
|
|
settings_json["agent"]["default_model"]["model"],
|
|
json!("fake")
|
|
);
|
|
assert_eq!(
|
|
settings_json["agent"]["default_model"]["provider"],
|
|
json!("fake")
|
|
);
|
|
|
|
// Register a thinking model and select it.
|
|
cx.update(|cx| {
|
|
let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"fake-thinking",
|
|
"Fake Thinking",
|
|
true,
|
|
));
|
|
let thinking_provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![thinking_model]),
|
|
);
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(thinking_provider, cx);
|
|
});
|
|
});
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
|
|
.await
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Verify enable_thinking was written to settings as true.
|
|
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
|
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
|
assert_eq!(
|
|
settings_json["agent"]["default_model"]["enable_thinking"],
|
|
json!(true),
|
|
"selecting a thinking model should persist enable_thinking: true to settings"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.create_dir(paths::settings_file().parent().unwrap())
|
|
.await
|
|
.unwrap();
|
|
fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
|
|
let project = Project::test(fs.clone(), [], cx).await;
|
|
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = NativeAgent::new(
|
|
project.clone(),
|
|
thread_store,
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
|
|
|
// Register a second provider with a thinking model.
|
|
cx.update(|cx| {
|
|
let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"fake-thinking",
|
|
"Fake Thinking",
|
|
true,
|
|
));
|
|
let thinking_provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![thinking_model]),
|
|
);
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(thinking_provider, cx);
|
|
});
|
|
});
|
|
// Refresh the agent's model list so it picks up the new provider.
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
// Thread starts with thinking_enabled = false (the default).
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert!(!thread.thinking_enabled(), "thinking defaults to false");
|
|
});
|
|
});
|
|
|
|
// Select the thinking model via select_model.
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// select_model should have enabled thinking based on the model's supports_thinking().
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.thinking_enabled(),
|
|
"select_model should enable thinking when model supports it"
|
|
);
|
|
});
|
|
});
|
|
|
|
// Switch back to the non-thinking model.
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// select_model should have disabled thinking.
|
|
agent.read_with(cx, |agent, _| {
|
|
let session = agent.sessions.get(&session_id).unwrap();
|
|
session.thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
!thread.thinking_enabled(),
|
|
"select_model should disable thinking when model does not support it"
|
|
);
|
|
});
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = NativeAgent::new(
|
|
project.clone(),
|
|
thread_store.clone(),
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
// Register a thinking model.
|
|
let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"fake-thinking",
|
|
"Fake Thinking",
|
|
true,
|
|
));
|
|
let thinking_provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![thinking_model.clone()]),
|
|
);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(thinking_provider, cx);
|
|
});
|
|
});
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
// Create a thread and select the thinking model.
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), Path::new("/a"), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
// Verify thinking is enabled after selecting the thinking model.
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.thinking_enabled(),
|
|
"thinking should be enabled after selecting thinking model"
|
|
);
|
|
});
|
|
|
|
// Send a message so the thread gets persisted.
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
thinking_model.send_last_completion_stream_text_chunk("Response.");
|
|
thinking_model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Close the session so it can be reloaded from disk.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(agent.sessions.is_empty());
|
|
});
|
|
|
|
// Reload the thread and verify thinking_enabled is still true.
|
|
let reloaded_acp_thread = agent
|
|
.update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
|
|
.await
|
|
.unwrap();
|
|
let reloaded_thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
reloaded_thread.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.thinking_enabled(),
|
|
"thinking_enabled should be preserved when reloading a thread with a thinking model"
|
|
);
|
|
});
|
|
|
|
drop(reloaded_acp_thread);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/", json!({ "a": {} })).await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = NativeAgent::new(
|
|
project.clone(),
|
|
thread_store.clone(),
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
// Register a model where id() != name(), like real Anthropic models
|
|
// (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
|
|
let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake-corp",
|
|
"custom-model-id",
|
|
"Custom Model Display Name",
|
|
false,
|
|
));
|
|
let provider = Arc::new(
|
|
FakeLanguageModelProvider::new(
|
|
LanguageModelProviderId::from("fake-corp".to_string()),
|
|
LanguageModelProviderName::from("Fake Corp".to_string()),
|
|
)
|
|
.with_models(vec![model.clone()]),
|
|
);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.register_provider(provider, cx);
|
|
});
|
|
});
|
|
agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
|
|
|
|
// Create a thread and select the model.
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), Path::new("/a"), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
let selector = connection.model_selector(&session_id).unwrap();
|
|
cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
|
|
.await
|
|
.unwrap();
|
|
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.model().unwrap().id().0.as_ref(),
|
|
"custom-model-id",
|
|
"model should be set before persisting"
|
|
);
|
|
});
|
|
|
|
// Send a message so the thread gets persisted.
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
model.send_last_completion_stream_text_chunk("Response.");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Close the session so it can be reloaded from disk.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert!(agent.sessions.is_empty());
|
|
});
|
|
|
|
// Reload the thread and verify the model was preserved.
|
|
let reloaded_acp_thread = agent
|
|
.update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
|
|
.await
|
|
.unwrap();
|
|
let reloaded_thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
reloaded_thread.read_with(cx, |thread, _| {
|
|
let reloaded_model = thread
|
|
.model()
|
|
.expect("model should be present after reload");
|
|
assert_eq!(
|
|
reloaded_model.id().0.as_ref(),
|
|
"custom-model-id",
|
|
"reloaded thread should have the same model, not fall back to the default"
|
|
);
|
|
});
|
|
|
|
drop(reloaded_acp_thread);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_save_load_thread(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = NativeAgent::new(
|
|
project.clone(),
|
|
thread_store.clone(),
|
|
Templates::new(),
|
|
None,
|
|
fs.clone(),
|
|
&mut cx.to_async(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), Path::new(""), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
|
|
// Ensure empty threads are not saved, even if they get mutated.
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
thread.set_summarization_model(Some(summary_model.clone()), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
assert_eq!(thread_entries(&thread_store, cx), vec![]);
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| {
|
|
thread.send(
|
|
vec![
|
|
"What does ".into(),
|
|
acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
|
|
"b.md",
|
|
MentionUri::File {
|
|
abs_path: path!("/a/b.md").into(),
|
|
}
|
|
.to_uri()
|
|
.to_string(),
|
|
)),
|
|
" mean?".into(),
|
|
],
|
|
cx,
|
|
)
|
|
});
|
|
let send = cx.foreground_executor().spawn(send);
|
|
cx.run_until_parked();
|
|
|
|
model.send_last_completion_stream_text_chunk("Lorem.");
|
|
model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
summary_model
|
|
.send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
|
|
summary_model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
let uri = MentionUri::File {
|
|
abs_path: path!("/a/b.md").into(),
|
|
}
|
|
.to_uri();
|
|
acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
formatdoc! {"
|
|
## User
|
|
|
|
What does [@b.md]({uri}) mean?
|
|
|
|
## Assistant
|
|
|
|
Lorem.
|
|
|
|
"}
|
|
)
|
|
});
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Close the session so it can be reloaded from disk.
|
|
cx.update(|cx| connection.clone().close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(thread);
|
|
drop(acp_thread);
|
|
agent.read_with(cx, |agent, _| {
|
|
assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
|
|
});
|
|
|
|
// Ensure the thread can be reloaded from disk.
|
|
assert_eq!(
|
|
thread_entries(&thread_store, cx),
|
|
vec![(
|
|
session_id.clone(),
|
|
format!("Explaining {}", path!("/a/b.md"))
|
|
)]
|
|
);
|
|
let acp_thread = agent
|
|
.update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
|
|
.await
|
|
.unwrap();
|
|
acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
formatdoc! {"
|
|
## User
|
|
|
|
What does [@b.md]({uri}) mean?
|
|
|
|
## Assistant
|
|
|
|
Lorem.
|
|
|
|
"}
|
|
)
|
|
});
|
|
}
|
|
|
|
fn thread_entries(
|
|
thread_store: &Entity<ThreadStore>,
|
|
cx: &mut TestAppContext,
|
|
) -> Vec<(acp::SessionId, String)> {
|
|
thread_store.read_with(cx, |store, _| {
|
|
store
|
|
.entries()
|
|
.map(|entry| (entry.id.clone(), entry.title.to_string()))
|
|
.collect::<Vec<_>>()
|
|
})
|
|
}
|
|
|
|
fn init_test(cx: &mut TestAppContext) {
|
|
env_logger::try_init().ok();
|
|
cx.update(|cx| {
|
|
let settings_store = SettingsStore::test(cx);
|
|
cx.set_global(settings_store);
|
|
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
}
|
|
}
|
|
|
|
fn mcp_message_content_to_acp_content_block(
|
|
content: context_server::types::MessageContent,
|
|
) -> acp::ContentBlock {
|
|
match content {
|
|
context_server::types::MessageContent::Text {
|
|
text,
|
|
annotations: _,
|
|
} => text.into(),
|
|
context_server::types::MessageContent::Image {
|
|
data,
|
|
mime_type,
|
|
annotations: _,
|
|
} => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
|
|
context_server::types::MessageContent::Audio {
|
|
data,
|
|
mime_type,
|
|
annotations: _,
|
|
} => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
|
|
context_server::types::MessageContent::Resource {
|
|
resource,
|
|
annotations: _,
|
|
} => {
|
|
let mut link =
|
|
acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
|
|
if let Some(mime_type) = resource.mime_type {
|
|
link = link.mime_type(mime_type);
|
|
}
|
|
acp::ContentBlock::ResourceLink(link)
|
|
}
|
|
}
|
|
}
|