Add support for provider extensions (but no extensions yet) (#45277)

This adds support for provider extensions but doesn't actually add any
yet.

Release Notes:

- N/A
This commit is contained in:
Richard Feldman 2025-12-18 17:05:04 -05:00 committed by GitHub
parent 88f90c12ed
commit 6055b45ee1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 585 additions and 121 deletions

2
Cargo.lock generated
View file

@ -8932,6 +8932,8 @@ dependencies = [
"credentials_provider",
"deepseek",
"editor",
"extension",
"extension_host",
"fs",
"futures 0.3.31",
"google_ai",

View file

@ -210,12 +210,21 @@ pub trait AgentModelSelector: 'static {
}
}
/// Icon for a model in the model selector.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentModelIcon {
/// A built-in icon from Zed's icon set.
Named(IconName),
/// Path to a custom SVG icon file.
Path(SharedString),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentModelInfo {
pub id: acp::ModelId,
pub name: SharedString,
pub description: Option<SharedString>,
pub icon: Option<IconName>,
pub icon: Option<AgentModelIcon>,
}
impl From<acp::ModelInfo> for AgentModelInfo {

View file

@ -30,7 +30,7 @@ use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
@ -93,7 +93,7 @@ impl LanguageModels {
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@ -153,7 +153,10 @@ impl LanguageModels {
id: Self::model_id(model),
name: model.name().0,
description: None,
icon: Some(provider.icon()),
icon: Some(match provider.icon() {
IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
}),
}
}
@ -164,7 +167,7 @@ impl LanguageModels {
fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
let authenticate_all_providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.visible_providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
@ -1630,7 +1633,9 @@ mod internal_tests {
id: acp::ModelId::new("fake/fake"),
name: "Fake".into(),
description: None,
icon: Some(ui::IconName::ZedAssistant),
icon: Some(acp_thread::AgentModelIcon::Named(
ui::IconName::ZedAssistant
)),
}]
)])
);

View file

@ -1,6 +1,6 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_client_protocol::ModelId;
use agent_servers::AgentServer;
use agent_settings::AgentSettings;
@ -350,7 +350,11 @@ impl PickerDelegate for AcpModelPickerDelegate {
})
.child(
ModelSelectorListItem::new(ix, model_info.name.clone())
.when_some(model_info.icon, |this, icon| this.icon(icon))
.map(|this| match &model_info.icon {
Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
None => this,
})
.is_selected(is_selected)
.is_focused(selected)
.when(supports_favorites, |this| {

View file

@ -1,7 +1,7 @@
use std::rc::Rc;
use std::sync::Arc;
use acp_thread::{AgentModelInfo, AgentModelSelector};
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
use agent_servers::AgentServer;
use agent_settings::AgentSettings;
use fs::Fs;
@ -70,7 +70,7 @@ impl Render for AcpModelSelectorPopover {
.map(|model| model.name.clone())
.unwrap_or_else(|| SharedString::from("Select a Model"));
let model_icon = model.as_ref().and_then(|model| model.icon);
let model_icon = model.as_ref().and_then(|model| model.icon.clone());
let focus_handle = self.focus_handle.clone();
@ -125,7 +125,14 @@ impl Render for AcpModelSelectorPopover {
ButtonLike::new("active-model")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.when_some(model_icon, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
this.child(
match icon {
AgentModelIcon::Path(path) => Icon::from_external_svg(path),
AgentModelIcon::Named(icon_name) => Icon::new(icon_name),
}
.color(color)
.size(IconSize::XSmall),
)
})
.child(
Label::new(model_name)

View file

@ -22,7 +22,8 @@ use gpui::{
};
use language::LanguageRegistry;
use language_model::{
LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
IconOrSvg, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
ZED_CLOUD_PROVIDER_ID,
};
use language_models::AllLanguageModelSettings;
use notifications::status_toast::{StatusToast, ToastIcon};
@ -117,7 +118,7 @@ impl AgentConfiguration {
}
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let providers = LanguageModelRegistry::read_global(cx).providers();
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
for provider in providers {
self.add_provider_configuration_view(&provider, window, cx);
}
@ -261,9 +262,12 @@ impl AgentConfiguration {
.w_full()
.gap_1p5()
.child(
Icon::new(provider.icon())
.size(IconSize::Small)
.color(Color::Muted),
match provider.icon() {
IconOrSvg::Svg(path) => Icon::from_external_svg(path),
IconOrSvg::Icon(name) => Icon::new(name),
}
.size(IconSize::Small)
.color(Color::Muted),
)
.child(
h_flex()
@ -416,7 +420,7 @@ impl AgentConfiguration {
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
let popover_menu = PopoverMenu::new("add-provider-popover")
.trigger(

View file

@ -4,6 +4,7 @@ use crate::{
};
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use language_model::IconOrSvg;
use picker::popover_menu::PickerPopoverMenu;
use settings::update_settings_file;
use std::sync::Arc;
@ -103,7 +104,14 @@ impl Render for AgentModelSelector {
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(provider_icon, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
this.child(
match icon {
IconOrSvg::Svg(path) => Icon::from_external_svg(path),
IconOrSvg::Icon(name) => Icon::new(name),
}
.color(color)
.size(IconSize::XSmall),
)
})
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.child(
@ -115,7 +123,7 @@ impl Render for AgentModelSelector {
.child(
Icon::new(IconName::ChevronDown)
.color(color)
.size(IconSize::Small),
.size(IconSize::XSmall),
),
move |_window, cx| {
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)

View file

@ -2428,7 +2428,7 @@ impl AgentPanel {
let history_is_empty = self.history_store.read(cx).is_empty(cx);
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.any(|provider| {
provider.is_authenticated(cx)

View file

@ -348,7 +348,8 @@ fn init_language_model_settings(cx: &mut App) {
|_, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
update_active_language_model_from_settings(cx);
}
_ => {}

View file

@ -7,8 +7,8 @@ use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider,
LanguageModelProviderId, LanguageModelRegistry,
AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@ -55,7 +55,7 @@ pub fn language_model_selector(
fn all_models(cx: &App) -> GroupedModels {
let lm_registry = LanguageModelRegistry::global(cx).read(cx);
let providers = lm_registry.providers();
let providers = lm_registry.visible_providers();
let mut favorites_index = FavoritesIndex::default();
@ -94,7 +94,7 @@ type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
icon: IconOrSvg,
is_favorite: bool,
}
@ -203,7 +203,7 @@ impl LanguageModelPickerDelegate {
fn authenticate_all_providers(cx: &mut App) -> Task<()> {
let authenticate_all_providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.visible_providers()
.iter()
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
@ -474,7 +474,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let configured_providers = language_model_registry
.read(cx)
.providers()
.visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@ -566,7 +566,10 @@ impl PickerDelegate for LanguageModelPickerDelegate {
Some(
ModelSelectorListItem::new(ix, model_info.model.name().0)
.icon(model_info.icon)
.map(|this| match &model_info.icon {
IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
})
.is_selected(is_selected)
.is_focused(selected)
.is_favorite(is_favorite)
@ -702,7 +705,7 @@ mod tests {
.any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
icon: IconName::Ai,
icon: IconOrSvg::Icon(IconName::Ai),
is_favorite,
}
})

View file

@ -33,7 +33,8 @@ use language::{
language_settings::{SoftWrap, all_language_settings},
};
use language_model::{
ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelRegistry, Role,
ConfigurationError, IconOrSvg, LanguageModelExt, LanguageModelImage, LanguageModelRegistry,
Role,
};
use multi_buffer::MultiBufferRow;
use picker::{Picker, popover_menu::PickerPopoverMenu};
@ -2231,10 +2232,10 @@ impl TextThreadEditor {
.default_model()
.map(|default| default.provider);
let provider_icon = match active_provider {
Some(provider) => provider.icon(),
None => IconName::Ai,
};
let provider_icon = active_provider
.as_ref()
.map(|p| p.icon())
.unwrap_or(IconOrSvg::Icon(IconName::Ai));
let focus_handle = self.editor().focus_handle(cx);
@ -2244,6 +2245,13 @@ impl TextThreadEditor {
(Color::Muted, IconName::ChevronDown)
};
let provider_icon_element = match provider_icon {
IconOrSvg::Svg(path) => Icon::from_external_svg(path),
IconOrSvg::Icon(name) => Icon::new(name),
}
.color(color)
.size(IconSize::XSmall);
let tooltip = Tooltip::element({
move |_, cx| {
let focus_handle = focus_handle.clone();
@ -2291,7 +2299,7 @@ impl TextThreadEditor {
.child(
h_flex()
.gap_0p5()
.child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
.child(provider_icon_element)
.child(
Label::new(model_name)
.color(color)

View file

@ -1,6 +1,11 @@
use gpui::{Action, FocusHandle, prelude::*};
use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
enum ModelIcon {
Name(IconName),
Path(SharedString),
}
#[derive(IntoElement)]
pub struct ModelSelectorHeader {
title: SharedString,
@ -39,7 +44,7 @@ impl RenderOnce for ModelSelectorHeader {
pub struct ModelSelectorListItem {
index: usize,
title: SharedString,
icon: Option<IconName>,
icon: Option<ModelIcon>,
is_selected: bool,
is_focused: bool,
is_favorite: bool,
@ -60,7 +65,12 @@ impl ModelSelectorListItem {
}
pub fn icon(mut self, icon: IconName) -> Self {
self.icon = Some(icon);
self.icon = Some(ModelIcon::Name(icon));
self
}
pub fn icon_path(mut self, path: SharedString) -> Self {
self.icon = Some(ModelIcon::Path(path));
self
}
@ -105,9 +115,12 @@ impl RenderOnce for ModelSelectorListItem {
.gap_1p5()
.when_some(self.icon, |this, icon| {
this.child(
Icon::new(icon)
.color(model_icon_color)
.size(IconSize::Small),
match icon {
ModelIcon::Name(icon_name) => Icon::new(icon_name),
ModelIcon::Path(icon_path) => Icon::from_external_svg(icon_path),
}
.color(model_icon_color)
.size(IconSize::Small),
)
})
.child(Label::new(self.title).truncate()),

View file

@ -1,9 +1,9 @@
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use language_model::{IconOrSvg, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use ui::{Divider, List, ListBulletItem, prelude::*};
pub struct ApiKeysWithProviders {
configured_providers: Vec<(IconName, SharedString)>,
configured_providers: Vec<(IconOrSvg, SharedString)>,
}
impl ApiKeysWithProviders {
@ -13,7 +13,8 @@ impl ApiKeysWithProviders {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
this.configured_providers = Self::compute_configured_providers(cx)
}
_ => {}
@ -26,9 +27,9 @@ impl ApiKeysWithProviders {
}
}
fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
fn compute_configured_providers(cx: &App) -> Vec<(IconOrSvg, SharedString)> {
LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
@ -47,7 +48,14 @@ impl Render for ApiKeysWithProviders {
.map(|(icon, name)| {
h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
.child(
match icon {
IconOrSvg::Icon(icon_name) => Icon::new(icon_name),
IconOrSvg::Svg(icon_path) => Icon::from_external_svg(icon_path),
}
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(name))
});
div()

View file

@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
pub struct AgentPanelOnboarding {
user_store: Entity<UserStore>,
client: Arc<Client>,
configured_providers: Vec<(IconName, SharedString)>,
has_configured_providers: bool,
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
}
@ -27,8 +27,9 @@ impl AgentPanelOnboarding {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_available_providers(cx)
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
this.has_configured_providers = Self::has_configured_providers(cx)
}
_ => {}
},
@ -38,20 +39,16 @@ impl AgentPanelOnboarding {
Self {
user_store,
client,
configured_providers: Self::compute_available_providers(cx),
has_configured_providers: Self::has_configured_providers(cx),
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
}
}
fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
fn has_configured_providers(cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
})
.map(|provider| (provider.icon(), provider.name().0))
.collect()
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
}
}
@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding {
}),
)
.map(|this| {
if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
if enrolled_in_trial || is_pro_user || self.has_configured_providers {
this
} else {
this.child(ApiKeysWithoutProviders::new())

View file

@ -19,6 +19,9 @@ impl Global for GlobalExtensionHostProxy {}
///
/// This object implements each of the individual proxy types so that their
/// methods can be called directly on it.
/// Registration function for language model providers.
pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send>;
#[derive(Default)]
pub struct ExtensionHostProxy {
theme_proxy: RwLock<Option<Arc<dyn ExtensionThemeProxy>>>,
@ -29,6 +32,7 @@ pub struct ExtensionHostProxy {
slash_command_proxy: RwLock<Option<Arc<dyn ExtensionSlashCommandProxy>>>,
context_server_proxy: RwLock<Option<Arc<dyn ExtensionContextServerProxy>>>,
debug_adapter_provider_proxy: RwLock<Option<Arc<dyn ExtensionDebugAdapterProviderProxy>>>,
language_model_provider_proxy: RwLock<Option<Arc<dyn ExtensionLanguageModelProviderProxy>>>,
}
impl ExtensionHostProxy {
@ -54,6 +58,7 @@ impl ExtensionHostProxy {
slash_command_proxy: RwLock::default(),
context_server_proxy: RwLock::default(),
debug_adapter_provider_proxy: RwLock::default(),
language_model_provider_proxy: RwLock::default(),
}
}
@ -90,6 +95,15 @@ impl ExtensionHostProxy {
.write()
.replace(Arc::new(proxy));
}
pub fn register_language_model_provider_proxy(
&self,
proxy: impl ExtensionLanguageModelProviderProxy,
) {
self.language_model_provider_proxy
.write()
.replace(Arc::new(proxy));
}
}
pub trait ExtensionThemeProxy: Send + Sync + 'static {
@ -446,3 +460,37 @@ impl ExtensionDebugAdapterProviderProxy for ExtensionHostProxy {
proxy.unregister_debug_locator(locator_name)
}
}
pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
fn register_language_model_provider(
&self,
provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
);
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
}
impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
fn register_language_model_provider(
&self,
provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
) {
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
return;
};
proxy.register_language_model_provider(provider_id, register_fn, cx)
}
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
return;
};
proxy.unregister_language_model_provider(provider_id, cx)
}
}

View file

@ -93,6 +93,8 @@ pub struct ExtensionManifest {
pub debug_adapters: BTreeMap<Arc<str>, DebugAdapterManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub debug_locators: BTreeMap<Arc<str>, DebugLocatorManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub language_model_providers: BTreeMap<Arc<str>, LanguageModelProviderManifestEntry>,
}
impl ExtensionManifest {
@ -288,6 +290,16 @@ pub struct DebugAdapterManifestEntry {
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct DebugLocatorManifestEntry {}
/// Manifest entry for a language model provider.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelProviderManifestEntry {
/// Display name for the provider.
pub name: String,
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
#[serde(default)]
pub icon: Option<String>,
}
impl ExtensionManifest {
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
let extension_name = extension_dir
@ -358,6 +370,7 @@ fn manifest_from_old_manifest(
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: Default::default(),
}
}
@ -391,6 +404,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View file

@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest {
)],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View file

@ -113,6 +113,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View file

@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},

View file

@ -797,11 +797,26 @@ pub enum AuthenticateError {
Other(#[from] anyhow::Error),
}
/// Either a built-in icon name or a path to an external SVG.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IconOrSvg {
/// A built-in icon from Zed's icon set.
Icon(IconName),
/// Path to a custom SVG icon file.
Svg(SharedString),
}
impl Default for IconOrSvg {
fn default() -> Self {
Self::Icon(IconName::ZedAssistant)
}
}
pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
fn icon(&self) -> IconName {
IconName::ZedAssistant
fn icon(&self) -> IconOrSvg {
IconOrSvg::default()
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
@ -820,7 +835,7 @@ pub trait LanguageModelProvider: 'static {
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
}
#[derive(Default, Clone)]
#[derive(Default, Clone, PartialEq, Eq)]
pub enum ConfigurationViewTargetAgent {
#[default]
ZedAgent,

View file

@ -2,12 +2,16 @@ use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use collections::BTreeMap;
use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
/// Function type for checking if a built-in provider should be hidden.
/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
@ -48,6 +52,11 @@ pub struct LanguageModelRegistry {
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
/// Set of installed extension IDs that provide language models.
/// Used to determine which built-in providers should be hidden.
installed_llm_extension_ids: HashSet<Arc<str>>,
/// Function to check if a built-in provider should be hidden by an extension.
builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
}
#[derive(Debug)]
@ -104,6 +113,8 @@ pub enum Event {
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
/// Emitted when provider visibility changes due to extension install/uninstall.
ProvidersChanged,
}
impl EventEmitter<Event> for LanguageModelRegistry {}
@ -183,6 +194,60 @@ impl LanguageModelRegistry {
providers
}
/// Returns providers, filtering out hidden built-in providers.
pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
self.providers()
.into_iter()
.filter(|p| !self.should_hide_provider(&p.id()))
.collect()
}
/// Sets the function used to check if a built-in provider should be hidden.
pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
self.builtin_provider_hiding_fn = Some(hiding_fn);
}
/// Called when an extension is installed/loaded.
/// If the extension provides language models, track it so we can hide the corresponding built-in.
pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
if self.installed_llm_extension_ids.insert(extension_id) {
cx.emit(Event::ProvidersChanged);
cx.notify();
}
}
/// Called when an extension is uninstalled/unloaded.
pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
if self.installed_llm_extension_ids.remove(extension_id) {
cx.emit(Event::ProvidersChanged);
cx.notify();
}
}
/// Sync the set of installed LLM extension IDs.
pub fn sync_installed_llm_extensions(
&mut self,
extension_ids: HashSet<Arc<str>>,
cx: &mut Context<Self>,
) {
if extension_ids != self.installed_llm_extension_ids {
self.installed_llm_extension_ids = extension_ids;
cx.emit(Event::ProvidersChanged);
cx.notify();
}
}
/// Returns true if a provider should be hidden from the UI.
/// Built-in providers are hidden when their corresponding extension is installed.
pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
if let Some(extension_id) = hiding_fn(&provider_id.0) {
return self.installed_llm_extension_ids.contains(extension_id);
}
}
false
}
pub fn configuration_error(
&self,
model: Option<ConfiguredModel>,
@ -416,4 +481,132 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
#[gpui::test]
fn test_provider_hiding_on_extension_install(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = Arc::new(FakeLanguageModelProvider::default());
let provider_id = provider.id();
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "fake" {
Some("fake-extension")
} else {
None
}
}));
});
let visible = registry.read(cx).visible_providers();
assert_eq!(visible.len(), 1);
assert_eq!(visible[0].id(), provider_id);
registry.update(cx, |registry, cx| {
registry.extension_installed("fake-extension".into(), cx);
});
let visible = registry.read(cx).visible_providers();
assert!(visible.is_empty());
let all = registry.read(cx).providers();
assert_eq!(all.len(), 1);
}
#[gpui::test]
fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = Arc::new(FakeLanguageModelProvider::default());
let provider_id = provider.id();
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "fake" {
Some("fake-extension")
} else {
None
}
}));
registry.extension_installed("fake-extension".into(), cx);
});
let visible = registry.read(cx).visible_providers();
assert!(visible.is_empty());
registry.update(cx, |registry, cx| {
registry.extension_uninstalled("fake-extension", cx);
});
let visible = registry.read(cx).visible_providers();
assert_eq!(visible.len(), 1);
assert_eq!(visible[0].id(), provider_id);
}
#[gpui::test]
fn test_should_hide_provider(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
registry.update(cx, |registry, cx| {
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "anthropic" {
Some("anthropic")
} else if id == "openai" {
Some("openai")
} else {
None
}
}));
registry.extension_installed("anthropic".into(), cx);
});
let registry_read = registry.read(cx);
assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
}
#[gpui::test]
fn test_sync_installed_llm_extensions(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = Arc::new(FakeLanguageModelProvider::default());
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "fake" {
Some("fake-extension")
} else {
None
}
}));
});
let mut extension_ids = HashSet::default();
extension_ids.insert(Arc::from("fake-extension"));
registry.update(cx, |registry, cx| {
registry.sync_installed_llm_extensions(extension_ids, cx);
});
assert!(registry.read(cx).visible_providers().is_empty());
registry.update(cx, |registry, cx| {
registry.sync_installed_llm_extensions(HashSet::default(), cx);
});
assert_eq!(registry.read(cx).visible_providers().len(), 1);
}
}

View file

@ -28,6 +28,8 @@ convert_case.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
extension.workspace = true
extension_host.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }

View file

@ -0,0 +1,67 @@
use collections::HashMap;
use extension::{
ExtensionHostProxy, ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration,
};
use gpui::{App, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use std::sync::{Arc, LazyLock};
/// Maps built-in provider IDs to their corresponding extension IDs.
/// When an extension with this ID is installed, the built-in provider should be hidden.
static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
LazyLock::new(|| {
let mut map = HashMap::default();
map.insert("anthropic", "anthropic");
map.insert("openai", "openai");
map.insert("google", "google-ai");
map.insert("openrouter", "openrouter");
map.insert("copilot_chat", "copilot-chat");
map
});
/// Returns the extension ID that should hide the given built-in provider.
pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
}
/// Proxy that registers extension language model providers with the LanguageModelRegistry.
pub struct LanguageModelProviderRegistryProxy {
registry: Entity<LanguageModelRegistry>,
}
impl LanguageModelProviderRegistryProxy {
pub fn new(registry: Entity<LanguageModelRegistry>) -> Self {
Self { registry }
}
}
impl ExtensionLanguageModelProviderProxy for LanguageModelProviderRegistryProxy {
fn register_language_model_provider(
&self,
_provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
) {
register_fn(cx);
}
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
self.registry.update(cx, |registry, cx| {
registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx);
});
}
}
/// Initialize the extension language model provider proxy.
/// This must be called BEFORE extension_host::init to ensure the proxy is available
/// when extensions try to register their language model providers.
pub fn init_proxy(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, _cx| {
registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
});
proxy.register_language_model_provider_proxy(LanguageModelProviderRegistryProxy::new(registry));
}

View file

@ -7,9 +7,12 @@ use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod extension;
pub mod provider;
mod settings;
pub use crate::extension::init_proxy as init_extension_proxy;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
@ -31,6 +34,56 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
register_language_model_providers(registry, user_store, client.clone(), cx);
});
// Subscribe to extension store events to track LLM extension installations
if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
cx.subscribe(&extension_store, {
let registry = registry.clone();
move |extension_store, event, cx| match event {
extension_host::Event::ExtensionInstalled(extension_id) => {
if let Some(manifest) = extension_store
.read(cx)
.extension_manifest_for_id(extension_id)
{
if !manifest.language_model_providers.is_empty() {
registry.update(cx, |registry, cx| {
registry.extension_installed(extension_id.clone(), cx);
});
}
}
}
extension_host::Event::ExtensionUninstalled(extension_id) => {
registry.update(cx, |registry, cx| {
registry.extension_uninstalled(extension_id, cx);
});
}
extension_host::Event::ExtensionsUpdated => {
let mut new_ids = HashSet::default();
for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
if !entry.manifest.language_model_providers.is_empty() {
new_ids.insert(extension_id.clone());
}
}
registry.update(cx, |registry, cx| {
registry.sync_installed_llm_extensions(new_ids, cx);
});
}
_ => {}
}
})
.detach();
// Initialize with currently installed extensions
registry.update(cx, |registry, cx| {
let mut initial_ids = HashSet::default();
for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
if !entry.manifest.language_model_providers.is_empty() {
initial_ids.insert(extension_id.clone());
}
}
registry.sync_installed_llm_extensions(initial_ids, cx);
});
}
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
.openai_compatible
.keys()

View file

@ -8,7 +8,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModel,
ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, IconOrSvg, LanguageModel,
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
@ -125,8 +125,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiAnthropic
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiAnthropic)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -30,7 +30,7 @@ use gpui::{
use gpui_tokio::Tokio;
use http_client::HttpClient;
use language_model::{
AuthenticateError, EnvVar, LanguageModel, LanguageModelCacheConfiguration,
AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
@ -426,8 +426,8 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiBedrock
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiBedrock)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -19,7 +19,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
@ -304,8 +304,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiZed
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiZed)
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -18,12 +18,12 @@ use gpui::{AnyView, App, AsyncApp, Entity, Subscription, Task};
use http_client::StatusCode;
use language::language_settings::all_language_settings;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason, TokenUsage,
AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice,
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
MessageContent, RateLimiter, Role, StopReason, TokenUsage,
};
use settings::SettingsStore;
use ui::prelude::*;
@ -104,8 +104,8 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::Copilot
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::Copilot)
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -7,7 +7,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@ -127,8 +127,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiDeepSeek
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiDeepSeek)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -14,7 +14,7 @@ use language_model::{
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
@ -164,8 +164,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiGoogle
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiGoogle)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -10,7 +10,7 @@ use language_model::{
StopReason, TokenUsage,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
@ -175,8 +175,8 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiLmStudio
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiLmStudio)
}
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -5,7 +5,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@ -176,8 +176,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiMistral
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiMistral)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -5,7 +5,7 @@ use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Task};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
@ -221,8 +221,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiOllama
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiOllama)
}
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -5,7 +5,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@ -122,8 +122,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiOpenAi
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiOpenAi)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
@ -133,8 +133,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
self.name.clone()
}
fn icon(&self) -> IconName {
IconName::AiOpenAiCompat
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiOpenAiCompat)
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -4,7 +4,7 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
@ -180,8 +180,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiOpenRouter
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiOpenRouter)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var,
@ -117,8 +117,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiVZero
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiVZero)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -4,7 +4,7 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
@ -118,8 +118,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiXAi
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiXAi)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {

View file

@ -126,17 +126,6 @@ enum IconSource {
ExternalSvg(SharedString),
}
impl IconSource {
fn from_path(path: impl Into<SharedString>) -> Self {
let path = path.into();
if path.starts_with("icons/") {
Self::Embedded(path)
} else {
Self::External(Arc::from(PathBuf::from(path.as_ref())))
}
}
}
#[derive(IntoElement, RegisterComponent)]
pub struct Icon {
source: IconSource,
@ -155,9 +144,18 @@ impl Icon {
}
}
/// Create an icon from a path. Uses a heuristic to determine if it's embedded or external:
/// - Paths starting with "icons/" are treated as embedded SVGs
/// - Other paths are treated as external raster images (from icon themes)
pub fn from_path(path: impl Into<SharedString>) -> Self {
let path = path.into();
let source = if path.starts_with("icons/") {
IconSource::Embedded(path)
} else {
IconSource::External(Arc::from(PathBuf::from(path.as_ref())))
};
Self {
source: IconSource::from_path(path),
source,
color: Color::default(),
size: IconSize::default().rems(),
transformation: Transformation::default(),