agent_ui: Persist external agent selections as defaults (#57511)

Matches behavior from selectors between Zed + external agents.

Also means they will persist across worktree creation 🎉

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- N/A
This commit is contained in:
Ben Brandt 2026-05-26 17:11:19 +02:00 committed by GitHub
parent b20cd411ec
commit 4bee412118
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 621 additions and 190 deletions

View file

@ -15,10 +15,13 @@ use futures::channel::mpsc;
use futures::future::Shared;
use futures::io::BufReader;
use futures::{AsyncBufReadExt as _, Future, FutureExt as _, StreamExt as _};
use project::agent_server_store::{AgentServerCommand, AgentServerStore};
use project::agent_server_store::{
AgentServerCommand, AgentServerStore, AllAgentServersSettings, CustomAgentServerSettings,
};
use project::{AgentId, Project};
use remote::remote_client::Interactive;
use serde::Deserialize;
use settings::SettingsStore;
use std::path::PathBuf;
use std::process::{ExitStatus, Stdio};
use std::rc::Rc;
@ -32,7 +35,7 @@ use util::path_list::PathList;
use util::process::Child;
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Subscription, Task, WeakEntity};
use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent};
use terminal::TerminalBuilder;
@ -421,18 +424,101 @@ pub struct AcpConnection {
auth_methods: Vec<acp::AuthMethod>,
agent_server_store: WeakEntity<AgentServerStore>,
agent_capabilities: acp::AgentCapabilities,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
default_config_options: HashMap<String, String>,
defaults: AcpConnectionDefaults,
child: Option<Child>,
session_list: Option<Rc<AcpSessionList>>,
debug_log: AcpDebugLog,
_settings_subscription: Subscription,
_io_task: Task<()>,
_dispatch_task: Task<()>,
_wait_task: Task<Result<()>>,
_stderr_task: Task<Result<()>>,
}
#[derive(Clone, Default)]
struct AcpConnectionDefaults {
mode: Rc<RefCell<Option<acp::SessionModeId>>>,
model: Rc<RefCell<Option<acp::ModelId>>>,
config_options: Rc<RefCell<HashMap<String, String>>>,
}
impl AcpConnectionDefaults {
fn new(
mode: Option<acp::SessionModeId>,
model: Option<acp::ModelId>,
config_options: HashMap<String, String>,
) -> Self {
Self {
mode: Rc::new(RefCell::new(mode)),
model: Rc::new(RefCell::new(model)),
config_options: Rc::new(RefCell::new(config_options)),
}
}
fn mode(&self) -> Option<acp::SessionModeId> {
self.mode.borrow().clone()
}
fn model(&self) -> Option<acp::ModelId> {
self.model.borrow().clone()
}
fn config_option(&self, config_id: &str) -> Option<String> {
self.config_options.borrow().get(config_id).cloned()
}
fn set(
&self,
mode: Option<acp::SessionModeId>,
model: Option<acp::ModelId>,
config_options: HashMap<String, String>,
) {
*self.mode.borrow_mut() = mode;
*self.model.borrow_mut() = model;
*self.config_options.borrow_mut() = config_options;
}
fn refresh_from_settings(&self, agent_id: &AgentId, cx: &App) {
let Some(settings_store) = cx.try_global::<SettingsStore>() else {
self.set(None, None, HashMap::default());
return;
};
let settings = settings_store.get::<AllAgentServersSettings>(None);
let Some(agent_settings) = settings.get(agent_id.as_ref()) else {
self.set(None, None, HashMap::default());
return;
};
let default_config_options = match agent_settings {
CustomAgentServerSettings::Custom {
default_config_options,
..
}
| CustomAgentServerSettings::Registry {
default_config_options,
..
} => default_config_options.clone(),
};
self.set(
agent_settings.default_mode().map(acp::SessionModeId::new),
agent_settings.default_model().map(acp::ModelId::new),
default_config_options,
);
}
fn observe_settings(&self, agent_id: AgentId, cx: &mut App) -> Subscription {
if cx.try_global::<SettingsStore>().is_none() {
return Subscription::new(|| {});
}
self.refresh_from_settings(&agent_id, cx);
let defaults = self.clone();
cx.observe_global::<SettingsStore>(move |cx| {
defaults.refresh_from_settings(&agent_id, cx);
})
}
}
struct PendingAcpSession {
task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
ref_count: usize,
@ -996,6 +1082,14 @@ impl AcpConnection {
} else {
response.auth_methods
};
let defaults =
AcpConnectionDefaults::new(default_mode, default_model, default_config_options);
let settings_subscription = cx.update({
let agent_id = agent_id.clone();
let defaults = defaults.clone();
move |cx| defaults.observe_settings(agent_id, cx)
});
Ok(Self {
id: agent_id,
auth_methods,
@ -1006,11 +1100,10 @@ impl AcpConnection {
sessions,
pending_sessions: Rc::new(RefCell::new(HashMap::default())),
agent_capabilities: response.agent_capabilities,
default_mode,
default_model,
default_config_options,
defaults,
session_list,
debug_log,
_settings_subscription: settings_subscription,
_io_task: io_task,
_dispatch_task: dispatch_task,
_wait_task: wait_task,
@ -1031,10 +1124,14 @@ impl AcpConnection {
agent_server_store: WeakEntity<AgentServerStore>,
io_task: Task<()>,
dispatch_task: Task<()>,
_cx: &mut App,
cx: &mut App,
) -> Self {
let agent_id = AgentId::new("test");
let defaults = AcpConnectionDefaults::default();
let settings_subscription = defaults.observe_settings(agent_id.clone(), cx);
Self {
id: AgentId::new("test"),
id: agent_id,
telemetry_id: "test".into(),
agent_version: None,
connection,
@ -1043,12 +1140,11 @@ impl AcpConnection {
auth_methods: vec![],
agent_server_store,
agent_capabilities,
default_mode: None,
default_model: None,
default_config_options: HashMap::default(),
defaults,
child: None,
session_list: None,
debug_log: AcpDebugLog::default(),
_settings_subscription: settings_subscription,
_io_task: io_task,
_dispatch_task: dispatch_task,
_wait_task: Task::ready(Ok(())),
@ -1215,7 +1311,7 @@ impl AcpConnection {
config_opts_ref
.iter()
.filter_map(|config_option| {
let default_value = self.default_config_options.get(&*config_option.id.0)?;
let default_value = self.defaults.config_option(config_option.id.0.as_ref())?;
let is_valid = match &config_option.kind {
acp::SessionConfigKind::Select(select) => match &select.options {
@ -1241,11 +1337,7 @@ impl AcpConnection {
}
_ => None,
};
Some((
config_option.id.clone(),
default_value.clone(),
initial_value,
))
Some((config_option.id.clone(), default_value, initial_value))
} else {
log::warn!(
"`{}` is not a valid value for config option `{}` in {}",
@ -1488,7 +1580,8 @@ impl AgentConnection for AcpConnection {
let (modes, models, config_options) =
config_state(response.modes, response.models, response.config_options);
if let Some(default_mode) = self.default_mode.clone() {
let default_mode = self.defaults.mode();
if let Some(default_mode) = default_mode {
if let Some(modes) = modes.as_ref() {
let mut modes_ref = modes.borrow_mut();
let has_mode = modes_ref
@ -1537,7 +1630,8 @@ impl AgentConnection for AcpConnection {
}
}
if let Some(default_model) = self.default_model.clone() {
let default_model = self.defaults.model();
if let Some(default_model) = default_model {
if let Some(models) = models.as_ref() {
let mut models_ref = models.borrow_mut();
let has_model = models_ref
@ -2501,6 +2595,7 @@ mod tests {
use super::*;
use gpui::UpdateGlobal as _;
use settings::Settings as _;
#[test]
fn terminal_auth_task_builds_spawn_from_prebuilt_command() {
@ -2970,6 +3065,68 @@ mod tests {
.expect("failed to receive ACP connection")
}
#[gpui::test]
async fn settings_changes_refresh_active_connection_defaults(cx: &mut gpui::TestAppContext) {
cx.update(|cx| {
let store = settings::SettingsStore::test(cx);
cx.set_global(store);
});
let fs = fs::FakeFs::new(cx.executor());
fs.insert_tree("/", serde_json::json!({ "a": {} })).await;
let project = project::Project::test(fs, [std::path::Path::new("/a")], cx).await;
let harness = test_support::connect_fake_acp_connection(project, cx).await;
cx.update(|cx| {
AllAgentServersSettings::override_global(
AllAgentServersSettings(HashMap::from_iter([(
"test".to_string(),
settings::CustomAgentServerSettings::Custom {
path: PathBuf::from("test-agent"),
args: Vec::new(),
env: HashMap::default(),
default_mode: Some("manual".to_string()),
default_model: Some("claude-sonnet-4".to_string()),
favorite_models: Vec::new(),
default_config_options: HashMap::from_iter([(
"mode".to_string(),
"manual".to_string(),
)]),
favorite_config_option_values: HashMap::default(),
}
.into(),
)])),
cx,
);
});
cx.run_until_parked();
assert_eq!(
harness.connection.defaults.mode(),
Some(acp::SessionModeId::new("manual"))
);
assert_eq!(
harness.connection.defaults.model(),
Some(acp::ModelId::new("claude-sonnet-4"))
);
assert_eq!(
harness.connection.defaults.config_option("mode").as_deref(),
Some("manual")
);
cx.update(|cx| {
AllAgentServersSettings::override_global(
AllAgentServersSettings(HashMap::default()),
cx,
);
});
cx.run_until_parked();
assert_eq!(harness.connection.defaults.mode(), None);
assert_eq!(harness.connection.defaults.model(), None);
assert_eq!(harness.connection.defaults.config_option("mode"), None);
}
#[gpui::test]
async fn session_list_delete_sends_session_delete_when_supported(
cx: &mut gpui::TestAppContext,

View file

@ -19,7 +19,7 @@ use ui::{
};
use util::ResultExt as _;
use crate::ui::{HoldForDefault, documentation_aside_side};
use crate::ui::documentation_aside_side;
const PICKER_THRESHOLD: usize = 5;
@ -101,6 +101,13 @@ impl ConfigOptionsView {
return false;
};
self.agent_server.set_default_config_option(
config_id.0.as_ref(),
Some(next_value.0.as_ref()),
self.fs.clone(),
cx,
);
let task = self
.config_options
.set_config_option(config_id, next_value, cx);
@ -412,7 +419,7 @@ struct ConfigOptionPickerDelegate {
filtered_entries: Vec<ConfigOptionPickerEntry>,
all_options: Vec<ConfigOptionValue>,
selected_index: usize,
selected_description: Option<(usize, SharedString, bool)>,
selected_description: Option<(usize, SharedString)>,
favorites: HashSet<acp::SessionConfigValueId>,
_settings_subscription: Subscription,
}
@ -544,28 +551,16 @@ impl PickerDelegate for ConfigOptionPickerDelegate {
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
if let Some(ConfigOptionPickerEntry::Option(option)) =
self.filtered_entries.get(self.selected_index)
{
if window.modifiers().secondary() {
let default_value = self
.agent_server
.default_config_option(self.config_id.0.as_ref(), cx);
let is_default = default_value.as_deref() == Some(&*option.value.0);
self.agent_server.set_default_config_option(
self.config_id.0.as_ref(),
if is_default {
None
} else {
Some(option.value.0.as_ref())
},
self.fs.clone(),
cx,
);
}
self.agent_server.set_default_config_option(
self.config_id.0.as_ref(),
Some(option.value.0.as_ref()),
self.fs.clone(),
cx,
);
let task = self.config_options.set_config_option(
self.config_id.clone(),
option.value.clone(),
@ -614,11 +609,6 @@ impl PickerDelegate for ConfigOptionPickerDelegate {
let current_value = self.current_value();
let is_selected = current_value.as_ref() == Some(&option.value);
let default_value = self
.agent_server
.default_config_option(self.config_id.0.as_ref(), cx);
let is_default = default_value.as_deref() == Some(&*option.value.0);
let is_favorite = self.favorites.contains(&option.value);
let option_name = option.name.clone();
@ -631,9 +621,8 @@ impl PickerDelegate for ConfigOptionPickerDelegate {
let desc: SharedString = desc.into();
this.on_hover(cx.listener(move |menu, hovered, _, cx| {
if *hovered {
menu.delegate.selected_description =
Some((ix, desc.clone(), is_default));
} else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix)
menu.delegate.selected_description = Some((ix, desc.clone()));
} else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix)
{
menu.delegate.selected_description = None;
}
@ -688,29 +677,20 @@ impl PickerDelegate for ConfigOptionPickerDelegate {
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<ui::DocumentationAside> {
self.selected_description
.as_ref()
.map(|(_, description, is_default)| {
let description = description.clone();
let is_default = *is_default;
self.selected_description.as_ref().map(|(_, description)| {
let description = description.clone();
let side = documentation_aside_side(cx);
let side = documentation_aside_side(cx);
ui::DocumentationAside::new(
side,
Rc::new(move |_| {
v_flex()
.gap_1()
.child(Label::new(description.clone()))
.child(HoldForDefault::new(is_default))
.into_any_element()
}),
)
})
ui::DocumentationAside::new(
side,
Rc::new(move |_| Label::new(description.clone()).into_any_element()),
)
})
}
fn documentation_aside_index(&self) -> Option<usize> {
self.selected_description.as_ref().map(|(ix, _, _)| *ix)
self.selected_description.as_ref().map(|(ix, _)| *ix)
}
}
@ -878,3 +858,143 @@ fn count_config_options(option: &acp::SessionConfigOption) -> usize {
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use acp_thread::AgentConnection;
use fs::FakeFs;
use gpui::TestAppContext;
use parking_lot::Mutex;
use project::{AgentId, Project};
use std::{any::Any, cell::RefCell};
#[gpui::test]
fn cycling_config_option_saves_selected_value_as_default(cx: &mut TestAppContext) {
let agent_server = Rc::new(TestAgentServer::default());
let config_options = Rc::new(TestSessionConfigOptions::new(vec![
acp::SessionConfigOption::select(
"mode",
"Mode",
"auto",
vec![
acp::SessionConfigSelectOption::new("auto", "Auto"),
acp::SessionConfigSelectOption::new("manual", "Manual"),
],
)
.category(acp::SessionConfigOptionCategory::Mode),
]));
let fs: Arc<dyn Fs> = FakeFs::new(cx.executor());
cx.update(|cx| {
let config_options: Rc<dyn AgentSessionConfigOptions> = config_options.clone();
let agent_server: Rc<dyn AgentServer> = agent_server.clone();
let fs = fs.clone();
let view = cx.new(|_| ConfigOptionsView {
config_option_ids: ConfigOptionsView::config_option_ids(&config_options),
config_options,
selectors: Vec::new(),
agent_server,
fs,
_refresh_task: Task::ready(()),
});
assert!(view.update(cx, |view, cx| {
view.cycle_category_option(acp::SessionConfigOptionCategory::Mode, false, cx)
}));
});
assert_eq!(
agent_server.saved_defaults.lock().as_slice(),
&[("mode".to_string(), Some("manual".to_string()))]
);
assert_eq!(
config_options.set_values.borrow().as_slice(),
&[("mode".to_string(), "manual".to_string())]
);
}
#[derive(Default)]
struct TestAgentServer {
saved_defaults: Arc<Mutex<Vec<(String, Option<String>)>>>,
}
impl AgentServer for TestAgentServer {
fn logo(&self) -> IconName {
IconName::ZedAssistant
}
fn agent_id(&self) -> AgentId {
AgentId::new("test-agent")
}
fn connect(
&self,
_delegate: agent_servers::AgentServerDelegate,
_project: Entity<Project>,
_cx: &mut App,
) -> Task<anyhow::Result<Rc<dyn AgentConnection>>> {
Task::ready(Err(anyhow::anyhow!("test agent server cannot connect")))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
fn set_default_config_option(
&self,
config_id: &str,
value_id: Option<&str>,
_fs: Arc<dyn Fs>,
_cx: &mut App,
) {
self.saved_defaults.lock().push((
config_id.to_string(),
value_id.map(|value| value.to_string()),
));
}
}
struct TestSessionConfigOptions {
options: RefCell<Vec<acp::SessionConfigOption>>,
set_values: RefCell<Vec<(String, String)>>,
}
impl TestSessionConfigOptions {
fn new(options: Vec<acp::SessionConfigOption>) -> Self {
Self {
options: RefCell::new(options),
set_values: RefCell::new(Vec::new()),
}
}
}
impl AgentSessionConfigOptions for TestSessionConfigOptions {
fn config_options(&self) -> Vec<acp::SessionConfigOption> {
self.options.borrow().clone()
}
fn set_config_option(
&self,
config_id: acp::SessionConfigId,
value: acp::SessionConfigValueId,
_cx: &mut App,
) -> Task<anyhow::Result<Vec<acp::SessionConfigOption>>> {
self.set_values
.borrow_mut()
.push((config_id.0.to_string(), value.0.to_string()));
let options = {
let mut options = self.options.borrow_mut();
if let Some(option) = options.iter_mut().find(|option| option.id == config_id)
&& let acp::SessionConfigKind::Select(select) = &mut option.kind
{
select.current_value = value;
}
options.clone()
};
Task::ready(Ok(options))
}
}
}

View file

@ -11,10 +11,7 @@ use ui::{
prelude::*,
};
use crate::{
CycleModeSelector, ToggleProfileSelector,
ui::{HoldForDefault, documentation_aside_side},
};
use crate::{CycleModeSelector, ToggleProfileSelector, ui::documentation_aside_side};
pub struct ModeSelector {
connection: Rc<dyn AgentSessionModes>,
@ -45,6 +42,10 @@ impl ModeSelector {
pub fn cycle_mode(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
let all_modes = self.connection.all_modes();
if all_modes.is_empty() {
return;
}
let current_mode = self.connection.current_mode();
let current_index = all_modes
@ -52,8 +53,9 @@ impl ModeSelector {
.position(|mode| mode.id.0 == current_mode.0)
.unwrap_or(0);
let next_index = (current_index + 1) % all_modes.len();
self.set_mode(all_modes[next_index].id.clone(), cx);
if let Some(next_mode) = all_modes.get((current_index + 1) % all_modes.len()) {
self.set_mode(next_mode.id.clone(), cx);
}
}
pub fn mode(&self) -> acp::SessionModeId {
@ -61,6 +63,9 @@ impl ModeSelector {
}
pub fn set_mode(&mut self, mode: acp::SessionModeId, cx: &mut Context<Self>) {
self.agent_server
.set_default_mode(Some(mode.clone()), self.fs.clone(), cx);
let task = self.connection.set_mode(mode, cx);
self.setting_mode = true;
cx.notify();
@ -88,13 +93,11 @@ impl ModeSelector {
ContextMenu::build(window, cx, move |mut menu, _window, cx| {
let all_modes = self.connection.all_modes();
let current_mode = self.connection.current_mode();
let default_mode = self.agent_server.default_mode(cx);
let side = documentation_aside_side(cx);
for mode in all_modes {
let is_selected = &mode.id == &current_mode;
let is_default = Some(&mode.id) == default_mode.as_ref();
let entry = ContextMenuEntry::new(mode.name.clone())
.toggleable(IconPosition::End, is_selected);
@ -102,13 +105,7 @@ impl ModeSelector {
entry.documentation_aside(side, {
let description = description.clone();
move |_| {
v_flex()
.gap_1()
.child(Label::new(description.clone()))
.child(HoldForDefault::new(is_default))
.into_any_element()
}
move |_| Label::new(description.clone()).into_any_element()
})
} else {
entry
@ -117,21 +114,9 @@ impl ModeSelector {
menu.push_item(entry.handler({
let mode_id = mode.id.clone();
let weak_self = weak_self.clone();
move |window, cx| {
move |_window, cx| {
weak_self
.update(cx, |this, cx| {
if window.modifiers().secondary() {
this.agent_server.set_default_mode(
if is_default {
None
} else {
Some(mode_id.clone())
},
this.fs.clone(),
cx,
);
}
this.set_mode(mode_id.clone(), cx);
})
.ok();
@ -209,3 +194,110 @@ impl Render for ModeSelector {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use acp_thread::AgentConnection;
use fs::FakeFs;
use gpui::{App, Task, TestAppContext};
use parking_lot::Mutex;
use project::{AgentId, Project};
use std::{any::Any, cell::RefCell};
#[gpui::test]
fn setting_mode_saves_selected_mode_as_default(cx: &mut TestAppContext) {
let agent_server = Rc::new(TestAgentServer::default());
let session_modes = Rc::new(TestSessionModes::new());
let fs: Arc<dyn Fs> = FakeFs::new(cx.executor());
cx.update(|cx| {
let session_modes: Rc<dyn AgentSessionModes> = session_modes.clone();
let agent_server: Rc<dyn AgentServer> = agent_server.clone();
let selector = cx.new(|_| ModeSelector::new(session_modes, agent_server, fs));
selector.update(cx, |selector, cx| {
selector.set_mode(acp::SessionModeId::new("manual"), cx);
});
});
assert_eq!(
agent_server.saved_defaults.lock().as_slice(),
&[Some(acp::SessionModeId::new("manual"))]
);
assert_eq!(
session_modes.set_modes.borrow().as_slice(),
&[acp::SessionModeId::new("manual")]
);
}
#[derive(Default)]
struct TestAgentServer {
saved_defaults: Arc<Mutex<Vec<Option<acp::SessionModeId>>>>,
}
impl AgentServer for TestAgentServer {
fn logo(&self) -> IconName {
IconName::ZedAssistant
}
fn agent_id(&self) -> AgentId {
AgentId::new("test-agent")
}
fn connect(
&self,
_delegate: agent_servers::AgentServerDelegate,
_project: Entity<Project>,
_cx: &mut App,
) -> Task<anyhow::Result<Rc<dyn AgentConnection>>> {
Task::ready(Err(anyhow::anyhow!("test agent server cannot connect")))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
fn set_default_mode(
&self,
mode_id: Option<acp::SessionModeId>,
_fs: Arc<dyn Fs>,
_cx: &mut App,
) {
self.saved_defaults.lock().push(mode_id);
}
}
struct TestSessionModes {
current_mode: RefCell<acp::SessionModeId>,
set_modes: RefCell<Vec<acp::SessionModeId>>,
}
impl TestSessionModes {
fn new() -> Self {
Self {
current_mode: RefCell::new(acp::SessionModeId::new("auto")),
set_modes: RefCell::new(Vec::new()),
}
}
}
impl AgentSessionModes for TestSessionModes {
fn current_mode(&self) -> acp::SessionModeId {
self.current_mode.borrow().clone()
}
fn all_modes(&self) -> Vec<acp::SessionMode> {
vec![
acp::SessionMode::new("auto", "Auto"),
acp::SessionMode::new("manual", "Manual"),
]
}
fn set_mode(&self, mode: acp::SessionModeId, _cx: &mut App) -> Task<anyhow::Result<()>> {
*self.current_mode.borrow_mut() = mode.clone();
self.set_modes.borrow_mut().push(mode);
Task::ready(Ok(()))
}
}
}

View file

@ -22,8 +22,7 @@ use util::ResultExt;
use zed_actions::agent::OpenSettings;
use crate::ui::{
HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem,
documentation_aside_side,
ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem, documentation_aside_side,
};
pub type ModelSelector = Picker<ModelPickerDelegate>;
@ -55,7 +54,7 @@ pub struct ModelPickerDelegate {
filtered_entries: Vec<ModelPickerEntry>,
models: Option<AgentModelList>,
selected_index: usize,
selected_description: Option<(usize, SharedString, bool)>,
selected_description: Option<(usize, SharedString)>,
selected_model: Option<AgentModelInfo>,
favorites: HashSet<acp::ModelId>,
_refresh_models_task: Task<()>,
@ -182,6 +181,9 @@ impl ModelPickerDelegate {
let next_model = favorite_models[next_index].clone();
self.agent_server
.set_default_model(Some(next_model.id.clone()), self.fs.clone(), cx);
self.selector
.select_model(next_model.id.clone(), cx)
.detach_and_log_err(cx);
@ -277,20 +279,8 @@ impl PickerDelegate for ModelPickerDelegate {
if let Some(ModelPickerEntry::Model(model_info, _)) =
self.filtered_entries.get(self.selected_index)
{
if window.modifiers().secondary() {
let default_model = self.agent_server.default_model(cx);
let is_default = default_model.as_ref() == Some(&model_info.id);
self.agent_server.set_default_model(
if is_default {
None
} else {
Some(model_info.id.clone())
},
self.fs.clone(),
cx,
);
}
self.agent_server
.set_default_model(Some(model_info.id.clone()), self.fs.clone(), cx);
self.selector
.select_model(model_info.id.clone(), cx)
@ -322,8 +312,6 @@ impl PickerDelegate for ModelPickerDelegate {
}
ModelPickerEntry::Model(model_info, is_favorite) => {
let is_selected = Some(model_info) == self.selected_model.as_ref();
let default_model = self.agent_server.default_model(cx);
let is_default = default_model.as_ref() == Some(&model_info.id);
let is_favorite = *is_favorite;
let handle_action_click = {
@ -350,8 +338,8 @@ impl PickerDelegate for ModelPickerDelegate {
this.on_hover(cx.listener(move |menu, hovered, _, cx| {
if *hovered {
menu.delegate.selected_description =
Some((ix, description.clone(), is_default));
} else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
Some((ix, description.clone()));
} else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) {
menu.delegate.selected_description = None;
}
cx.notify();
@ -382,29 +370,20 @@ impl PickerDelegate for ModelPickerDelegate {
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<ui::DocumentationAside> {
self.selected_description
.as_ref()
.map(|(_, description, is_default)| {
let description = description.clone();
let is_default = *is_default;
self.selected_description.as_ref().map(|(_, description)| {
let description = description.clone();
let side = documentation_aside_side(cx);
let side = documentation_aside_side(cx);
DocumentationAside::new(
side,
Rc::new(move |_| {
v_flex()
.gap_1()
.child(Label::new(description.clone()))
.child(HoldForDefault::new(is_default))
.into_any_element()
}),
)
})
DocumentationAside::new(
side,
Rc::new(move |_| Label::new(description.clone()).into_any_element()),
)
})
}
fn documentation_aside_index(&self) -> Option<usize> {
self.selected_description.as_ref().map(|(ix, _, _)| *ix)
self.selected_description.as_ref().map(|(ix, _)| *ix)
}
fn render_footer(
@ -530,7 +509,12 @@ async fn fuzzy_search(
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use acp_thread::AgentConnection;
use fs::FakeFs;
use gpui::{App, Entity, TestAppContext, VisualTestContext};
use parking_lot::Mutex;
use project::{AgentId, Project};
use std::{any::Any, cell::RefCell};
use super::*;
@ -608,6 +592,138 @@ mod tests {
.collect()
}
#[gpui::test]
fn confirming_model_saves_selected_model_as_default(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = settings::SettingsStore::test(cx);
cx.set_global(settings_store);
theme_settings::init(theme::LoadThemes::JustBase, cx);
editor::init(cx);
});
let agent_server = Rc::new(TestAgentServer::default());
let model_selector = Rc::new(TestModelSelector::new());
let fs: Arc<dyn Fs> = FakeFs::new(cx.executor());
let window_handle = cx.add_window({
let agent_server = agent_server.clone();
let model_selector = model_selector.clone();
move |window, cx| {
let selector: Rc<dyn AgentModelSelector> = model_selector.clone();
let agent_server: Rc<dyn AgentServer> = agent_server.clone();
acp_model_selector(selector, agent_server, fs, cx.focus_handle(), window, cx)
}
});
cx.run_until_parked();
let mut cx = VisualTestContext::from_window(window_handle.into(), cx);
window_handle
.update(&mut cx, |picker, window, cx| {
picker.delegate.set_selected_index(1, window, cx);
picker.delegate.confirm(false, window, cx);
})
.unwrap();
assert_eq!(
agent_server.saved_defaults.lock().as_slice(),
&[Some(acp::ModelId::new("manual"))]
);
assert_eq!(
model_selector.selected_models.borrow().as_slice(),
&[acp::ModelId::new("manual")]
);
}
#[derive(Default)]
struct TestAgentServer {
saved_defaults: Arc<Mutex<Vec<Option<acp::ModelId>>>>,
}
impl AgentServer for TestAgentServer {
fn logo(&self) -> IconName {
IconName::ZedAssistant
}
fn agent_id(&self) -> AgentId {
AgentId::new("test-agent")
}
fn connect(
&self,
_delegate: agent_servers::AgentServerDelegate,
_project: Entity<Project>,
_cx: &mut App,
) -> Task<anyhow::Result<Rc<dyn AgentConnection>>> {
Task::ready(Err(anyhow::anyhow!("test agent server cannot connect")))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
fn set_default_model(
&self,
model_id: Option<acp::ModelId>,
_fs: Arc<dyn Fs>,
_cx: &mut App,
) {
self.saved_defaults.lock().push(model_id);
}
}
struct TestModelSelector {
models: Vec<AgentModelInfo>,
selected_model: RefCell<AgentModelInfo>,
selected_models: RefCell<Vec<acp::ModelId>>,
}
impl TestModelSelector {
fn new() -> Self {
let models = vec![
AgentModelInfo {
id: acp::ModelId::new("auto"),
name: "Auto".into(),
description: None,
icon: None,
is_latest: false,
cost: None,
},
AgentModelInfo {
id: acp::ModelId::new("manual"),
name: "Manual".into(),
description: None,
icon: None,
is_latest: false,
cost: None,
},
];
Self {
selected_model: RefCell::new(models[0].clone()),
models,
selected_models: RefCell::new(Vec::new()),
}
}
}
impl AgentModelSelector for TestModelSelector {
fn list_models(&self, _cx: &mut App) -> Task<Result<AgentModelList>> {
Task::ready(Ok(AgentModelList::Flat(self.models.clone())))
}
fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task<Result<()>> {
self.selected_models.borrow_mut().push(model_id.clone());
if let Some(model) = self.models.iter().find(|model| model.id == model_id) {
*self.selected_model.borrow_mut() = model.clone();
}
Task::ready(Ok(()))
}
fn selected_model(&self, _cx: &mut App) -> Task<Result<AgentModelInfo>> {
Task::ready(Ok(self.selected_model.borrow().clone()))
}
}
fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> {
entries
.iter()

View file

@ -1,13 +1,11 @@
mod agent_notification;
mod end_trial_upsell;
mod hold_for_default;
mod mention_crease;
mod model_selector_components;
mod undo_reject_toast;
pub use agent_notification::*;
pub use end_trial_upsell::*;
pub use hold_for_default::*;
pub use mention_crease::*;
pub use model_selector_components::*;
pub use undo_reject_toast::*;

View file

@ -1,52 +0,0 @@
use gpui::{App, IntoElement, Modifiers, RenderOnce, Window};
use ui::{prelude::*, render_modifiers};
#[derive(IntoElement)]
pub struct HoldForDefault {
is_default: bool,
more_content: bool,
}
impl HoldForDefault {
pub fn new(is_default: bool) -> Self {
Self {
is_default,
more_content: true,
}
}
#[allow(dead_code)]
pub fn more_content(mut self, more_content: bool) -> Self {
self.more_content = more_content;
self
}
}
impl RenderOnce for HoldForDefault {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
h_flex()
.when(self.more_content, |this| {
this.pt_1()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
})
.gap_0p5()
.text_sm()
.text_color(Color::Muted.color(cx))
.child("Hold")
.child(h_flex().flex_shrink_0().children(render_modifiers(
&Modifiers::secondary_key(),
PlatformStyle::platform(),
None,
Some(TextSize::Default.rems(cx).into()),
false,
)))
.child(div().map(|this| {
if self.is_default {
this.child("to unset as default")
} else {
this.child("to set as default")
}
}))
}
}