From 0c91f061c360468da86bd1cb88768ceae6f71308 Mon Sep 17 00:00:00 2001 From: "Oleksii (Alexey) Orlenko" Date: Tue, 16 Dec 2025 20:22:30 +0100 Subject: [PATCH] agent_ui: Implement favorite models selection (#44297) This PR solves my main pain point with Zed agent: I have a long list of available models from different providers, and I switch between a few of them depending on the context and the project. In particular, I use the same models from different providers depending on whether I'm working on a personal project or at my day job. Since I only care about a few models (none of which are in "recommended") that are scattered all over the list, switching between them is bothersome, even using search. This change adds a new option in `settings.json` (`agent.favorite_models`) and the UI to manipulate it directly from the list of available models. When any models are marked as favorites, they appear in a dedicated section at the very top of the list. Each model has a small icon button that appears on hover and allows to toggle whether it's marked as favorite. I implemented this on the UI level (i.e. there's no first-party knowledge about favorite models in the agent itself; in theory it could return favorite models as a group but it would make it harder to implement bespoke UI for the favorite models section and it also wouldn't work for text threads which don't use the ACP infrastructure). The feature is only enabled for the native agent but disabled for external agents because we can't easily map their model IDs to settings and there could be weird collisions between them. https://github.com/user-attachments/assets/cf23afe4-3883-45cb-9906-f55de3ea2a97 Closes https://github.com/zed-industries/zed/issues/31507 Release Notes: - Added the ability to mark language models as favorites and pin them to the top of the list. This feature is available in the native Zed agent (including text threads and the inline assistant), but not in external agents via ACP. --------- Co-authored-by: Danilo Leal Co-authored-by: Bennet Bo Fenner --- Cargo.lock | 1 + crates/acp_thread/src/connection.rs | 10 + crates/agent/src/agent.rs | 4 + crates/agent_settings/Cargo.toml | 1 + crates/agent_settings/src/agent_settings.rs | 12 +- crates/agent_ui/src/acp/model_selector.rs | 290 ++++++++++++++++-- .../manage_profiles_modal.rs | 43 ++- crates/agent_ui/src/agent_model_selector.rs | 37 ++- crates/agent_ui/src/agent_ui.rs | 2 + crates/agent_ui/src/favorite_models.rs | 57 ++++ .../agent_ui/src/language_model_selector.rs | 216 +++++++++++-- crates/agent_ui/src/text_thread_editor.rs | 36 ++- .../src/ui/model_selector_components.rs | 35 ++- crates/settings/src/settings_content/agent.rs | 13 + 14 files changed, 656 insertions(+), 101 deletions(-) create mode 100644 crates/agent_ui/src/favorite_models.rs diff --git a/Cargo.lock b/Cargo.lock index 2d0cb8235d5..6908a8ed518 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -301,6 +301,7 @@ dependencies = [ name = "agent_settings" version = "0.1.0" dependencies = [ + "agent-client-protocol", "anyhow", "cloud_llm_client", "collections", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 3c8c56b2c02..a670ba60115 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -202,6 +202,12 @@ pub trait AgentModelSelector: 'static { fn should_render_footer(&self) -> bool { false } + + /// Whether this selector supports the favorites feature. + /// Only the native agent uses the model ID format that maps to settings. + fn supports_favorites(&self) -> bool { + false + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -239,6 +245,10 @@ impl AgentModelList { AgentModelList::Grouped(groups) => groups.is_empty(), } } + + pub fn is_flat(&self) -> bool { + matches!(self, AgentModelList::Flat(_)) + } } #[cfg(feature = "test-support")] diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 693d3abd449..5e16f74682e 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -1164,6 +1164,10 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector { fn should_render_footer(&self) -> bool { true } + + fn supports_favorites(&self) -> bool { + true + } } impl acp_thread::AgentConnection for NativeAgentConnection { diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 8ddcac24fe0..0d7163549f0 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -12,6 +12,7 @@ workspace = true path = "src/agent_settings.rs" [dependencies] +agent-client-protocol.workspace = true anyhow.workspace = true cloud_llm_client.workspace = true collections.workspace = true diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 25ca5c78d6b..b513ec1a70b 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -2,7 +2,8 @@ mod agent_profile; use std::sync::Arc; -use collections::IndexMap; +use agent_client_protocol::ModelId; +use collections::{HashSet, IndexMap}; use gpui::{App, Pixels, px}; use language_model::LanguageModel; use project::DisableAiSettings; @@ -33,6 +34,7 @@ pub struct AgentSettings { pub commit_message_model: Option, pub thread_summary_model: Option, pub inline_alternatives: Vec, + pub favorite_models: Vec, pub default_profile: AgentProfileId, pub default_view: DefaultAgentView, pub profiles: IndexMap, @@ -96,6 +98,13 @@ impl AgentSettings { pub fn set_message_editor_max_lines(&self) -> usize { self.message_editor_min_lines * 2 } + + pub fn favorite_model_ids(&self) -> HashSet { + self.favorite_models + .iter() + .map(|sel| ModelId::new(format!("{}/{}", sel.provider.0, sel.model))) + .collect() + } } #[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)] @@ -164,6 +173,7 @@ impl Settings for AgentSettings { commit_message_model: agent.commit_message_model, thread_summary_model: agent.thread_summary_model, inline_alternatives: agent.inline_alternatives.unwrap_or_default(), + favorite_models: agent.favorite_models, default_profile: AgentProfileId(agent.default_profile.unwrap()), default_view: agent.default_view.unwrap(), profiles: agent diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index 658b88e0c2a..f885ff12e59 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -1,18 +1,22 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc}; use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use agent_client_protocol::ModelId; use agent_servers::AgentServer; +use agent_settings::AgentSettings; use anyhow::Result; -use collections::IndexMap; +use collections::{HashSet, IndexMap}; use fs::Fs; use futures::FutureExt; use fuzzy::{StringMatchCandidate, match_strings}; use gpui::{ Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity, }; +use itertools::Itertools; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; -use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, prelude::*}; +use settings::Settings; +use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*}; use util::ResultExt; use zed_actions::agent::OpenSettings; @@ -38,7 +42,7 @@ pub fn acp_model_selector( enum AcpModelPickerEntry { Separator(SharedString), - Model(AgentModelInfo), + Model(AgentModelInfo, bool), } pub struct AcpModelPickerDelegate { @@ -140,7 +144,7 @@ impl PickerDelegate for AcpModelPickerDelegate { _cx: &mut Context>, ) -> bool { match self.filtered_entries.get(ix) { - Some(AcpModelPickerEntry::Model(_)) => true, + Some(AcpModelPickerEntry::Model(_, _)) => true, Some(AcpModelPickerEntry::Separator(_)) | None => false, } } @@ -155,6 +159,12 @@ impl PickerDelegate for AcpModelPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { + let favorites = if self.selector.supports_favorites() { + Arc::new(AgentSettings::get_global(cx).favorite_model_ids()) + } else { + Default::default() + }; + cx.spawn_in(window, async move |this, cx| { let filtered_models = match this .read_with(cx, |this, cx| { @@ -171,7 +181,7 @@ impl PickerDelegate for AcpModelPickerDelegate { this.update_in(cx, |this, window, cx| { this.delegate.filtered_entries = - info_list_to_picker_entries(filtered_models).collect(); + info_list_to_picker_entries(filtered_models, favorites); // Finds the currently selected model in the list let new_index = this .delegate @@ -179,7 +189,7 @@ impl PickerDelegate for AcpModelPickerDelegate { .as_ref() .and_then(|selected| { this.delegate.filtered_entries.iter().position(|entry| { - if let AcpModelPickerEntry::Model(model_info) = entry { + if let AcpModelPickerEntry::Model(model_info, _) = entry { model_info.id == selected.id } else { false @@ -195,7 +205,7 @@ impl PickerDelegate for AcpModelPickerDelegate { } fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { - if let Some(AcpModelPickerEntry::Model(model_info)) = + if let Some(AcpModelPickerEntry::Model(model_info, _)) = self.filtered_entries.get(self.selected_index) { if window.modifiers().secondary() { @@ -233,7 +243,7 @@ impl PickerDelegate for AcpModelPickerDelegate { fn render_match( &self, ix: usize, - is_focused: bool, + selected: bool, _: &mut Window, cx: &mut Context>, ) -> Option { @@ -241,32 +251,53 @@ impl PickerDelegate for AcpModelPickerDelegate { AcpModelPickerEntry::Separator(title) => { Some(ModelSelectorHeader::new(title, ix > 1).into_any_element()) } - AcpModelPickerEntry::Model(model_info) => { + AcpModelPickerEntry::Model(model_info, is_favorite) => { let is_selected = Some(model_info) == self.selected_model.as_ref(); let default_model = self.agent_server.default_model(cx); let is_default = default_model.as_ref() == Some(&model_info.id); + let supports_favorites = self.selector.supports_favorites(); + + let is_favorite = *is_favorite; + let handle_action_click = { + let model_id = model_info.id.clone(); + let fs = self.fs.clone(); + + move |cx: &App| { + crate::favorite_models::toggle_model_id_in_settings( + model_id.clone(), + !is_favorite, + fs.clone(), + cx, + ); + } + }; + Some( div() .id(("model-picker-menu-child", ix)) .when_some(model_info.description.clone(), |this, description| { - this - .on_hover(cx.listener(move |menu, hovered, _, cx| { - if *hovered { - menu.delegate.selected_description = Some((ix, description.clone(), is_default)); - } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) { - menu.delegate.selected_description = None; - } - cx.notify(); - })) + this.on_hover(cx.listener(move |menu, hovered, _, cx| { + if *hovered { + menu.delegate.selected_description = + Some((ix, description.clone(), is_default)); + } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) { + menu.delegate.selected_description = None; + } + cx.notify(); + })) }) .child( ModelSelectorListItem::new(ix, model_info.name.clone()) - .is_focused(is_focused) + .when_some(model_info.icon, |this, icon| this.icon(icon)) .is_selected(is_selected) - .when_some(model_info.icon, |this, icon| this.icon(icon)), + .is_focused(selected) + .when(supports_favorites, |this| { + this.is_favorite(is_favorite) + .on_toggle_favorite(handle_action_click) + }), ) - .into_any_element() + .into_any_element(), ) } } @@ -314,18 +345,51 @@ impl PickerDelegate for AcpModelPickerDelegate { fn info_list_to_picker_entries( model_list: AgentModelList, -) -> impl Iterator { - match model_list { - AgentModelList::Flat(list) => { - itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model)) - } - AgentModelList::Grouped(index_map) => { - itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| { - std::iter::once(AcpModelPickerEntry::Separator(group_name.0)) - .chain(models.into_iter().map(AcpModelPickerEntry::Model)) - })) + favorites: Arc>, +) -> Vec { + let mut entries = Vec::new(); + + let all_models: Vec<_> = match &model_list { + AgentModelList::Flat(list) => list.iter().collect(), + AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(), + }; + + let favorite_models: Vec<_> = all_models + .iter() + .filter(|m| favorites.contains(&m.id)) + .unique_by(|m| &m.id) + .collect(); + + let has_favorites = !favorite_models.is_empty(); + if has_favorites { + entries.push(AcpModelPickerEntry::Separator("Favorite".into())); + for model in favorite_models { + entries.push(AcpModelPickerEntry::Model((*model).clone(), true)); } } + + match model_list { + AgentModelList::Flat(list) => { + if has_favorites { + entries.push(AcpModelPickerEntry::Separator("All".into())); + } + for model in list { + let is_favorite = favorites.contains(&model.id); + entries.push(AcpModelPickerEntry::Model(model, is_favorite)); + } + } + AgentModelList::Grouped(index_map) => { + for (group_name, models) in index_map { + entries.push(AcpModelPickerEntry::Separator(group_name.0)); + for model in models { + let is_favorite = favorites.contains(&model.id); + entries.push(AcpModelPickerEntry::Model(model, is_favorite)); + } + } + } + } + + entries } async fn fuzzy_search( @@ -447,6 +511,170 @@ mod tests { } } + fn create_favorites(models: Vec<&str>) -> Arc> { + Arc::new( + models + .into_iter() + .map(|m| ModelId::new(m.to_string())) + .collect(), + ) + } + + fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> { + entries + .iter() + .filter_map(|entry| match entry { + AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()), + _ => None, + }) + .collect() + } + + fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> { + entries + .iter() + .map(|entry| match entry { + AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(), + AcpModelPickerEntry::Separator(s) => &s, + }) + .collect() + } + + #[gpui::test] + fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("zed", vec!["zed/claude", "zed/gemini"]), + ("openai", vec!["openai/gpt-5"]), + ]); + let favorites = create_favorites(vec!["zed/gemini"]); + + let entries = info_list_to_picker_entries(models, favorites); + + assert!(matches!( + entries.first(), + Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite" + )); + + let model_ids = get_entry_model_ids(&entries); + assert_eq!(model_ids[0], "zed/gemini"); + } + + #[gpui::test] + fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) { + let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]); + let favorites = create_favorites(vec![]); + + let entries = info_list_to_picker_entries(models, favorites); + + assert!(matches!( + entries.first(), + Some(AcpModelPickerEntry::Separator(s)) if s == "zed" + )); + } + + #[gpui::test] + fn test_models_have_correct_actions(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("zed", vec!["zed/claude", "zed/gemini"]), + ("openai", vec!["openai/gpt-5"]), + ]); + let favorites = create_favorites(vec!["zed/claude"]); + + let entries = info_list_to_picker_entries(models, favorites); + + for entry in &entries { + if let AcpModelPickerEntry::Model(info, is_favorite) = entry { + if info.id.0.as_ref() == "zed/claude" { + assert!(is_favorite, "zed/claude should be a favorite"); + } else { + assert!(!is_favorite, "{} should not be a favorite", info.id.0); + } + } + } + } + + #[gpui::test] + fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("zed", vec!["zed/claude", "zed/gemini"]), + ("openai", vec!["openai/gpt-5", "openai/gpt-4"]), + ]); + let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]); + + let entries = info_list_to_picker_entries(models, favorites); + let model_ids = get_entry_model_ids(&entries); + + assert_eq!(model_ids[0], "zed/gemini"); + assert_eq!(model_ids[1], "openai/gpt-5"); + + assert!(model_ids[2..].contains(&"zed/gemini")); + assert!(model_ids[2..].contains(&"openai/gpt-5")); + } + + #[gpui::test] + fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("Recommended", vec!["zed/claude", "anthropic/claude"]), + ("Zed", vec!["zed/claude", "zed/gpt-5"]), + ("Antropic", vec!["anthropic/claude"]), + ("OpenAI", vec!["openai/gpt-5"]), + ]); + + let favorites = create_favorites(vec!["zed/claude"]); + + let entries = info_list_to_picker_entries(models, favorites); + let labels = get_entry_labels(&entries); + + assert_eq!( + labels, + vec![ + "Favorite", + "zed/claude", + "Recommended", + "zed/claude", + "anthropic/claude", + "Zed", + "zed/claude", + "zed/gpt-5", + "Antropic", + "anthropic/claude", + "OpenAI", + "openai/gpt-5" + ] + ); + } + + #[gpui::test] + fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) { + let models = AgentModelList::Flat(vec![ + acp_thread::AgentModelInfo { + id: acp::ModelId::new("zed/claude".to_string()), + name: "Claude".into(), + description: None, + icon: None, + }, + acp_thread::AgentModelInfo { + id: acp::ModelId::new("zed/gemini".to_string()), + name: "Gemini".into(), + description: None, + icon: None, + }, + ]); + let favorites = create_favorites(vec!["zed/gemini"]); + + let entries = info_list_to_picker_entries(models, favorites); + + assert!(matches!( + entries.first(), + Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite" + )); + + assert!(entries.iter().any(|e| matches!( + e, + AcpModelPickerEntry::Separator(s) if s == "All" + ))); + } + #[gpui::test] async fn test_fuzzy_match(cx: &mut TestAppContext) { let models = create_model_list(vec![ diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index ed00b2b5c71..127852fd50e 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -222,7 +222,6 @@ impl ManageProfilesModal { let profile_id_for_closure = profile_id.clone(); let model_picker = cx.new(|cx| { - let fs = fs.clone(); let profile_id = profile_id_for_closure.clone(); language_model_selector( @@ -250,22 +249,36 @@ impl ManageProfilesModal { }) } }, - move |model, cx| { - let provider = model.provider_id().0.to_string(); - let model_id = model.id().0.to_string(); - let profile_id = profile_id.clone(); + { + let fs = fs.clone(); + move |model, cx| { + let provider = model.provider_id().0.to_string(); + let model_id = model.id().0.to_string(); + let profile_id = profile_id.clone(); - update_settings_file(fs.clone(), cx, move |settings, _cx| { - let agent_settings = settings.agent.get_or_insert_default(); - if let Some(profiles) = agent_settings.profiles.as_mut() { - if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) { - profile.default_model = Some(LanguageModelSelection { - provider: LanguageModelProviderSetting(provider.clone()), - model: model_id.clone(), - }); + update_settings_file(fs.clone(), cx, move |settings, _cx| { + let agent_settings = settings.agent.get_or_insert_default(); + if let Some(profiles) = agent_settings.profiles.as_mut() { + if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) { + profile.default_model = Some(LanguageModelSelection { + provider: LanguageModelProviderSetting(provider.clone()), + model: model_id.clone(), + }); + } } - } - }); + }); + } + }, + { + let fs = fs.clone(); + move |model, should_be_favorite, cx| { + crate::favorite_models::toggle_in_settings( + model, + should_be_favorite, + fs.clone(), + cx, + ); + } }, false, // Do not use popover styles for the model picker self.focus_handle.clone(), diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index 9c263414309..ac57ed575d9 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -29,26 +29,39 @@ impl AgentModelSelector { Self { selector: cx.new(move |cx| { - let fs = fs.clone(); language_model_selector( { let model_context = model_usage_context.clone(); move |cx| model_context.configured_model(cx) }, - move |model, cx| { - let provider = model.provider_id().0.to_string(); - let model_id = model.id().0.to_string(); - match &model_usage_context { - ModelUsageContext::InlineAssistant => { - update_settings_file(fs.clone(), cx, move |settings, _cx| { - settings - .agent - .get_or_insert_default() - .set_inline_assistant_model(provider.clone(), model_id); - }); + { + let fs = fs.clone(); + move |model, cx| { + let provider = model.provider_id().0.to_string(); + let model_id = model.id().0.to_string(); + match &model_usage_context { + ModelUsageContext::InlineAssistant => { + update_settings_file(fs.clone(), cx, move |settings, _cx| { + settings + .agent + .get_or_insert_default() + .set_inline_assistant_model(provider.clone(), model_id); + }); + } } } }, + { + let fs = fs.clone(); + move |model, should_be_favorite, cx| { + crate::favorite_models::toggle_in_settings( + model, + should_be_favorite, + fs.clone(), + cx, + ); + } + }, true, // Use popover styles for picker focus_handle_clone, window, diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 4f759d6a9c7..1622d17f585 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -7,6 +7,7 @@ mod buffer_codegen; mod completion_provider; mod context; mod context_server_configuration; +mod favorite_models; mod inline_assistant; mod inline_prompt_editor; mod language_model_selector; @@ -467,6 +468,7 @@ mod tests { commit_message_model: None, thread_summary_model: None, inline_alternatives: vec![], + favorite_models: vec![], default_profile: AgentProfileId::default(), default_view: DefaultAgentView::Thread, profiles: Default::default(), diff --git a/crates/agent_ui/src/favorite_models.rs b/crates/agent_ui/src/favorite_models.rs new file mode 100644 index 00000000000..d8d4db976fc --- /dev/null +++ b/crates/agent_ui/src/favorite_models.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use agent_client_protocol::ModelId; +use fs::Fs; +use language_model::LanguageModel; +use settings::{LanguageModelSelection, update_settings_file}; +use ui::App; + +fn language_model_to_selection(model: &Arc) -> LanguageModelSelection { + LanguageModelSelection { + provider: model.provider_id().to_string().into(), + model: model.id().0.to_string(), + } +} + +fn model_id_to_selection(model_id: &ModelId) -> LanguageModelSelection { + let id = model_id.0.as_ref(); + let (provider, model) = id.split_once('/').unwrap_or(("", id)); + LanguageModelSelection { + provider: provider.to_owned().into(), + model: model.to_owned(), + } +} + +pub fn toggle_in_settings( + model: Arc, + should_be_favorite: bool, + fs: Arc, + cx: &App, +) { + let selection = language_model_to_selection(&model); + update_settings_file(fs, cx, move |settings, _| { + let agent = settings.agent.get_or_insert_default(); + if should_be_favorite { + agent.add_favorite_model(selection.clone()); + } else { + agent.remove_favorite_model(&selection); + } + }); +} + +pub fn toggle_model_id_in_settings( + model_id: ModelId, + should_be_favorite: bool, + fs: Arc, + cx: &App, +) { + let selection = model_id_to_selection(&model_id); + update_settings_file(fs, cx, move |settings, _| { + let agent = settings.agent.get_or_insert_default(); + if should_be_favorite { + agent.add_favorite_model(selection.clone()); + } else { + agent.remove_favorite_model(&selection); + } + }); +} diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7e1c35eba45..7bb42fb330d 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,16 +1,18 @@ use std::{cmp::Reverse, sync::Arc}; -use collections::IndexMap; +use agent_settings::AgentSettings; +use collections::{HashMap, HashSet, IndexMap}; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; use gpui::{ Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task, }; use language_model::{ - AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, - LanguageModelRegistry, + AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider, + LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; +use settings::Settings; use ui::prelude::*; use zed_actions::agent::OpenSettings; @@ -18,12 +20,14 @@ use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem} type OnModelChanged = Arc, &mut App) + 'static>; type GetActiveModel = Arc Option + 'static>; +type OnToggleFavorite = Arc, bool, &App) + 'static>; pub type LanguageModelSelector = Picker; pub fn language_model_selector( get_active_model: impl Fn(&App) -> Option + 'static, on_model_changed: impl Fn(Arc, &mut App) + 'static, + on_toggle_favorite: impl Fn(Arc, bool, &App) + 'static, popover_styles: bool, focus_handle: FocusHandle, window: &mut Window, @@ -32,6 +36,7 @@ pub fn language_model_selector( let delegate = LanguageModelPickerDelegate::new( get_active_model, on_model_changed, + on_toggle_favorite, popover_styles, focus_handle, window, @@ -49,7 +54,17 @@ pub fn language_model_selector( } fn all_models(cx: &App) -> GroupedModels { - let providers = LanguageModelRegistry::global(cx).read(cx).providers(); + let lm_registry = LanguageModelRegistry::global(cx).read(cx); + let providers = lm_registry.providers(); + + let mut favorites_index = FavoritesIndex::default(); + + for sel in &AgentSettings::get_global(cx).favorite_models { + favorites_index + .entry(sel.provider.0.clone().into()) + .or_default() + .insert(sel.model.clone().into()); + } let recommended = providers .iter() @@ -57,10 +72,7 @@ fn all_models(cx: &App) -> GroupedModels { provider .recommended_models(cx) .into_iter() - .map(|model| ModelInfo { - model, - icon: provider.icon(), - }) + .map(|model| ModelInfo::new(&**provider, model, &favorites_index)) }) .collect(); @@ -70,25 +82,44 @@ fn all_models(cx: &App) -> GroupedModels { provider .provided_models(cx) .into_iter() - .map(|model| ModelInfo { - model, - icon: provider.icon(), - }) + .map(|model| ModelInfo::new(&**provider, model, &favorites_index)) }) .collect(); GroupedModels::new(all, recommended) } +type FavoritesIndex = HashMap>; + #[derive(Clone)] struct ModelInfo { model: Arc, icon: IconName, + is_favorite: bool, +} + +impl ModelInfo { + fn new( + provider: &dyn LanguageModelProvider, + model: Arc, + favorites_index: &FavoritesIndex, + ) -> Self { + let is_favorite = favorites_index + .get(&provider.id()) + .map_or(false, |set| set.contains(&model.id())); + + Self { + model, + icon: provider.icon(), + is_favorite, + } + } } pub struct LanguageModelPickerDelegate { on_model_changed: OnModelChanged, get_active_model: GetActiveModel, + on_toggle_favorite: OnToggleFavorite, all_models: Arc, filtered_entries: Vec, selected_index: usize, @@ -102,6 +133,7 @@ impl LanguageModelPickerDelegate { fn new( get_active_model: impl Fn(&App) -> Option + 'static, on_model_changed: impl Fn(Arc, &mut App) + 'static, + on_toggle_favorite: impl Fn(Arc, bool, &App) + 'static, popover_styles: bool, focus_handle: FocusHandle, window: &mut Window, @@ -117,6 +149,7 @@ impl LanguageModelPickerDelegate { selected_index: Self::get_active_model_index(&entries, get_active_model(cx)), filtered_entries: entries, get_active_model: Arc::new(get_active_model), + on_toggle_favorite: Arc::new(on_toggle_favorite), _authenticate_all_providers_task: Self::authenticate_all_providers(cx), _subscriptions: vec![cx.subscribe_in( &LanguageModelRegistry::global(cx), @@ -219,12 +252,19 @@ impl LanguageModelPickerDelegate { } struct GroupedModels { + favorites: Vec, recommended: Vec, all: IndexMap>, } impl GroupedModels { pub fn new(all: Vec, recommended: Vec) -> Self { + let favorites = all + .iter() + .filter(|info| info.is_favorite) + .cloned() + .collect(); + let mut all_by_provider: IndexMap<_, Vec> = IndexMap::default(); for model in all { let provider = model.model.provider_id(); @@ -236,6 +276,7 @@ impl GroupedModels { } Self { + favorites, recommended, all: all_by_provider, } @@ -244,13 +285,18 @@ impl GroupedModels { fn entries(&self) -> Vec { let mut entries = Vec::new(); + if !self.favorites.is_empty() { + entries.push(LanguageModelPickerEntry::Separator("Favorite".into())); + for info in &self.favorites { + entries.push(LanguageModelPickerEntry::Model(info.clone())); + } + } + if !self.recommended.is_empty() { entries.push(LanguageModelPickerEntry::Separator("Recommended".into())); - entries.extend( - self.recommended - .iter() - .map(|info| LanguageModelPickerEntry::Model(info.clone())), - ); + for info in &self.recommended { + entries.push(LanguageModelPickerEntry::Model(info.clone())); + } } for models in self.all.values() { @@ -260,12 +306,11 @@ impl GroupedModels { entries.push(LanguageModelPickerEntry::Separator( models[0].model.provider_name().0, )); - entries.extend( - models - .iter() - .map(|info| LanguageModelPickerEntry::Model(info.clone())), - ); + for info in models { + entries.push(LanguageModelPickerEntry::Model(info.clone())); + } } + entries } } @@ -461,7 +506,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { fn render_match( &self, ix: usize, - is_focused: bool, + selected: bool, _: &mut Window, cx: &mut Context>, ) -> Option { @@ -477,11 +522,20 @@ impl PickerDelegate for LanguageModelPickerDelegate { let is_selected = Some(model_info.model.provider_id()) == active_provider_id && Some(model_info.model.id()) == active_model_id; + let is_favorite = model_info.is_favorite; + let handle_action_click = { + let model = model_info.model.clone(); + let on_toggle_favorite = self.on_toggle_favorite.clone(); + move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx) + }; + Some( ModelSelectorListItem::new(ix, model_info.model.name().0) - .is_focused(is_focused) - .is_selected(is_selected) .icon(model_info.icon) + .is_selected(is_selected) + .is_focused(selected) + .is_favorite(is_favorite) + .on_toggle_favorite(handle_action_click) .into_any_element(), ) } @@ -493,12 +547,12 @@ impl PickerDelegate for LanguageModelPickerDelegate { _window: &mut Window, _cx: &mut Context>, ) -> Option { + let focus_handle = self.focus_handle.clone(); + if !self.popover_styles { return None; } - let focus_handle = self.focus_handle.clone(); - Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element()) } } @@ -598,11 +652,24 @@ mod tests { } fn create_models(model_specs: Vec<(&str, &str)>) -> Vec { + create_models_with_favorites(model_specs, vec![]) + } + + fn create_models_with_favorites( + model_specs: Vec<(&str, &str)>, + favorites: Vec<(&str, &str)>, + ) -> Vec { model_specs .into_iter() - .map(|(provider, name)| ModelInfo { - model: Arc::new(TestLanguageModel::new(name, provider)), - icon: IconName::Ai, + .map(|(provider, name)| { + let is_favorite = favorites + .iter() + .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name); + ModelInfo { + model: Arc::new(TestLanguageModel::new(name, provider)), + icon: IconName::Ai, + is_favorite, + } }) .collect() } @@ -740,4 +807,93 @@ mod tests { vec!["zed/claude", "zed/gemini", "copilot/claude"], ); } + + #[gpui::test] + fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) { + let recommended_models = create_models(vec![("zed", "claude")]); + let all_models = create_models_with_favorites( + vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")], + vec![("zed", "gemini")], + ); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + let entries = grouped_models.entries(); + + assert!(matches!( + entries.first(), + Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite" + )); + + assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]); + } + + #[gpui::test] + fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) { + let recommended_models = create_models(vec![("zed", "claude")]); + let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + let entries = grouped_models.entries(); + + assert!(matches!( + entries.first(), + Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended" + )); + + assert!(grouped_models.favorites.is_empty()); + } + + #[gpui::test] + fn test_models_have_correct_actions(_cx: &mut TestAppContext) { + let recommended_models = + create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]); + let all_models = create_models_with_favorites( + vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")], + vec![("zed", "claude")], + ); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + let entries = grouped_models.entries(); + + for entry in &entries { + if let LanguageModelPickerEntry::Model(info) = entry { + if info.model.telemetry_id() == "zed/claude" { + assert!(info.is_favorite, "zed/claude should be a favorite"); + } else { + assert!( + !info.is_favorite, + "{} should not be a favorite", + info.model.telemetry_id() + ); + } + } + } + } + + #[gpui::test] + fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) { + let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")]; + + let recommended_models = + create_models_with_favorites(vec![("zed", "claude")], favorites.clone()); + + let all_models = create_models_with_favorites( + vec![ + ("zed", "claude"), + ("zed", "gemini"), + ("openai", "gpt-4"), + ("openai", "gpt-3.5"), + ], + favorites, + ); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + + assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]); + assert_models_eq(grouped_models.recommended, vec!["zed/claude"]); + assert_models_eq( + grouped_models.all.values().flatten().cloned().collect(), + vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"], + ); + } } diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 5e3f348c17d..881eb213a38 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -304,17 +304,31 @@ impl TextThreadEditor { language_model_selector: cx.new(|cx| { language_model_selector( |cx| LanguageModelRegistry::read_global(cx).default_model(), - move |model, cx| { - update_settings_file(fs.clone(), cx, move |settings, _| { - let provider = model.provider_id().0.to_string(); - let model = model.id().0.to_string(); - settings.agent.get_or_insert_default().set_model( - LanguageModelSelection { - provider: LanguageModelProviderSetting(provider), - model, - }, - ) - }); + { + let fs = fs.clone(); + move |model, cx| { + update_settings_file(fs.clone(), cx, move |settings, _| { + let provider = model.provider_id().0.to_string(); + let model = model.id().0.to_string(); + settings.agent.get_or_insert_default().set_model( + LanguageModelSelection { + provider: LanguageModelProviderSetting(provider), + model, + }, + ) + }); + } + }, + { + let fs = fs.clone(); + move |model, should_be_favorite, cx| { + crate::favorite_models::toggle_in_settings( + model, + should_be_favorite, + fs.clone(), + cx, + ); + } }, true, // Use popover styles for picker focus_handle, diff --git a/crates/agent_ui/src/ui/model_selector_components.rs b/crates/agent_ui/src/ui/model_selector_components.rs index 3218daef7c9..184c8e0ba2d 100644 --- a/crates/agent_ui/src/ui/model_selector_components.rs +++ b/crates/agent_ui/src/ui/model_selector_components.rs @@ -1,5 +1,5 @@ use gpui::{Action, FocusHandle, prelude::*}; -use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*}; +use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*}; #[derive(IntoElement)] pub struct ModelSelectorHeader { @@ -42,6 +42,8 @@ pub struct ModelSelectorListItem { icon: Option, is_selected: bool, is_focused: bool, + is_favorite: bool, + on_toggle_favorite: Option>, } impl ModelSelectorListItem { @@ -52,6 +54,8 @@ impl ModelSelectorListItem { icon: None, is_selected: false, is_focused: false, + is_favorite: false, + on_toggle_favorite: None, } } @@ -69,6 +73,16 @@ impl ModelSelectorListItem { self.is_focused = is_focused; self } + + pub fn is_favorite(mut self, is_favorite: bool) -> Self { + self.is_favorite = is_favorite; + self + } + + pub fn on_toggle_favorite(mut self, handler: impl Fn(&App) + 'static) -> Self { + self.on_toggle_favorite = Some(Box::new(handler)); + self + } } impl RenderOnce for ModelSelectorListItem { @@ -79,6 +93,8 @@ impl RenderOnce for ModelSelectorListItem { Color::Muted }; + let is_favorite = self.is_favorite; + ListItem::new(self.index) .inset(true) .spacing(ListItemSpacing::Sparse) @@ -103,6 +119,23 @@ impl RenderOnce for ModelSelectorListItem { .size(IconSize::Small), ) })) + .end_hover_slot(div().pr_2().when_some(self.on_toggle_favorite, { + |this, handle_click| { + let (icon, color, tooltip) = if is_favorite { + (IconName::StarFilled, Color::Accent, "Unfavorite Model") + } else { + (IconName::Star, Color::Default, "Favorite Model") + }; + this.child( + IconButton::new(("toggle-favorite", self.index), icon) + .layer(ElevationIndex::ElevatedSurface) + .icon_color(color) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text(tooltip)) + .on_click(move |_, _, cx| (handle_click)(cx)), + ) + } + })) } } diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index f7a88deb7d8..d3a8e40084f 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -38,6 +38,9 @@ pub struct AgentSettingsContent { pub default_height: Option, /// The default model to use when creating new chats and for other features when a specific model is not specified. pub default_model: Option, + /// Favorite models to show at the top of the model selector. + #[serde(default)] + pub favorite_models: Vec, /// Model to use for the inline assistant. Defaults to default_model when not specified. pub inline_assistant_model: Option, /// Model to use for the inline assistant when streaming tools are enabled. @@ -176,6 +179,16 @@ impl AgentSettingsContent { pub fn set_profile(&mut self, profile_id: Arc) { self.default_profile = Some(profile_id); } + + pub fn add_favorite_model(&mut self, model: LanguageModelSelection) { + if !self.favorite_models.contains(&model) { + self.favorite_models.push(model); + } + } + + pub fn remove_favorite_model(&mut self, model: &LanguageModelSelection) { + self.favorite_models.retain(|m| m != model); + } } #[with_fallible_options]