language_model: Decouple from Zed-specific implementation details (#52913)

This PR decouples `language_model`'s dependence on Zed-specific
implementation details. In particular
* `credentials_provider` is split into a generic `credentials_provider`
crate that provides a trait, and `zed_credentials_provider` that
implements the said trait for Zed-specific providers and has functions
that can populate a global state with them
* `zed_env_vars` is split into a generic `env_var` crate that provides
generic tooling for managing env vars, and `zed_env_vars` that contains
Zed-specific statics
* `client` is now dependent on `language_model` and not vice versa

Release Notes:

- N/A
This commit is contained in:
Jakub Konka 2026-04-02 22:06:57 +02:00 committed by GitHub
parent 34c77a0eb9
commit 29609d3599
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
63 changed files with 1122 additions and 561 deletions

40
Cargo.lock generated
View file

@ -260,7 +260,6 @@ dependencies = [
"chrono",
"client",
"collections",
"credentials_provider",
"env_logger 0.11.8",
"feature_flags",
"fs",
@ -289,6 +288,7 @@ dependencies = [
"util",
"uuid",
"watch",
"zed_credentials_provider",
]
[[package]]
@ -2856,6 +2856,7 @@ dependencies = [
"chrono",
"clock",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
@ -2869,6 +2870,7 @@ dependencies = [
"http_client",
"http_client_tls",
"httparse",
"language_model",
"log",
"objc2-foundation",
"parking_lot",
@ -2900,6 +2902,7 @@ dependencies = [
"util",
"windows 0.61.3",
"worktree",
"zed_credentials_provider",
]
[[package]]
@ -3059,6 +3062,7 @@ dependencies = [
"serde",
"serde_json",
"text",
"zed_credentials_provider",
"zeta_prompt",
]
@ -4035,12 +4039,8 @@ name = "credentials_provider"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.31",
"gpui",
"paths",
"release_channel",
"serde",
"serde_json",
]
[[package]]
@ -5115,6 +5115,7 @@ dependencies = [
"collections",
"copilot",
"copilot_ui",
"credentials_provider",
"ctor",
"db",
"edit_prediction_context",
@ -5157,6 +5158,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
"zed_credentials_provider",
"zeta_prompt",
"zlog",
"zstd",
@ -5583,6 +5585,13 @@ dependencies = [
"log",
]
[[package]]
name = "env_var"
version = "0.1.0"
dependencies = [
"gpui",
]
[[package]]
name = "envy"
version = "0.4.2"
@ -9315,12 +9324,12 @@ dependencies = [
"anthropic",
"anyhow",
"base64 0.22.1",
"client",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
"env_var",
"futures 0.3.31",
"gpui",
"http_client",
@ -9336,7 +9345,6 @@ dependencies = [
"smol",
"thiserror 2.0.17",
"util",
"zed_env_vars",
]
[[package]]
@ -13137,6 +13145,7 @@ dependencies = [
"wax",
"which 6.0.3",
"worktree",
"zed_credentials_provider",
"zeroize",
"zlog",
"ztracing",
@ -15746,6 +15755,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
"zed_credentials_provider",
]
[[package]]
@ -22179,11 +22189,25 @@ dependencies = [
"uuid",
]
[[package]]
name = "zed_credentials_provider"
version = "0.1.0"
dependencies = [
"anyhow",
"credentials_provider",
"futures 0.3.31",
"gpui",
"paths",
"release_channel",
"serde",
"serde_json",
]
[[package]]
name = "zed_env_vars"
version = "0.1.0"
dependencies = [
"gpui",
"env_var",
]
[[package]]

View file

@ -61,6 +61,7 @@ members = [
"crates/edit_prediction_ui",
"crates/editor",
"crates/encoding_selector",
"crates/env_var",
"crates/etw_tracing",
"crates/eval_cli",
"crates/eval_utils",
@ -220,6 +221,7 @@ members = [
"crates/x_ai",
"crates/zed",
"crates/zed_actions",
"crates/zed_credentials_provider",
"crates/zed_env_vars",
"crates/zeta_prompt",
"crates/zlog",
@ -309,6 +311,7 @@ dev_container = { path = "crates/dev_container" }
diagnostics = { path = "crates/diagnostics" }
editor = { path = "crates/editor" }
encoding_selector = { path = "crates/encoding_selector" }
env_var = { path = "crates/env_var" }
etw_tracing = { path = "crates/etw_tracing" }
eval_utils = { path = "crates/eval_utils" }
extension = { path = "crates/extension" }
@ -465,6 +468,7 @@ worktree = { path = "crates/worktree" }
x_ai = { path = "crates/x_ai" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_credentials_provider = { path = "crates/zed_credentials_provider" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
zeta_prompt = { path = "crates/zeta_prompt" }

View file

@ -4,7 +4,7 @@ use crate::{
ListDirectoryTool, ListDirectoryToolInput, ReadFileTool, ReadFileToolInput,
};
use Role::*;
use client::{Client, UserStore};
use client::{Client, RefreshLlmTokenListener, UserStore};
use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind};
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
@ -1423,7 +1423,8 @@ impl EditAgentTest {
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
settings::init(cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
});

View file

@ -6,7 +6,7 @@ use acp_thread::{
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use client::{Client, RefreshLlmTokenListener, UserStore};
use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use feature_flags::FeatureFlagAppExt as _;
@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
LanguageModelRegistry::test(cx);
});
@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
}
};

View file

@ -6,7 +6,7 @@ use crate::{
};
use Role::*;
use anyhow::{Context as _, Result};
use client::{Client, UserStore};
use client::{Client, RefreshLlmTokenListener, UserStore};
use fs::FakeFs;
use futures::{FutureExt, StreamExt, future::LocalBoxFuture};
use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _};
@ -274,7 +274,8 @@ impl StreamingEditToolTest {
cx.set_http_client(http_client);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client, cx);
});

View file

@ -32,7 +32,6 @@ futures.workspace = true
gpui.workspace = true
feature_flags.workspace = true
gpui_tokio = { workspace = true, optional = true }
credentials_provider.workspace = true
google_ai.workspace = true
http_client.workspace = true
indoc.workspace = true
@ -53,6 +52,7 @@ terminal.workspace = true
uuid.workspace = true
util.workspace = true
watch.workspace = true
zed_credentials_provider.workspace = true
[target.'cfg(unix)'.dependencies]
libc.workspace = true

View file

@ -3,7 +3,6 @@ use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use collections::HashSet;
use credentials_provider::CredentialsProvider;
use fs::Fs;
use gpui::{App, AppContext as _, Entity, Task};
use language_model::{ApiKey, EnvVar};
@ -392,7 +391,7 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
if let Some(key) = env_var.value {
return Task::ready(Ok(key));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let credentials_provider = zed_credentials_provider::global(cx);
let api_url = google_ai::API_URL.to_string();
cx.spawn(async move |cx| {
Ok(

View file

@ -1,6 +1,7 @@
use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
use client::RefreshLlmTokenListener;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::AppContext;
use gpui::{Entity, TestAppContext};
@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
cx.set_http_client(Arc::new(http_client));
let client = client::Client::production(cx);
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
language_model::init(user_store, client, cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store, cx);
#[cfg(test)]
project::agent_server_store::AllAgentServersSettings::override_global(

View file

@ -815,7 +815,7 @@ mod tests {
cx.set_global(store);
theme_settings::init(theme::LoadThemes::JustBase, cx);
language_model::init_settings(cx);
language_model::init(cx);
editor::init(cx);
});

View file

@ -1809,7 +1809,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
language_model::init_settings(cx);
language_model::init(cx);
});
let fs = FakeFs::new(cx.executor());
@ -1966,7 +1966,7 @@ mod tests {
cx.set_global(settings_store);
prompt_store::init(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
language_model::init_settings(cx);
language_model::init(cx);
workspace::register_project_item::<Editor>(cx);
});

View file

@ -2025,7 +2025,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
pub mod evals {
use crate::InlineAssistant;
use agent::ThreadStore;
use client::{Client, UserStore};
use client::{Client, RefreshLlmTokenListener, UserStore};
use editor::{Editor, MultiBuffer, MultiBufferOffset};
use eval_utils::{EvalOutput, NoProcessor};
use fs::FakeFs;
@ -2091,7 +2091,8 @@ pub mod evals {
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);

View file

@ -22,6 +22,7 @@ base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
@ -35,6 +36,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
@ -60,6 +62,7 @@ tokio.workspace = true
url.workspace = true
util.workspace = true
worktree.workspace = true
zed_credentials_provider.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View file

@ -1,6 +1,7 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
mod llm_token;
mod proxy;
pub mod telemetry;
pub mod user;
@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
use cloud_api_client::CloudApiClient;
use cloud_api_client::websocket_protocol::MessageToClient;
use cloud_api_client::{ClientApiError, CloudApiClient};
use cloud_api_types::OrganizationId;
use credentials_provider::CredentialsProvider;
use feature_flags::FeatureFlagAppExt as _;
use futures::{
@ -24,6 +26,7 @@ use futures::{
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
use language_model::LlmApiToken;
use parking_lot::{Mutex, RwLock};
use postage::watch;
use proxy::connect_proxy_stream;
@ -51,6 +54,7 @@ use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
pub use llm_token::*;
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
@ -339,7 +343,7 @@ pub struct ClientCredentialsProvider {
impl ClientCredentialsProvider {
pub fn new(cx: &App) -> Self {
Self {
provider: <dyn CredentialsProvider>::global(cx),
provider: zed_credentials_provider::global(cx),
}
}
@ -568,6 +572,10 @@ impl Client {
self.http.clone()
}
pub fn credentials_provider(&self) -> Arc<dyn CredentialsProvider> {
self.credentials_provider.provider.clone()
}
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
self.cloud_client.clone()
}
@ -1513,6 +1521,66 @@ impl Client {
})
}
pub async fn acquire_llm_token(
&self,
llm_token: &LlmApiToken,
organization_id: Option<OrganizationId>,
) -> Result<String> {
let system_id = self.telemetry().system_id().map(|x| x.to_string());
let cloud_client = self.cloud_client();
match llm_token
.acquire(&cloud_client, system_id, organization_id)
.await
{
Ok(token) => Ok(token),
Err(ClientApiError::Unauthorized) => {
self.request_sign_out();
Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
}
Err(err) => Err(anyhow::Error::from(err)),
}
}
pub async fn refresh_llm_token(
&self,
llm_token: &LlmApiToken,
organization_id: Option<OrganizationId>,
) -> Result<String> {
let system_id = self.telemetry().system_id().map(|x| x.to_string());
let cloud_client = self.cloud_client();
match llm_token
.refresh(&cloud_client, system_id, organization_id)
.await
{
Ok(token) => Ok(token),
Err(ClientApiError::Unauthorized) => {
self.request_sign_out();
return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
}
Err(err) => return Err(anyhow::Error::from(err)),
}
}
pub async fn clear_and_refresh_llm_token(
&self,
llm_token: &LlmApiToken,
organization_id: Option<OrganizationId>,
) -> Result<String> {
let system_id = self.telemetry().system_id().map(|x| x.to_string());
let cloud_client = self.cloud_client();
match llm_token
.clear_and_refresh(&cloud_client, system_id, organization_id)
.await
{
Ok(token) => Ok(token),
Err(ClientApiError::Unauthorized) => {
self.request_sign_out();
return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
}
Err(err) => return Err(anyhow::Error::from(err)),
}
}
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
self.state.write().credentials = None;
self.cloud_client.clear_credentials();

View file

@ -0,0 +1,116 @@
use super::{Client, UserStore};
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
};
use language_model::LlmApiToken;
use std::sync::Arc;
pub trait NeedsLlmTokenRefresh {
/// Returns whether the LLM token needs to be refreshed.
fn needs_llm_token_refresh(&self) -> bool;
}
impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
fn needs_llm_token_refresh(&self) -> bool {
self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
|| self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
}
}
enum TokenRefreshMode {
Refresh,
ClearAndRefresh,
}
pub fn global_llm_token(cx: &App) -> LlmApiToken {
RefreshLlmTokenListener::global(cx)
.read(cx)
.llm_api_token
.clone()
}
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
pub struct LlmTokenRefreshedEvent;
pub struct RefreshLlmTokenListener {
client: Arc<Client>,
user_store: Entity<UserStore>,
llm_api_token: LlmApiToken,
_subscription: Subscription,
}
impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
pub fn global(cx: &App) -> Entity<Self> {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
client.add_message_to_client_handler({
let this = cx.weak_entity();
move |message, cx| {
if let Some(this) = this.upgrade() {
Self::handle_refresh_llm_token(this, message, cx);
}
}
});
let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
if matches!(event, super::user::Event::OrganizationChanged) {
this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
}
});
Self {
client,
user_store,
llm_api_token: LlmApiToken::default(),
_subscription: subscription,
}
}
fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let organization_id = self
.user_store
.read(cx)
.current_organization()
.map(|organization| organization.id.clone());
cx.spawn(async move |this, cx| {
match mode {
TokenRefreshMode::Refresh => {
client
.refresh_llm_token(&llm_api_token, organization_id)
.await?;
}
TokenRefreshMode::ClearAndRefresh => {
client
.clear_and_refresh_llm_token(&llm_api_token, organization_id)
.await?;
}
}
this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
})
.detach_and_log_err(cx);
}
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
match message {
MessageToClient::UserUpdated => {
this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
}
}
}
}

View file

@ -22,6 +22,7 @@ log.workspace = true
serde.workspace = true
serde_json.workspace = true
text.workspace = true
zed_credentials_provider.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]

View file

@ -48,9 +48,10 @@ pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
}
pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = zed_credentials_provider::global(cx);
let api_url = codestral_api_url(cx);
codestral_api_key_state(cx).update(cx, |key_state, cx| {
key_state.load_if_needed(api_url, |s| s, cx)
key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
})
}

View file

@ -13,9 +13,5 @@ path = "src/credentials_provider.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
paths.workspace = true
release_channel.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -1,26 +1,8 @@
use std::collections::HashMap;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use anyhow::Result;
use futures::FutureExt as _;
use gpui::{App, AsyncApp};
use release_channel::ReleaseChannel;
/// An environment variable whose presence indicates that the system keychain
/// should be used in development.
///
/// By default, running Zed in development uses the development credentials
/// provider. Setting this environment variable allows you to interact with the
/// system keychain (for instance, if you need to test something).
///
/// Only works in development. Setting this environment variable in other
/// release channels is a no-op.
static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
});
use gpui::AsyncApp;
/// A provider for credentials.
///
@ -50,150 +32,3 @@ pub trait CredentialsProvider: Send + Sync {
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
}
impl dyn CredentialsProvider {
/// Returns the global [`CredentialsProvider`].
pub fn global(cx: &App) -> Arc<Self> {
// The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
// seems like this is a false positive from Clippy.
#[allow(clippy::arc_with_non_send_sync)]
Self::new(cx)
}
fn new(cx: &App) -> Arc<Self> {
let use_development_provider = match ReleaseChannel::try_global(cx) {
Some(ReleaseChannel::Dev) => {
// In development we default to using the development
// credentials provider to avoid getting spammed by relentless
// keychain access prompts.
//
// However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
// variable is set, we will use the actual keychain.
!*ZED_DEVELOPMENT_USE_KEYCHAIN
}
Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
| None => false,
};
if use_development_provider {
Arc::new(DevelopmentCredentialsProvider::new())
} else {
Arc::new(KeychainCredentialsProvider)
}
}
}
/// A credentials provider that stores credentials in the system keychain.
struct KeychainCredentialsProvider;
impl CredentialsProvider for KeychainCredentialsProvider {
fn read_credentials<'a>(
&'a self,
url: &'a str,
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
}
fn write_credentials<'a>(
&'a self,
url: &'a str,
username: &'a str,
password: &'a [u8],
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
cx.update(move |cx| cx.write_credentials(url, username, password))
.await
}
.boxed_local()
}
fn delete_credentials<'a>(
&'a self,
url: &'a str,
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
}
}
/// A credentials provider that stores credentials in a local file.
///
/// This MUST only be used in development, as this is not a secure way of storing
/// credentials on user machines.
///
/// Its existence is purely to work around the annoyance of having to constantly
/// re-allow access to the system keychain when developing Zed.
struct DevelopmentCredentialsProvider {
path: PathBuf,
}
impl DevelopmentCredentialsProvider {
fn new() -> Self {
let path = paths::config_dir().join("development_credentials");
Self { path }
}
fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
let json = std::fs::read(&self.path)?;
let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
Ok(credentials)
}
fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
let json = serde_json::to_string(credentials)?;
std::fs::write(&self.path, json)?;
Ok(())
}
}
impl CredentialsProvider for DevelopmentCredentialsProvider {
fn read_credentials<'a>(
&'a self,
url: &'a str,
_cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
async move {
Ok(self
.load_credentials()
.unwrap_or_default()
.get(url)
.cloned())
}
.boxed_local()
}
fn write_credentials<'a>(
&'a self,
url: &'a str,
username: &'a str,
password: &'a [u8],
_cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
let mut credentials = self.load_credentials().unwrap_or_default();
credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
self.save_credentials(&credentials)
}
.boxed_local()
}
fn delete_credentials<'a>(
&'a self,
url: &'a str,
_cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
let mut credentials = self.load_credentials()?;
credentials.remove(url);
self.save_credentials(&credentials)
}
.boxed_local()
}
}

View file

@ -26,6 +26,7 @@ cloud_llm_client.workspace = true
collections.workspace = true
copilot.workspace = true
copilot_ui.workspace = true
credentials_provider.workspace = true
db.workspace = true
edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
@ -65,6 +66,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
zed_credentials_provider.workspace = true
zeta_prompt.workspace = true
zstd.workspace = true

View file

@ -258,6 +258,7 @@ fn generate_timestamp_name() -> String {
mod tests {
use super::*;
use crate::EditPredictionStore;
use client::RefreshLlmTokenListener;
use client::{Client, UserStore};
use clock::FakeSystemClock;
use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
@ -548,7 +549,8 @@ mod tests {
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
EditPredictionStore::global(&client, &user_store, cx);
})
}

View file

@ -1,5 +1,5 @@
use anyhow::Result;
use client::{Client, EditPredictionUsage, UserStore};
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@ -11,6 +11,7 @@ use cloud_llm_client::{
};
use collections::{HashMap, HashSet};
use copilot::{Copilot, Reinstall, SignIn, SignOut};
use credentials_provider::CredentialsProvider;
use db::kvp::{Dismissable, KeyValueStore};
use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
@ -30,7 +31,7 @@ use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
use language_model::LlmApiToken;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@ -150,6 +151,7 @@ pub struct EditPredictionStore {
rated_predictions: HashSet<EditPredictionId>,
#[cfg(test)]
settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
credentials_provider: Arc<dyn CredentialsProvider>,
}
pub(crate) struct EditPredictionRejectionPayload {
@ -746,7 +748,7 @@ impl EditPredictionStore {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let data_collection_choice = Self::load_data_collection_choice(cx);
let llm_token = LlmApiToken::global(cx);
let llm_token = global_llm_token(cx);
let (reject_tx, reject_rx) = mpsc::unbounded();
cx.background_spawn({
@ -787,6 +789,8 @@ impl EditPredictionStore {
.log_err();
});
let credentials_provider = zed_credentials_provider::global(cx);
let this = Self {
projects: HashMap::default(),
client,
@ -807,6 +811,8 @@ impl EditPredictionStore {
shown_predictions: Default::default(),
#[cfg(test)]
settled_event_callback: None,
credentials_provider,
};
this
@ -871,7 +877,9 @@ impl EditPredictionStore {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
let token = llm_token.acquire(&client, organization_id).await?;
let token = client
.acquire_llm_token(&llm_token, organization_id.clone())
.await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
@ -2315,7 +2323,10 @@ impl EditPredictionStore {
zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
}
EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
EditPredictionModel::Mercury => {
self.mercury
.request_prediction(inputs, self.credentials_provider.clone(), cx)
}
};
cx.spawn(async move |this, cx| {
@ -2536,12 +2547,15 @@ impl EditPredictionStore {
Res: DeserializeOwned,
{
let http_client = client.http_client();
let mut token = if require_auth {
Some(llm_token.acquire(&client, organization_id.clone()).await?)
Some(
client
.acquire_llm_token(&llm_token, organization_id.clone())
.await?,
)
} else {
llm_token
.acquire(&client, organization_id.clone())
client
.acquire_llm_token(&llm_token, organization_id.clone())
.await
.ok()
};
@ -2585,7 +2599,11 @@ impl EditPredictionStore {
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
token = Some(
client
.refresh_llm_token(&llm_token, organization_id.clone())
.await?,
);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;

View file

@ -1,6 +1,6 @@
use super::*;
use crate::udiff::apply_diff_to_string;
use client::{UserStore, test::FakeServer};
use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
use clock::FakeSystemClock;
use clock::ReplicaId;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@ -23,7 +23,7 @@ use language::{
Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
};
use language_model::RefreshLlmTokenListener;
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
client.cloud_client().set_credentials(1, "test".into());
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));

View file

@ -5,6 +5,7 @@ use crate::{
};
use anyhow::{Context as _, Result};
use cloud_llm_client::EditPredictionRejectReason;
use credentials_provider::CredentialsProvider;
use futures::AsyncReadExt as _;
use gpui::{
App, AppContext as _, Context, Entity, Global, SharedString, Task,
@ -51,10 +52,11 @@ impl Mercury {
debug_tx,
..
}: EditPredictionModelInput,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
});
let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
return Task::ready(Ok(None));
@ -387,8 +389,9 @@ pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
}
pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
let credentials_provider = zed_credentials_provider::global(cx);
mercury_api_token(cx).update(cx, |key_state, cx| {
key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
})
}

View file

@ -42,9 +42,10 @@ pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
pub fn load_open_ai_compatible_api_token(
cx: &mut App,
) -> Task<Result<(), language_model::AuthenticateError>> {
let credentials_provider = zed_credentials_provider::global(cx);
let api_url = open_ai_compatible_api_url(cx);
open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
key_state.load_if_needed(api_url, |s| s, cx)
key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
})
}

View file

@ -1,4 +1,4 @@
use client::{Client, ProxySettings, UserStore};
use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);

15
crates/env_var/Cargo.toml Normal file
View file

@ -0,0 +1,15 @@
[package]
name = "env_var"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/env_var.rs"
[dependencies]
gpui.workspace = true

1
crates/env_var/LICENSE-GPL Symbolic link
View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,40 @@
use gpui::SharedString;
#[derive(Clone)]
pub struct EnvVar {
pub name: SharedString,
/// Value of the environment variable. Also `None` when set to an empty string.
pub value: Option<String>,
}
impl EnvVar {
pub fn new(name: SharedString) -> Self {
let value = std::env::var(name.as_str()).ok();
if value.as_ref().is_some_and(|v| v.is_empty()) {
Self { name, value: None }
} else {
Self { name, value }
}
}
pub fn or(self, other: EnvVar) -> EnvVar {
if self.value.is_some() { self } else { other }
}
}
/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
#[macro_export]
macro_rules! env_var {
($name:expr) => {
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
};
}
/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
/// environment variable exists and is non-empty.
#[macro_export]
macro_rules! bool_env_var {
($name:expr) => {
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
};
}

View file

@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::sync::Arc;
use client::{Client, ProxySettings, UserStore};
use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
language_model::init(user_store.clone(), client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);

View file

@ -20,11 +20,11 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
client.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
env_var.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
@ -40,7 +40,6 @@ serde_json.workspace = true
smol.workspace = true
thiserror.workspace = true
util.workspace = true
zed_env_vars.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View file

@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use env_var::EnvVar;
use futures::{FutureExt, future};
use gpui::{AsyncApp, Context, SharedString, Task};
use std::{
@ -7,7 +8,6 @@ use std::{
sync::Arc,
};
use util::ResultExt as _;
use zed_env_vars::EnvVar;
use crate::AuthenticateError;
@ -101,6 +101,7 @@ impl ApiKeyState {
url: SharedString,
key: Option<String>,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
provider: Arc<dyn CredentialsProvider>,
cx: &Context<Ent>,
) -> Task<Result<()>> {
if self.is_from_env_var() {
@ -108,18 +109,14 @@ impl ApiKeyState {
"bug: attempted to store API key in system keychain when API key is from env var",
)));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |ent, cx| {
if let Some(key) = &key {
credentials_provider
provider
.write_credentials(&url, "Bearer", key.as_bytes(), cx)
.await
.log_err();
} else {
credentials_provider
.delete_credentials(&url, cx)
.await
.log_err();
provider.delete_credentials(&url, cx).await.log_err();
}
ent.update(cx, |ent, cx| {
let this = get_this(ent);
@ -144,12 +141,13 @@ impl ApiKeyState {
&mut self,
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<Ent>,
) {
if url != self.url {
if !self.is_from_env_var() {
// loading will continue even though this result task is dropped
let _task = self.load_if_needed(url, get_this, cx);
let _task = self.load_if_needed(url, get_this, provider, cx);
}
}
}
@ -163,6 +161,7 @@ impl ApiKeyState {
&mut self,
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<Ent>,
) -> Task<Result<(), AuthenticateError>> {
if let LoadStatus::Loaded { .. } = &self.load_status
@ -185,7 +184,7 @@ impl ApiKeyState {
let task = if let Some(load_task) = &self.load_task {
load_task.clone()
} else {
let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
self.url = url;
self.load_status = LoadStatus::NotPresent;
self.load_task = Some(load_task.clone());
@ -206,14 +205,13 @@ impl ApiKeyState {
fn load<Ent: 'static>(
url: SharedString,
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
provider: Arc<dyn CredentialsProvider>,
cx: &Context<Ent>,
) -> Task<()> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn({
async move |ent, cx| {
let load_status =
ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
.await;
ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
ent.update(cx, |ent, cx| {
let this = get_this(ent);
this.url = url;

View file

@ -11,12 +11,10 @@ pub mod tool_schema;
pub mod fake_provider;
use anyhow::{Result, anyhow};
use client::Client;
use client::UserStore;
use cloud_llm_client::CompletionRequestStatus;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
@ -36,15 +34,10 @@ pub use crate::registry::*;
pub use crate::request::*;
pub use crate::role::*;
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
pub use env_var::{EnvVar, env_var};
pub use provider::*;
pub use zed_env_vars::{EnvVar, env_var};
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
init_settings(cx);
RefreshLlmTokenListener::register(client, user_store, cx);
}
pub fn init_settings(cx: &mut App) {
pub fn init(cx: &mut App) {
registry::init(cx);
}

View file

@ -1,16 +1,9 @@
use std::fmt;
use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::Client;
use client::UserStore;
use cloud_api_client::ClientApiError;
use cloud_api_client::CloudApiClient;
use cloud_api_types::OrganizationId;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
pub fn global(cx: &App) -> Self {
RefreshLlmTokenListener::global(cx)
.read(cx)
.llm_api_token
.clone()
}
pub async fn acquire(
&self,
client: &Arc<Client>,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String> {
) -> Result<String, ClientApiError> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
@ -49,6 +36,7 @@ impl LlmApiToken {
Self::fetch(
RwLockUpgradableReadGuard::upgrade(lock).await,
client,
system_id,
organization_id,
)
.await
@ -57,10 +45,11 @@ impl LlmApiToken {
pub async fn refresh(
&self,
client: &Arc<Client>,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String> {
Self::fetch(self.0.write().await, client, organization_id).await
) -> Result<String, ClientApiError> {
Self::fetch(self.0.write().await, client, system_id, organization_id).await
}
/// Clears the existing token before attempting to fetch a new one.
@ -69,28 +58,22 @@ impl LlmApiToken {
/// leave a token for the wrong organization.
pub async fn clear_and_refresh(
&self,
client: &Arc<Client>,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String> {
) -> Result<String, ClientApiError> {
let mut lock = self.0.write().await;
*lock = None;
Self::fetch(lock, client, organization_id).await
Self::fetch(lock, client, system_id, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &Arc<Client>,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String> {
let system_id = client
.telemetry()
.system_id()
.map(|system_id| system_id.to_string());
let result = client
.cloud_client()
.create_llm_token(system_id, organization_id)
.await;
) -> Result<String, ClientApiError> {
let result = client.create_llm_token(system_id, organization_id).await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
@ -98,112 +81,7 @@ impl LlmApiToken {
}
Err(err) => {
*lock = None;
match err {
ClientApiError::Unauthorized => {
client.request_sign_out();
Err(err).context("Failed to create LLM token")
}
ClientApiError::Other(err) => Err(err),
}
}
}
}
}
pub trait NeedsLlmTokenRefresh {
/// Returns whether the LLM token needs to be refreshed.
fn needs_llm_token_refresh(&self) -> bool;
}
impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
fn needs_llm_token_refresh(&self) -> bool {
self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
|| self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
}
}
enum TokenRefreshMode {
Refresh,
ClearAndRefresh,
}
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
pub struct LlmTokenRefreshedEvent;
pub struct RefreshLlmTokenListener {
client: Arc<Client>,
user_store: Entity<UserStore>,
llm_api_token: LlmApiToken,
_subscription: Subscription,
}
impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
pub fn global(cx: &App) -> Entity<Self> {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
client.add_message_to_client_handler({
let this = cx.weak_entity();
move |message, cx| {
if let Some(this) = this.upgrade() {
Self::handle_refresh_llm_token(this, message, cx);
}
}
});
let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
if matches!(event, client::user::Event::OrganizationChanged) {
this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
}
});
Self {
client,
user_store,
llm_api_token: LlmApiToken::default(),
_subscription: subscription,
}
}
fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let organization_id = self
.user_store
.read(cx)
.current_organization()
.map(|organization| organization.id.clone());
cx.spawn(async move |this, cx| {
match mode {
TokenRefreshMode::Refresh => {
llm_api_token.refresh(&client, organization_id).await?;
}
TokenRefreshMode::ClearAndRefresh => {
llm_api_token
.clear_and_refresh(&client, organization_id)
.await?;
}
}
this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
})
.detach_and_log_err(cx);
}
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
match message {
MessageToClient::UserUpdated => {
this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
Err(err)
}
}
}

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
use credentials_provider::CredentialsProvider;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
@ -31,9 +32,16 @@ use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
let credentials_provider = client.credentials_provider();
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_language_model_providers(registry, user_store, client.clone(), cx);
register_language_model_providers(
registry,
user_store,
client.clone(),
credentials_provider.clone(),
cx,
);
});
// Subscribe to extension store events to track LLM extension installations
@ -104,6 +112,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
&HashSet::default(),
&openai_compatible_providers,
client.clone(),
credentials_provider.clone(),
cx,
);
});
@ -124,6 +133,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
&openai_compatible_providers,
&openai_compatible_providers_new,
client.clone(),
credentials_provider.clone(),
cx,
);
});
@ -138,6 +148,7 @@ fn register_openai_compatible_providers(
old: &HashSet<Arc<str>>,
new: &HashSet<Arc<str>>,
client: Arc<Client>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<LanguageModelRegistry>,
) {
for provider_id in old {
@ -152,6 +163,7 @@ fn register_openai_compatible_providers(
Arc::new(OpenAiCompatibleLanguageModelProvider::new(
provider_id.clone(),
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
@ -164,6 +176,7 @@ fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>,
client: Arc<Client>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
@ -177,62 +190,105 @@ fn register_language_model_providers(
registry.register_provider(
Arc::new(AnthropicLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(OpenAiLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(OllamaLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(LmStudioLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(DeepSeekLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(GoogleLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
MistralLanguageModelProvider::global(client.http_client(), cx),
MistralLanguageModelProvider::global(
client.http_client(),
credentials_provider.clone(),
cx,
),
cx,
);
registry.register_provider(
Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(BedrockLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(OpenRouterLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(VercelLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(VercelAiGatewayLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(XAiLanguageModelProvider::new(
client.http_client(),
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(
Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
Arc::new(OpenCodeLanguageModelProvider::new(
client.http_client(),
credentials_provider,
cx,
)),
cx,
);
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);

View file

@ -6,6 +6,7 @@ use anthropic::{
};
use anyhow::Result;
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
@ -51,6 +52,7 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -59,30 +61,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = AnthropicLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = AnthropicLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -195,12 +195,13 @@ pub struct State {
settings: Option<AmazonBedrockSettings>,
/// Whether credentials came from environment variables (only relevant for static credentials)
credentials_from_env: bool,
credentials_provider: Arc<dyn CredentialsProvider>,
_subscription: Subscription,
}
impl State {
fn reset_auth(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(AMAZON_AWS_URL, cx)
@ -220,7 +221,7 @@ impl State {
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let auth = credentials.clone().into_auth();
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(
@ -287,7 +288,7 @@ impl State {
&self,
cx: &mut Context<Self>,
) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let credentials_provider = self.credentials_provider.clone();
cx.spawn(async move |this, cx| {
// Try environment variables first
let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value {
@ -400,11 +401,16 @@ pub struct BedrockLanguageModelProvider {
}
impl BedrockLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| State {
auth: None,
settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
credentials_from_env: false,
credentials_provider,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),

View file

@ -1,7 +1,9 @@
use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use client::{Client, UserStore, zed_urls};
use client::{
Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls,
};
use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
@ -24,10 +26,9 @@ use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
ZED_CLOUD_PROVIDER_NAME,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
};
use release_channel::AppVersion;
use schemars::JsonSchema;
@ -111,7 +112,7 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let llm_api_token = LlmApiToken::global(cx);
let llm_api_token = global_llm_token(cx);
Self {
client: client.clone(),
llm_api_token,
@ -226,7 +227,9 @@ impl State {
organization_id: Option<OrganizationId>,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client, organization_id).await?;
let token = client
.acquire_llm_token(&llm_api_token, organization_id)
.await?;
let request = http_client::Request::builder()
.method(Method::GET)
@ -414,8 +417,8 @@ impl CloudLanguageModel {
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
let mut token = llm_api_token
.acquire(&client, organization_id.clone())
let mut token = client
.acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
let mut refreshed_token = false;
@ -447,8 +450,8 @@ impl CloudLanguageModel {
}
if !refreshed_token && response.needs_llm_token_refresh() {
token = llm_api_token
.refresh(&client, organization_id.clone())
token = client
.refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
refreshed_token = true;
continue;
@ -713,7 +716,9 @@ impl LanguageModel for CloudLanguageModel {
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client, organization_id).await?;
let token = client
.acquire_llm_token(&llm_api_token, organization_id)
.await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,

View file

@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use deepseek::DEEPSEEK_API_URL;
use futures::Stream;
@ -49,6 +50,7 @@ pub struct DeepSeekLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -57,30 +59,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl DeepSeekLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -1,5 +1,6 @@
use anyhow::{Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
@ -60,6 +61,7 @@ pub struct GoogleLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
@ -76,30 +78,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = GoogleLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl GoogleLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::HashMap;
use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
@ -52,6 +53,7 @@ pub struct LmStudioLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
available_models: Vec<lmstudio::Model>,
fetch_model_task: Option<Task<Result<()>>>,
@ -64,10 +66,15 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
let task = self
.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx);
let task = self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
self.restart_fetch_models_task(cx);
task
}
@ -114,10 +121,14 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
let _task = self
.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
let _task = self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
if self.is_authenticated() {
return Task::ready(Ok(()));
@ -152,16 +163,29 @@ impl State {
}
impl LmStudioLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
move |this: &mut State, cx| {
let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
if &settings != new_settings {
settings = new_settings.clone();
let new_settings =
AllLanguageModelSettings::get_global(cx).lmstudio.clone();
if settings != new_settings {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx).into();
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
settings = new_settings;
this.restart_fetch_models_task(cx);
cx.notify();
}
@ -173,6 +197,7 @@ impl LmStudioLanguageModelProvider {
Self::api_url(cx).into(),
(*API_KEY_ENV_VAR).clone(),
),
credentials_provider,
http_client,
available_models: Default::default(),
fetch_model_task: None,

View file

@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
@ -43,6 +44,7 @@ pub struct MistralLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -51,15 +53,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = MistralLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = MistralLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
@ -73,20 +86,30 @@ impl MistralLanguageModelProvider {
.map(|this| &this.0)
}
pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
pub fn global(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Arc<Self> {
if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
return this.0.clone();
}
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -1,4 +1,5 @@
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
@ -54,6 +55,7 @@ pub struct OllamaLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
fetched_models: Vec<ollama::Model>,
fetch_model_task: Option<Task<Result<()>>>,
@ -65,10 +67,15 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
let task = self
.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx);
let task = self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
self.fetched_models.clear();
cx.spawn(async move |this, cx| {
@ -80,10 +87,14 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OllamaLanguageModelProvider::api_url(cx);
let task = self
.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
let task = self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
// Always try to fetch models - if no API key is needed (local Ollama), it will work
// If API key is needed and provided, it will work
@ -157,7 +168,11 @@ impl State {
}
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
@ -170,6 +185,14 @@ impl OllamaLanguageModelProvider {
let url_changed = last_settings.api_url != current_settings.api_url;
last_settings = current_settings.clone();
if url_changed {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
this.fetched_models.clear();
this.authenticate(cx).detach();
}
@ -184,6 +207,7 @@ impl OllamaLanguageModelProvider {
fetched_models: Default::default(),
fetch_model_task: None,
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
}),
};

View file

@ -1,5 +1,6 @@
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
@ -55,6 +56,7 @@ pub struct OpenAiLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -63,30 +65,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OpenAiLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OpenAiLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use convert_case::{Case, Casing};
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@ -44,6 +45,7 @@ pub struct State {
id: Arc<str>,
api_key_state: ApiKeyState,
settings: OpenAiCompatibleSettings,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -52,20 +54,36 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = SharedString::new(self.settings.api_url.as_str());
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = SharedString::new(self.settings.api_url.clone());
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl OpenAiCompatibleLanguageModelProvider {
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
id: Arc<str>,
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
crate::AllLanguageModelSettings::get_global(cx)
.openai_compatible
@ -79,10 +97,12 @@ impl OpenAiCompatibleLanguageModelProvider {
return;
};
if &this.settings != &settings {
let credentials_provider = this.credentials_provider.clone();
let api_url = SharedString::new(settings.api_url.as_str());
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
this.settings = settings;
@ -98,6 +118,7 @@ impl OpenAiCompatibleLanguageModelProvider {
EnvVar::new(api_key_env_var_name),
),
settings,
credentials_provider,
}
});

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use collections::HashMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
use http_client::HttpClient;
@ -42,6 +43,7 @@ pub struct OpenRouterLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
available_models: Vec<open_router::Model>,
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
@ -53,16 +55,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let task = self
.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
let task = self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.spawn(async move |this, cx| {
let result = task.await;
@ -114,7 +126,11 @@ impl State {
}
impl OpenRouterLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>({
let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
@ -131,6 +147,7 @@ impl OpenRouterLanguageModelProvider {
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@ -43,6 +44,7 @@ pub struct OpenCodeLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -51,30 +53,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OpenCodeLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = OpenCodeLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl OpenCodeLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
@ -38,6 +39,7 @@ pub struct VercelLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -46,30 +48,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = VercelLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = VercelLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl VercelLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
@ -41,6 +42,7 @@ pub struct VercelAiGatewayLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
http_client: Arc<dyn HttpClient>,
available_models: Vec<AvailableModel>,
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
@ -52,16 +54,26 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
let task = self
.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
let task = self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.spawn(async move |this, cx| {
let result = task.await;
@ -100,7 +112,11 @@ impl State {
}
impl VercelAiGatewayLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>({
let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
@ -116,6 +132,7 @@ impl VercelAiGatewayLanguageModelProvider {
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
@ -39,6 +40,7 @@ pub struct XAiLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
@ -47,30 +49,51 @@ impl State {
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl XAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});

View file

@ -98,6 +98,7 @@ watch.workspace = true
wax.workspace = true
which.workspace = true
worktree.workspace = true
zed_credentials_provider.workspace = true
zeroize.workspace = true
zlog.workspace = true
ztracing.workspace = true

View file

@ -684,7 +684,7 @@ impl ContextServerStore {
let server_url = url.clone();
let id = id.clone();
cx.spawn(async move |_this, cx| {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
{
log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
@ -797,8 +797,7 @@ impl ContextServerStore {
if configuration.has_static_auth_header() {
None
} else {
let credentials_provider =
cx.update(|cx| <dyn CredentialsProvider>::global(cx));
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
let http_client = cx.update(|cx| cx.http_client());
match Self::load_session(&credentials_provider, url, &cx).await {
@ -1070,7 +1069,7 @@ impl ContextServerStore {
.context("Failed to start OAuth callback server")?;
let http_client = cx.update(|cx| cx.http_client());
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
let server_url = match configuration.as_ref() {
ContextServerConfiguration::Http { url, .. } => url.clone(),
_ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
@ -1233,7 +1232,7 @@ impl ContextServerStore {
self.stop_server(&id, cx)?;
cx.spawn(async move |this, cx| {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
log::error!("{} failed to clear OAuth session: {}", id, err);
}
@ -1451,7 +1450,7 @@ async fn resolve_start_failure(
// (e.g. timeout because the server rejected the token silently). Clear it
// so the next start attempt can get a clean 401 and trigger the auth flow.
if www_authenticate.is_none() {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
Ok(Some(_)) => {
log::info!("{id} start failed with a cached OAuth session present; clearing it");

View file

@ -59,6 +59,7 @@ ui.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_credentials_provider.workspace = true
[dev-dependencies]
fs = { workspace = true, features = ["test-support"] }

View file

@ -185,9 +185,15 @@ fn render_api_key_provider(
cx: &mut Context<SettingsWindow>,
) -> impl IntoElement {
let weak_page = cx.weak_entity();
let credentials_provider = zed_credentials_provider::global(cx);
_ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
let task = api_key_state.update(cx, |key_state, cx| {
key_state.load_if_needed(current_url(cx), |state| state, cx)
key_state.load_if_needed(
current_url(cx),
|state| state,
credentials_provider.clone(),
cx,
)
});
cx.spawn(async move |_, cx| {
task.await.ok();
@ -208,10 +214,17 @@ fn render_api_key_provider(
});
let write_key = move |api_key: Option<String>, cx: &mut App| {
let credentials_provider = zed_credentials_provider::global(cx);
api_key_state
.update(cx, |key_state, cx| {
let url = current_url(cx);
key_state.store(url, api_key, |key_state| key_state, cx)
key_state.store(
url,
api_key,
|key_state| key_state,
credentials_provider,
cx,
)
})
.detach_and_log_err(cx);
};

View file

@ -1,13 +1,13 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::{Client, UserStore};
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
use language_model::LlmApiToken;
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@ -30,7 +30,7 @@ pub struct State {
impl State {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let llm_api_token = LlmApiToken::global(cx);
let llm_api_token = global_llm_token(cx);
Self {
client,
@ -73,8 +73,8 @@ async fn perform_web_search(
let http_client = &client.http_client();
let mut retries_remaining = MAX_RETRIES;
let mut token = llm_api_token
.acquire(&client, organization_id.clone())
let mut token = client
.acquire_llm_token(&llm_api_token, organization_id.clone())
.await?;
loop {
@ -100,8 +100,8 @@ async fn perform_web_search(
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if response.needs_llm_token_refresh() {
token = llm_api_token
.refresh(&client, organization_id.clone())
token = client
.refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
retries_remaining -= 1;
} else {

View file

@ -10,7 +10,7 @@ use agent_ui::AgentPanel;
use anyhow::{Context as _, Error, Result};
use clap::Parser;
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{Client, ProxySettings, UserStore, parse_zed_link};
use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link};
use collab_ui::channel_view::ChannelView;
use collections::HashMap;
use crashes::InitCrashHandler;
@ -664,7 +664,12 @@ fn main() {
);
copilot_ui::init(&app_state, cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_model::init(cx);
RefreshLlmTokenListener::register(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
zed::telemetry_log::init(cx);

View file

@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
});
prompt_store::init(cx);
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_model::init(cx);
client::RefreshLlmTokenListener::register(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
git_ui::init(cx);
project::AgentRegistryStore::init_global(

View file

@ -5189,7 +5189,12 @@ mod tests {
cx,
);
image_viewer::init(cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_model::init(cx);
client::RefreshLlmTokenListener::register(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);

View file

@ -313,7 +313,12 @@ mod tests {
let app_state = cx.update(|cx| {
let app_state = AppState::test(cx);
client::init(&app_state.client, cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_model::init(cx);
client::RefreshLlmTokenListener::register(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
);
editor::init(cx);
app_state
});

View file

@ -0,0 +1,22 @@
[package]
name = "zed_credentials_provider"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/zed_credentials_provider.rs"
[dependencies]
anyhow.workspace = true
credentials_provider.workspace = true
futures.workspace = true
gpui.workspace = true
paths.workspace = true
release_channel.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,181 @@
use std::collections::HashMap;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use anyhow::Result;
use credentials_provider::CredentialsProvider;
use futures::FutureExt as _;
use gpui::{App, AsyncApp, Global};
use release_channel::ReleaseChannel;
/// An environment variable whose presence indicates that the system keychain
/// should be used in development.
///
/// By default, running Zed in development uses the development credentials
/// provider. Setting this environment variable allows you to interact with the
/// system keychain (for instance, if you need to test something).
///
/// Only works in development. Setting this environment variable in other
/// release channels is a no-op.
static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
});
pub struct ZedCredentialsProvider(pub Arc<dyn CredentialsProvider>);
impl Global for ZedCredentialsProvider {}
/// Returns the global [`CredentialsProvider`].
pub fn init_global(cx: &mut App) {
// The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
// seems like this is a false positive from Clippy.
#[allow(clippy::arc_with_non_send_sync)]
let provider = new(cx);
cx.set_global(ZedCredentialsProvider(provider));
}
pub fn global(cx: &App) -> Arc<dyn CredentialsProvider> {
cx.try_global::<ZedCredentialsProvider>()
.map(|provider| provider.0.clone())
.unwrap_or_else(|| new(cx))
}
fn new(cx: &App) -> Arc<dyn CredentialsProvider> {
let use_development_provider = match ReleaseChannel::try_global(cx) {
Some(ReleaseChannel::Dev) => {
// In development we default to using the development
// credentials provider to avoid getting spammed by relentless
// keychain access prompts.
//
// However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
// variable is set, we will use the actual keychain.
!*ZED_DEVELOPMENT_USE_KEYCHAIN
}
Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) | None => {
false
}
};
if use_development_provider {
Arc::new(DevelopmentCredentialsProvider::new())
} else {
Arc::new(KeychainCredentialsProvider)
}
}
/// A credentials provider that stores credentials in the system keychain.
struct KeychainCredentialsProvider;
impl CredentialsProvider for KeychainCredentialsProvider {
fn read_credentials<'a>(
&'a self,
url: &'a str,
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
}
fn write_credentials<'a>(
&'a self,
url: &'a str,
username: &'a str,
password: &'a [u8],
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
cx.update(move |cx| cx.write_credentials(url, username, password))
.await
}
.boxed_local()
}
fn delete_credentials<'a>(
&'a self,
url: &'a str,
cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
}
}
/// A credentials provider that stores credentials in a local file.
///
/// This MUST only be used in development, as this is not a secure way of storing
/// credentials on user machines.
///
/// Its existence is purely to work around the annoyance of having to constantly
/// re-allow access to the system keychain when developing Zed.
struct DevelopmentCredentialsProvider {
path: PathBuf,
}
impl DevelopmentCredentialsProvider {
fn new() -> Self {
let path = paths::config_dir().join("development_credentials");
Self { path }
}
fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
let json = std::fs::read(&self.path)?;
let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
Ok(credentials)
}
fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
let json = serde_json::to_string(credentials)?;
std::fs::write(&self.path, json)?;
Ok(())
}
}
impl CredentialsProvider for DevelopmentCredentialsProvider {
fn read_credentials<'a>(
&'a self,
url: &'a str,
_cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
async move {
Ok(self
.load_credentials()
.unwrap_or_default()
.get(url)
.cloned())
}
.boxed_local()
}
fn write_credentials<'a>(
&'a self,
url: &'a str,
username: &'a str,
password: &'a [u8],
_cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
let mut credentials = self.load_credentials().unwrap_or_default();
credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
self.save_credentials(&credentials)
}
.boxed_local()
}
fn delete_credentials<'a>(
&'a self,
url: &'a str,
_cx: &'a AsyncApp,
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
async move {
let mut credentials = self.load_credentials()?;
credentials.remove(url);
self.save_credentials(&credentials)
}
.boxed_local()
}
}

View file

@ -15,4 +15,4 @@ path = "src/zed_env_vars.rs"
default = []
[dependencies]
gpui.workspace = true
env_var.workspace = true

View file

@ -1,45 +1,6 @@
use gpui::SharedString;
pub use env_var::{EnvVar, bool_env_var, env_var};
use std::sync::LazyLock;
/// Whether Zed is running in stateless mode.
/// When true, Zed will use in-memory databases instead of persistent storage.
pub static ZED_STATELESS: LazyLock<bool> = bool_env_var!("ZED_STATELESS");
#[derive(Clone)]
pub struct EnvVar {
pub name: SharedString,
/// Value of the environment variable. Also `None` when set to an empty string.
pub value: Option<String>,
}
impl EnvVar {
pub fn new(name: SharedString) -> Self {
let value = std::env::var(name.as_str()).ok();
if value.as_ref().is_some_and(|v| v.is_empty()) {
Self { name, value: None }
} else {
Self { name, value }
}
}
pub fn or(self, other: EnvVar) -> EnvVar {
if self.value.is_some() { self } else { other }
}
}
/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
#[macro_export]
macro_rules! env_var {
($name:expr) => {
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
};
}
/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
/// environment variable exists and is non-empty.
#[macro_export]
macro_rules! bool_env_var {
($name:expr) => {
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
};
}