mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
language_model: Decouple from Zed-specific implementation details (#52913)
This PR decouples `language_model`'s dependence on Zed-specific implementation details. In particular * `credentials_provider` is split into a generic `credentials_provider` crate that provides a trait, and `zed_credentials_provider` that implements the said trait for Zed-specific providers and has functions that can populate a global state with them * `zed_env_vars` is split into a generic `env_var` crate that provides generic tooling for managing env vars, and `zed_env_vars` that contains Zed-specific statics * `client` is now dependent on `language_model` and not vice versa Release Notes: - N/A
This commit is contained in:
parent
34c77a0eb9
commit
29609d3599
63 changed files with 1122 additions and 561 deletions
40
Cargo.lock
generated
40
Cargo.lock
generated
|
|
@ -260,7 +260,6 @@ dependencies = [
|
|||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"env_logger 0.11.8",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
|
|
@ -289,6 +288,7 @@ dependencies = [
|
|||
"util",
|
||||
"uuid",
|
||||
"watch",
|
||||
"zed_credentials_provider",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2856,6 +2856,7 @@ dependencies = [
|
|||
"chrono",
|
||||
"clock",
|
||||
"cloud_api_client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
|
|
@ -2869,6 +2870,7 @@ dependencies = [
|
|||
"http_client",
|
||||
"http_client_tls",
|
||||
"httparse",
|
||||
"language_model",
|
||||
"log",
|
||||
"objc2-foundation",
|
||||
"parking_lot",
|
||||
|
|
@ -2900,6 +2902,7 @@ dependencies = [
|
|||
"util",
|
||||
"windows 0.61.3",
|
||||
"worktree",
|
||||
"zed_credentials_provider",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3059,6 +3062,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"text",
|
||||
"zed_credentials_provider",
|
||||
"zeta_prompt",
|
||||
]
|
||||
|
||||
|
|
@ -4035,12 +4039,8 @@ name = "credentials_provider"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"paths",
|
||||
"release_channel",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -5115,6 +5115,7 @@ dependencies = [
|
|||
"collections",
|
||||
"copilot",
|
||||
"copilot_ui",
|
||||
"credentials_provider",
|
||||
"ctor",
|
||||
"db",
|
||||
"edit_prediction_context",
|
||||
|
|
@ -5157,6 +5158,7 @@ dependencies = [
|
|||
"workspace",
|
||||
"worktree",
|
||||
"zed_actions",
|
||||
"zed_credentials_provider",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
"zstd",
|
||||
|
|
@ -5583,6 +5585,13 @@ dependencies = [
|
|||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_var"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"gpui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "envy"
|
||||
version = "0.4.2"
|
||||
|
|
@ -9315,12 +9324,12 @@ dependencies = [
|
|||
"anthropic",
|
||||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
"client",
|
||||
"cloud_api_client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"env_var",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
|
|
@ -9336,7 +9345,6 @@ dependencies = [
|
|||
"smol",
|
||||
"thiserror 2.0.17",
|
||||
"util",
|
||||
"zed_env_vars",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -13137,6 +13145,7 @@ dependencies = [
|
|||
"wax",
|
||||
"which 6.0.3",
|
||||
"worktree",
|
||||
"zed_credentials_provider",
|
||||
"zeroize",
|
||||
"zlog",
|
||||
"ztracing",
|
||||
|
|
@ -15746,6 +15755,7 @@ dependencies = [
|
|||
"util",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"zed_credentials_provider",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -22179,11 +22189,25 @@ dependencies = [
|
|||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zed_credentials_provider"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"credentials_provider",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"paths",
|
||||
"release_channel",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zed_env_vars"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"gpui",
|
||||
"env_var",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ members = [
|
|||
"crates/edit_prediction_ui",
|
||||
"crates/editor",
|
||||
"crates/encoding_selector",
|
||||
"crates/env_var",
|
||||
"crates/etw_tracing",
|
||||
"crates/eval_cli",
|
||||
"crates/eval_utils",
|
||||
|
|
@ -220,6 +221,7 @@ members = [
|
|||
"crates/x_ai",
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zed_credentials_provider",
|
||||
"crates/zed_env_vars",
|
||||
"crates/zeta_prompt",
|
||||
"crates/zlog",
|
||||
|
|
@ -309,6 +311,7 @@ dev_container = { path = "crates/dev_container" }
|
|||
diagnostics = { path = "crates/diagnostics" }
|
||||
editor = { path = "crates/editor" }
|
||||
encoding_selector = { path = "crates/encoding_selector" }
|
||||
env_var = { path = "crates/env_var" }
|
||||
etw_tracing = { path = "crates/etw_tracing" }
|
||||
eval_utils = { path = "crates/eval_utils" }
|
||||
extension = { path = "crates/extension" }
|
||||
|
|
@ -465,6 +468,7 @@ worktree = { path = "crates/worktree" }
|
|||
x_ai = { path = "crates/x_ai" }
|
||||
zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zed_credentials_provider = { path = "crates/zed_credentials_provider" }
|
||||
zed_env_vars = { path = "crates/zed_env_vars" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
zeta_prompt = { path = "crates/zeta_prompt" }
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
ListDirectoryTool, ListDirectoryToolInput, ReadFileTool, ReadFileToolInput,
|
||||
};
|
||||
use Role::*;
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, RefreshLlmTokenListener, UserStore};
|
||||
use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind};
|
||||
use fs::FakeFs;
|
||||
use futures::{FutureExt, future::LocalBoxFuture};
|
||||
|
|
@ -1423,7 +1423,8 @@ impl EditAgentTest {
|
|||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
settings::init(cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use acp_thread::{
|
|||
use agent_client_protocol::{self as acp};
|
||||
use agent_settings::AgentProfileId;
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, RefreshLlmTokenListener, UserStore};
|
||||
use collections::IndexMap;
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
|
|
@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
let clock = Arc::new(clock::FakeSystemClock::new());
|
||||
let client = Client::new(clock, http_client, cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
|
|
@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
cx.set_http_client(Arc::new(http_client));
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
};
|
||||
use Role::*;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, RefreshLlmTokenListener, UserStore};
|
||||
use fs::FakeFs;
|
||||
use futures::{FutureExt, StreamExt, future::LocalBoxFuture};
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _};
|
||||
|
|
@ -274,7 +274,8 @@ impl StreamingEditToolTest {
|
|||
cx.set_http_client(http_client);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store, client, cx);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
feature_flags.workspace = true
|
||||
gpui_tokio = { workspace = true, optional = true }
|
||||
credentials_provider.workspace = true
|
||||
google_ai.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
|
|
@ -53,6 +52,7 @@ terminal.workspace = true
|
|||
uuid.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
zed_credentials_provider.workspace = true
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
libc.workspace = true
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ use acp_thread::AgentConnection;
|
|||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashSet;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use fs::Fs;
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use language_model::{ApiKey, EnvVar};
|
||||
|
|
@ -392,7 +391,7 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
|
|||
if let Some(key) = env_var.value {
|
||||
return Task::ready(Ok(key));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
let api_url = google_ai::API_URL.to_string();
|
||||
cx.spawn(async move |cx| {
|
||||
Ok(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{AgentServer, AgentServerDelegate};
|
||||
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
|
||||
use agent_client_protocol as acp;
|
||||
use client::RefreshLlmTokenListener;
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::AppContext;
|
||||
use gpui::{Entity, TestAppContext};
|
||||
|
|
@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
|||
cx.set_http_client(Arc::new(http_client));
|
||||
let client = client::Client::production(cx);
|
||||
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store, client, cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store, cx);
|
||||
|
||||
#[cfg(test)]
|
||||
project::agent_server_store::AllAgentServersSettings::override_global(
|
||||
|
|
|
|||
|
|
@ -815,7 +815,7 @@ mod tests {
|
|||
cx.set_global(store);
|
||||
theme_settings::init(theme::LoadThemes::JustBase, cx);
|
||||
|
||||
language_model::init_settings(cx);
|
||||
language_model::init(cx);
|
||||
editor::init(cx);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1809,7 +1809,7 @@ mod tests {
|
|||
cx.set_global(settings_store);
|
||||
prompt_store::init(cx);
|
||||
theme_settings::init(theme::LoadThemes::JustBase, cx);
|
||||
language_model::init_settings(cx);
|
||||
language_model::init(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
|
@ -1966,7 +1966,7 @@ mod tests {
|
|||
cx.set_global(settings_store);
|
||||
prompt_store::init(cx);
|
||||
theme_settings::init(theme::LoadThemes::JustBase, cx);
|
||||
language_model::init_settings(cx);
|
||||
language_model::init(cx);
|
||||
workspace::register_project_item::<Editor>(cx);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -2025,7 +2025,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
|||
pub mod evals {
|
||||
use crate::InlineAssistant;
|
||||
use agent::ThreadStore;
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, RefreshLlmTokenListener, UserStore};
|
||||
use editor::{Editor, MultiBuffer, MultiBufferOffset};
|
||||
use eval_utils::{EvalOutput, NoProcessor};
|
||||
use fs::FakeFs;
|
||||
|
|
@ -2091,7 +2091,8 @@ pub mod evals {
|
|||
client::init(&client, cx);
|
||||
workspace::init(app_state.clone(), cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
|
||||
cx.set_global(inline_assistant);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ base64.workspace = true
|
|||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
|
|
@ -35,6 +36,7 @@ gpui_tokio.workspace = true
|
|||
http_client.workspace = true
|
||||
http_client_tls.workspace = true
|
||||
httparse = "1.10"
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
|
|
@ -60,6 +62,7 @@ tokio.workspace = true
|
|||
url.workspace = true
|
||||
util.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_credentials_provider.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod llm_token;
|
||||
mod proxy;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
|
|
@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
|
|||
http::{HeaderValue, Request, StatusCode},
|
||||
};
|
||||
use clock::SystemClock;
|
||||
use cloud_api_client::CloudApiClient;
|
||||
use cloud_api_client::websocket_protocol::MessageToClient;
|
||||
use cloud_api_client::{ClientApiError, CloudApiClient};
|
||||
use cloud_api_types::OrganizationId;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use futures::{
|
||||
|
|
@ -24,6 +26,7 @@ use futures::{
|
|||
};
|
||||
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
|
||||
use language_model::LlmApiToken;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
|
|
@ -51,6 +54,7 @@ use tokio::net::TcpStream;
|
|||
use url::Url;
|
||||
use util::{ConnectionResult, ResultExt};
|
||||
|
||||
pub use llm_token::*;
|
||||
pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
pub use user::*;
|
||||
|
|
@ -339,7 +343,7 @@ pub struct ClientCredentialsProvider {
|
|||
impl ClientCredentialsProvider {
|
||||
pub fn new(cx: &App) -> Self {
|
||||
Self {
|
||||
provider: <dyn CredentialsProvider>::global(cx),
|
||||
provider: zed_credentials_provider::global(cx),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -568,6 +572,10 @@ impl Client {
|
|||
self.http.clone()
|
||||
}
|
||||
|
||||
pub fn credentials_provider(&self) -> Arc<dyn CredentialsProvider> {
|
||||
self.credentials_provider.provider.clone()
|
||||
}
|
||||
|
||||
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
|
||||
self.cloud_client.clone()
|
||||
}
|
||||
|
|
@ -1513,6 +1521,66 @@ impl Client {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn acquire_llm_token(
|
||||
&self,
|
||||
llm_token: &LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
let system_id = self.telemetry().system_id().map(|x| x.to_string());
|
||||
let cloud_client = self.cloud_client();
|
||||
match llm_token
|
||||
.acquire(&cloud_client, system_id, organization_id)
|
||||
.await
|
||||
{
|
||||
Ok(token) => Ok(token),
|
||||
Err(ClientApiError::Unauthorized) => {
|
||||
self.request_sign_out();
|
||||
Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
|
||||
}
|
||||
Err(err) => Err(anyhow::Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_token(
|
||||
&self,
|
||||
llm_token: &LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
let system_id = self.telemetry().system_id().map(|x| x.to_string());
|
||||
let cloud_client = self.cloud_client();
|
||||
match llm_token
|
||||
.refresh(&cloud_client, system_id, organization_id)
|
||||
.await
|
||||
{
|
||||
Ok(token) => Ok(token),
|
||||
Err(ClientApiError::Unauthorized) => {
|
||||
self.request_sign_out();
|
||||
return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
|
||||
}
|
||||
Err(err) => return Err(anyhow::Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear_and_refresh_llm_token(
|
||||
&self,
|
||||
llm_token: &LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
let system_id = self.telemetry().system_id().map(|x| x.to_string());
|
||||
let cloud_client = self.cloud_client();
|
||||
match llm_token
|
||||
.clear_and_refresh(&cloud_client, system_id, organization_id)
|
||||
.await
|
||||
{
|
||||
Ok(token) => Ok(token),
|
||||
Err(ClientApiError::Unauthorized) => {
|
||||
self.request_sign_out();
|
||||
return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
|
||||
}
|
||||
Err(err) => return Err(anyhow::Error::from(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
|
||||
self.state.write().credentials = None;
|
||||
self.cloud_client.clear_credentials();
|
||||
|
|
|
|||
116
crates/client/src/llm_token.rs
Normal file
116
crates/client/src/llm_token.rs
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
use super::{Client, UserStore};
|
||||
use cloud_api_types::websocket_protocol::MessageToClient;
|
||||
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
|
||||
};
|
||||
use language_model::LlmApiToken;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait NeedsLlmTokenRefresh {
|
||||
/// Returns whether the LLM token needs to be refreshed.
|
||||
fn needs_llm_token_refresh(&self) -> bool;
|
||||
}
|
||||
|
||||
impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
|
||||
fn needs_llm_token_refresh(&self) -> bool {
|
||||
self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
|
||||
|| self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
|
||||
}
|
||||
}
|
||||
|
||||
enum TokenRefreshMode {
|
||||
Refresh,
|
||||
ClearAndRefresh,
|
||||
}
|
||||
|
||||
pub fn global_llm_token(cx: &App) -> LlmApiToken {
|
||||
RefreshLlmTokenListener::global(cx)
|
||||
.read(cx)
|
||||
.llm_api_token
|
||||
.clone()
|
||||
}
|
||||
|
||||
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
|
||||
|
||||
impl Global for GlobalRefreshLlmTokenListener {}
|
||||
|
||||
pub struct LlmTokenRefreshedEvent;
|
||||
|
||||
pub struct RefreshLlmTokenListener {
|
||||
client: Arc<Client>,
|
||||
user_store: Entity<UserStore>,
|
||||
llm_api_token: LlmApiToken,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
|
||||
|
||||
impl RefreshLlmTokenListener {
|
||||
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
|
||||
cx.set_global(GlobalRefreshLlmTokenListener(listener));
|
||||
}
|
||||
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
GlobalRefreshLlmTokenListener::global(cx).0.clone()
|
||||
}
|
||||
|
||||
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
|
||||
client.add_message_to_client_handler({
|
||||
let this = cx.weak_entity();
|
||||
move |message, cx| {
|
||||
if let Some(this) = this.upgrade() {
|
||||
Self::handle_refresh_llm_token(this, message, cx);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
|
||||
if matches!(event, super::user::Event::OrganizationChanged) {
|
||||
this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
client,
|
||||
user_store,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_subscription: subscription,
|
||||
}
|
||||
}
|
||||
|
||||
fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let organization_id = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
cx.spawn(async move |this, cx| {
|
||||
match mode {
|
||||
TokenRefreshMode::Refresh => {
|
||||
client
|
||||
.refresh_llm_token(&llm_api_token, organization_id)
|
||||
.await?;
|
||||
}
|
||||
TokenRefreshMode::ClearAndRefresh => {
|
||||
client
|
||||
.clear_and_refresh_llm_token(&llm_api_token, organization_id)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
|
||||
match message {
|
||||
MessageToClient::UserUpdated => {
|
||||
this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ log.workspace = true
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
text.workspace = true
|
||||
zed_credentials_provider.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
|||
|
|
@ -48,9 +48,10 @@ pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
|
|||
}
|
||||
|
||||
pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
let api_url = codestral_api_url(cx);
|
||||
codestral_api_key_state(cx).update(cx, |key_state, cx| {
|
||||
key_state.load_if_needed(api_url, |s| s, cx)
|
||||
key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,5 @@ path = "src/credentials_provider.rs"
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
paths.workspace = true
|
||||
release_channel.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,26 +1,8 @@
|
|||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::FutureExt as _;
|
||||
use gpui::{App, AsyncApp};
|
||||
use release_channel::ReleaseChannel;
|
||||
|
||||
/// An environment variable whose presence indicates that the system keychain
|
||||
/// should be used in development.
|
||||
///
|
||||
/// By default, running Zed in development uses the development credentials
|
||||
/// provider. Setting this environment variable allows you to interact with the
|
||||
/// system keychain (for instance, if you need to test something).
|
||||
///
|
||||
/// Only works in development. Setting this environment variable in other
|
||||
/// release channels is a no-op.
|
||||
static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
|
||||
std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
|
||||
});
|
||||
use gpui::AsyncApp;
|
||||
|
||||
/// A provider for credentials.
|
||||
///
|
||||
|
|
@ -50,150 +32,3 @@ pub trait CredentialsProvider: Send + Sync {
|
|||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
|
||||
}
|
||||
|
||||
impl dyn CredentialsProvider {
|
||||
/// Returns the global [`CredentialsProvider`].
|
||||
pub fn global(cx: &App) -> Arc<Self> {
|
||||
// The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
|
||||
// seems like this is a false positive from Clippy.
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
Self::new(cx)
|
||||
}
|
||||
|
||||
fn new(cx: &App) -> Arc<Self> {
|
||||
let use_development_provider = match ReleaseChannel::try_global(cx) {
|
||||
Some(ReleaseChannel::Dev) => {
|
||||
// In development we default to using the development
|
||||
// credentials provider to avoid getting spammed by relentless
|
||||
// keychain access prompts.
|
||||
//
|
||||
// However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
|
||||
// variable is set, we will use the actual keychain.
|
||||
!*ZED_DEVELOPMENT_USE_KEYCHAIN
|
||||
}
|
||||
Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
|
||||
| None => false,
|
||||
};
|
||||
|
||||
if use_development_provider {
|
||||
Arc::new(DevelopmentCredentialsProvider::new())
|
||||
} else {
|
||||
Arc::new(KeychainCredentialsProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A credentials provider that stores credentials in the system keychain.
|
||||
struct KeychainCredentialsProvider;
|
||||
|
||||
impl CredentialsProvider for KeychainCredentialsProvider {
|
||||
fn read_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
|
||||
async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
|
||||
}
|
||||
|
||||
fn write_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
username: &'a str,
|
||||
password: &'a [u8],
|
||||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move {
|
||||
cx.update(move |cx| cx.write_credentials(url, username, password))
|
||||
.await
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
||||
fn delete_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
/// A credentials provider that stores credentials in a local file.
|
||||
///
|
||||
/// This MUST only be used in development, as this is not a secure way of storing
|
||||
/// credentials on user machines.
|
||||
///
|
||||
/// Its existence is purely to work around the annoyance of having to constantly
|
||||
/// re-allow access to the system keychain when developing Zed.
|
||||
struct DevelopmentCredentialsProvider {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl DevelopmentCredentialsProvider {
|
||||
fn new() -> Self {
|
||||
let path = paths::config_dir().join("development_credentials");
|
||||
|
||||
Self { path }
|
||||
}
|
||||
|
||||
fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
|
||||
let json = std::fs::read(&self.path)?;
|
||||
let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
|
||||
let json = serde_json::to_string(credentials)?;
|
||||
std::fs::write(&self.path, json)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialsProvider for DevelopmentCredentialsProvider {
|
||||
fn read_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
_cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
|
||||
async move {
|
||||
Ok(self
|
||||
.load_credentials()
|
||||
.unwrap_or_default()
|
||||
.get(url)
|
||||
.cloned())
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
||||
fn write_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
username: &'a str,
|
||||
password: &'a [u8],
|
||||
_cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move {
|
||||
let mut credentials = self.load_credentials().unwrap_or_default();
|
||||
credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
|
||||
|
||||
self.save_credentials(&credentials)
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
||||
fn delete_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
_cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move {
|
||||
let mut credentials = self.load_credentials()?;
|
||||
credentials.remove(url);
|
||||
|
||||
self.save_credentials(&credentials)
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ cloud_llm_client.workspace = true
|
|||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
copilot_ui.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
db.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
|
|
@ -65,6 +66,7 @@ uuid.workspace = true
|
|||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zed_credentials_provider.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
zstd.workspace = true
|
||||
|
||||
|
|
|
|||
|
|
@ -258,6 +258,7 @@ fn generate_timestamp_name() -> String {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::EditPredictionStore;
|
||||
use client::RefreshLlmTokenListener;
|
||||
use client::{Client, UserStore};
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
|
||||
|
|
@ -548,7 +549,8 @@ mod tests {
|
|||
let http_client = FakeHttpClient::with_404_response();
|
||||
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
EditPredictionStore::global(&client, &user_store, cx);
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use anyhow::Result;
|
||||
use client::{Client, EditPredictionUsage, UserStore};
|
||||
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
|
||||
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
|
||||
|
|
@ -11,6 +11,7 @@ use cloud_llm_client::{
|
|||
};
|
||||
use collections::{HashMap, HashSet};
|
||||
use copilot::{Copilot, Reinstall, SignIn, SignOut};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use db::kvp::{Dismissable, KeyValueStore};
|
||||
use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
|
||||
|
|
@ -30,7 +31,7 @@ use heapless::Vec as ArrayVec;
|
|||
use language::language_settings::all_language_settings;
|
||||
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
|
||||
use language::{BufferSnapshot, OffsetRangeExt};
|
||||
use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
|
||||
use language_model::LlmApiToken;
|
||||
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
|
||||
use release_channel::AppVersion;
|
||||
use semver::Version;
|
||||
|
|
@ -150,6 +151,7 @@ pub struct EditPredictionStore {
|
|||
rated_predictions: HashSet<EditPredictionId>,
|
||||
#[cfg(test)]
|
||||
settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
pub(crate) struct EditPredictionRejectionPayload {
|
||||
|
|
@ -746,7 +748,7 @@ impl EditPredictionStore {
|
|||
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
|
||||
let data_collection_choice = Self::load_data_collection_choice(cx);
|
||||
|
||||
let llm_token = LlmApiToken::global(cx);
|
||||
let llm_token = global_llm_token(cx);
|
||||
|
||||
let (reject_tx, reject_rx) = mpsc::unbounded();
|
||||
cx.background_spawn({
|
||||
|
|
@ -787,6 +789,8 @@ impl EditPredictionStore {
|
|||
.log_err();
|
||||
});
|
||||
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
|
||||
let this = Self {
|
||||
projects: HashMap::default(),
|
||||
client,
|
||||
|
|
@ -807,6 +811,8 @@ impl EditPredictionStore {
|
|||
shown_predictions: Default::default(),
|
||||
#[cfg(test)]
|
||||
settled_event_callback: None,
|
||||
|
||||
credentials_provider,
|
||||
};
|
||||
|
||||
this
|
||||
|
|
@ -871,7 +877,9 @@ impl EditPredictionStore {
|
|||
let experiments = cx
|
||||
.background_spawn(async move {
|
||||
let http_client = client.http_client();
|
||||
let token = llm_token.acquire(&client, organization_id).await?;
|
||||
let token = client
|
||||
.acquire_llm_token(&llm_token, organization_id.clone())
|
||||
.await?;
|
||||
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
|
|
@ -2315,7 +2323,10 @@ impl EditPredictionStore {
|
|||
zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
|
||||
}
|
||||
EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
|
||||
EditPredictionModel::Mercury => {
|
||||
self.mercury
|
||||
.request_prediction(inputs, self.credentials_provider.clone(), cx)
|
||||
}
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
|
|
@ -2536,12 +2547,15 @@ impl EditPredictionStore {
|
|||
Res: DeserializeOwned,
|
||||
{
|
||||
let http_client = client.http_client();
|
||||
|
||||
let mut token = if require_auth {
|
||||
Some(llm_token.acquire(&client, organization_id.clone()).await?)
|
||||
Some(
|
||||
client
|
||||
.acquire_llm_token(&llm_token, organization_id.clone())
|
||||
.await?,
|
||||
)
|
||||
} else {
|
||||
llm_token
|
||||
.acquire(&client, organization_id.clone())
|
||||
client
|
||||
.acquire_llm_token(&llm_token, organization_id.clone())
|
||||
.await
|
||||
.ok()
|
||||
};
|
||||
|
|
@ -2585,7 +2599,11 @@ impl EditPredictionStore {
|
|||
return Ok((serde_json::from_slice(&body)?, usage));
|
||||
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
|
||||
did_retry = true;
|
||||
token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
|
||||
token = Some(
|
||||
client
|
||||
.refresh_llm_token(&llm_token, organization_id.clone())
|
||||
.await?,
|
||||
);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use super::*;
|
||||
use crate::udiff::apply_diff_to_string;
|
||||
use client::{UserStore, test::FakeServer};
|
||||
use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
|
||||
use clock::FakeSystemClock;
|
||||
use clock::ReplicaId;
|
||||
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
|
||||
|
|
@ -23,7 +23,7 @@ use language::{
|
|||
Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
|
||||
DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
|
||||
};
|
||||
use language_model::RefreshLlmTokenListener;
|
||||
|
||||
use lsp::LanguageServerId;
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::{assert_eq, assert_matches};
|
||||
|
|
@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
|
|||
client.cloud_client().set_credentials(1, "test".into());
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
|
||||
|
||||
(
|
||||
|
|
@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
|
|||
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
||||
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
|
||||
cx.update(|cx| {
|
||||
language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
});
|
||||
|
||||
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use crate::{
|
|||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::EditPredictionRejectReason;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, Global, SharedString, Task,
|
||||
|
|
@ -51,10 +52,11 @@ impl Mercury {
|
|||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
self.api_token.update(cx, |key_state, cx| {
|
||||
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
|
||||
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
|
||||
});
|
||||
let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
|
||||
return Task::ready(Ok(None));
|
||||
|
|
@ -387,8 +389,9 @@ pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
|
|||
}
|
||||
|
||||
pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
mercury_api_token(cx).update(cx, |key_state, cx| {
|
||||
key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
|
||||
key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,9 +42,10 @@ pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
|
|||
pub fn load_open_ai_compatible_api_token(
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), language_model::AuthenticateError>> {
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
let api_url = open_ai_compatible_api_url(cx);
|
||||
open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
|
||||
key_state.load_if_needed(api_url, |s| s, cx)
|
||||
key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use client::{Client, ProxySettings, UserStore};
|
||||
use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
|
||||
use db::AppDatabase;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::RealFs;
|
||||
|
|
@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
|
|||
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
|
||||
prompt_store::init(cx);
|
||||
|
|
|
|||
15
crates/env_var/Cargo.toml
Normal file
15
crates/env_var/Cargo.toml
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
[package]
|
||||
name = "env_var"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/env_var.rs"
|
||||
|
||||
[dependencies]
|
||||
gpui.workspace = true
|
||||
1
crates/env_var/LICENSE-GPL
Symbolic link
1
crates/env_var/LICENSE-GPL
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
||||
40
crates/env_var/src/env_var.rs
Normal file
40
crates/env_var/src/env_var.rs
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
use gpui::SharedString;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EnvVar {
|
||||
pub name: SharedString,
|
||||
/// Value of the environment variable. Also `None` when set to an empty string.
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
impl EnvVar {
|
||||
pub fn new(name: SharedString) -> Self {
|
||||
let value = std::env::var(name.as_str()).ok();
|
||||
if value.as_ref().is_some_and(|v| v.is_empty()) {
|
||||
Self { name, value: None }
|
||||
} else {
|
||||
Self { name, value }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn or(self, other: EnvVar) -> EnvVar {
|
||||
if self.value.is_some() { self } else { other }
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
|
||||
#[macro_export]
|
||||
macro_rules! env_var {
|
||||
($name:expr) => {
|
||||
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
|
||||
};
|
||||
}
|
||||
|
||||
/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
|
||||
/// environment variable exists and is non-empty.
|
||||
#[macro_export]
|
||||
macro_rules! bool_env_var {
|
||||
($name:expr) => {
|
||||
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
|
||||
};
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, ProxySettings, UserStore};
|
||||
use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
|
||||
use db::AppDatabase;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::RealFs;
|
||||
|
|
@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
|
|||
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
|
||||
prompt_store::init(cx);
|
||||
|
|
|
|||
|
|
@ -20,11 +20,11 @@ anthropic = { workspace = true, features = ["schemars"] }
|
|||
anyhow.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
base64.workspace = true
|
||||
client.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
env_var.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
|
|
@ -40,7 +40,6 @@ serde_json.workspace = true
|
|||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
zed_env_vars.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use env_var::EnvVar;
|
||||
use futures::{FutureExt, future};
|
||||
use gpui::{AsyncApp, Context, SharedString, Task};
|
||||
use std::{
|
||||
|
|
@ -7,7 +8,6 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use zed_env_vars::EnvVar;
|
||||
|
||||
use crate::AuthenticateError;
|
||||
|
||||
|
|
@ -101,6 +101,7 @@ impl ApiKeyState {
|
|||
url: SharedString,
|
||||
key: Option<String>,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
|
||||
provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &Context<Ent>,
|
||||
) -> Task<Result<()>> {
|
||||
if self.is_from_env_var() {
|
||||
|
|
@ -108,18 +109,14 @@ impl ApiKeyState {
|
|||
"bug: attempted to store API key in system keychain when API key is from env var",
|
||||
)));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn(async move |ent, cx| {
|
||||
if let Some(key) = &key {
|
||||
credentials_provider
|
||||
provider
|
||||
.write_credentials(&url, "Bearer", key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
} else {
|
||||
credentials_provider
|
||||
.delete_credentials(&url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
provider.delete_credentials(&url, cx).await.log_err();
|
||||
}
|
||||
ent.update(cx, |ent, cx| {
|
||||
let this = get_this(ent);
|
||||
|
|
@ -144,12 +141,13 @@ impl ApiKeyState {
|
|||
&mut self,
|
||||
url: SharedString,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
|
||||
provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut Context<Ent>,
|
||||
) {
|
||||
if url != self.url {
|
||||
if !self.is_from_env_var() {
|
||||
// loading will continue even though this result task is dropped
|
||||
let _task = self.load_if_needed(url, get_this, cx);
|
||||
let _task = self.load_if_needed(url, get_this, provider, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -163,6 +161,7 @@ impl ApiKeyState {
|
|||
&mut self,
|
||||
url: SharedString,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
|
||||
provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut Context<Ent>,
|
||||
) -> Task<Result<(), AuthenticateError>> {
|
||||
if let LoadStatus::Loaded { .. } = &self.load_status
|
||||
|
|
@ -185,7 +184,7 @@ impl ApiKeyState {
|
|||
let task = if let Some(load_task) = &self.load_task {
|
||||
load_task.clone()
|
||||
} else {
|
||||
let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
|
||||
let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
|
||||
self.url = url;
|
||||
self.load_status = LoadStatus::NotPresent;
|
||||
self.load_task = Some(load_task.clone());
|
||||
|
|
@ -206,14 +205,13 @@ impl ApiKeyState {
|
|||
fn load<Ent: 'static>(
|
||||
url: SharedString,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
|
||||
provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &Context<Ent>,
|
||||
) -> Task<()> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn({
|
||||
async move |ent, cx| {
|
||||
let load_status =
|
||||
ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
|
||||
.await;
|
||||
ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
|
||||
ent.update(cx, |ent, cx| {
|
||||
let this = get_this(ent);
|
||||
this.url = url;
|
||||
|
|
|
|||
|
|
@ -11,12 +11,10 @@ pub mod tool_schema;
|
|||
pub mod fake_provider;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use client::UserStore;
|
||||
use cloud_llm_client::CompletionRequestStatus;
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
|
||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use http_client::{StatusCode, http};
|
||||
use icons::IconName;
|
||||
use parking_lot::Mutex;
|
||||
|
|
@ -36,15 +34,10 @@ pub use crate::registry::*;
|
|||
pub use crate::request::*;
|
||||
pub use crate::role::*;
|
||||
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
|
||||
pub use env_var::{EnvVar, env_var};
|
||||
pub use provider::*;
|
||||
pub use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
init_settings(cx);
|
||||
RefreshLlmTokenListener::register(client, user_store, cx);
|
||||
}
|
||||
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
pub fn init(cx: &mut App) {
|
||||
registry::init(cx);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,9 @@
|
|||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::Client;
|
||||
use client::UserStore;
|
||||
use cloud_api_client::ClientApiError;
|
||||
use cloud_api_client::CloudApiClient;
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_api_types::websocket_protocol::MessageToClient;
|
||||
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
|
||||
};
|
||||
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
||||
use thiserror::Error;
|
||||
|
||||
|
|
@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
|
|||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LlmApiToken {
|
||||
pub fn global(cx: &App) -> Self {
|
||||
RefreshLlmTokenListener::global(cx)
|
||||
.read(cx)
|
||||
.llm_api_token
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub async fn acquire(
|
||||
&self,
|
||||
client: &Arc<Client>,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
) -> Result<String, ClientApiError> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
if let Some(token) = lock.as_ref() {
|
||||
Ok(token.to_string())
|
||||
|
|
@ -49,6 +36,7 @@ impl LlmApiToken {
|
|||
Self::fetch(
|
||||
RwLockUpgradableReadGuard::upgrade(lock).await,
|
||||
client,
|
||||
system_id,
|
||||
organization_id,
|
||||
)
|
||||
.await
|
||||
|
|
@ -57,10 +45,11 @@ impl LlmApiToken {
|
|||
|
||||
pub async fn refresh(
|
||||
&self,
|
||||
client: &Arc<Client>,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
Self::fetch(self.0.write().await, client, organization_id).await
|
||||
) -> Result<String, ClientApiError> {
|
||||
Self::fetch(self.0.write().await, client, system_id, organization_id).await
|
||||
}
|
||||
|
||||
/// Clears the existing token before attempting to fetch a new one.
|
||||
|
|
@ -69,28 +58,22 @@ impl LlmApiToken {
|
|||
/// leave a token for the wrong organization.
|
||||
pub async fn clear_and_refresh(
|
||||
&self,
|
||||
client: &Arc<Client>,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
) -> Result<String, ClientApiError> {
|
||||
let mut lock = self.0.write().await;
|
||||
*lock = None;
|
||||
Self::fetch(lock, client, organization_id).await
|
||||
Self::fetch(lock, client, system_id, organization_id).await
|
||||
}
|
||||
|
||||
async fn fetch(
|
||||
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
||||
client: &Arc<Client>,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
let system_id = client
|
||||
.telemetry()
|
||||
.system_id()
|
||||
.map(|system_id| system_id.to_string());
|
||||
|
||||
let result = client
|
||||
.cloud_client()
|
||||
.create_llm_token(system_id, organization_id)
|
||||
.await;
|
||||
) -> Result<String, ClientApiError> {
|
||||
let result = client.create_llm_token(system_id, organization_id).await;
|
||||
match result {
|
||||
Ok(response) => {
|
||||
*lock = Some(response.token.0.clone());
|
||||
|
|
@ -98,112 +81,7 @@ impl LlmApiToken {
|
|||
}
|
||||
Err(err) => {
|
||||
*lock = None;
|
||||
match err {
|
||||
ClientApiError::Unauthorized => {
|
||||
client.request_sign_out();
|
||||
Err(err).context("Failed to create LLM token")
|
||||
}
|
||||
ClientApiError::Other(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NeedsLlmTokenRefresh {
|
||||
/// Returns whether the LLM token needs to be refreshed.
|
||||
fn needs_llm_token_refresh(&self) -> bool;
|
||||
}
|
||||
|
||||
impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
|
||||
fn needs_llm_token_refresh(&self) -> bool {
|
||||
self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
|
||||
|| self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
|
||||
}
|
||||
}
|
||||
|
||||
enum TokenRefreshMode {
|
||||
Refresh,
|
||||
ClearAndRefresh,
|
||||
}
|
||||
|
||||
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
|
||||
|
||||
impl Global for GlobalRefreshLlmTokenListener {}
|
||||
|
||||
pub struct LlmTokenRefreshedEvent;
|
||||
|
||||
pub struct RefreshLlmTokenListener {
|
||||
client: Arc<Client>,
|
||||
user_store: Entity<UserStore>,
|
||||
llm_api_token: LlmApiToken,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
|
||||
|
||||
impl RefreshLlmTokenListener {
|
||||
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
|
||||
cx.set_global(GlobalRefreshLlmTokenListener(listener));
|
||||
}
|
||||
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
GlobalRefreshLlmTokenListener::global(cx).0.clone()
|
||||
}
|
||||
|
||||
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
|
||||
client.add_message_to_client_handler({
|
||||
let this = cx.weak_entity();
|
||||
move |message, cx| {
|
||||
if let Some(this) = this.upgrade() {
|
||||
Self::handle_refresh_llm_token(this, message, cx);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
|
||||
if matches!(event, client::user::Event::OrganizationChanged) {
|
||||
this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
client,
|
||||
user_store,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_subscription: subscription,
|
||||
}
|
||||
}
|
||||
|
||||
fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let organization_id = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
cx.spawn(async move |this, cx| {
|
||||
match mode {
|
||||
TokenRefreshMode::Refresh => {
|
||||
llm_api_token.refresh(&client, organization_id).await?;
|
||||
}
|
||||
TokenRefreshMode::ClearAndRefresh => {
|
||||
llm_api_token
|
||||
.clear_and_refresh(&client, organization_id)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
|
||||
match message {
|
||||
MessageToClient::UserUpdated => {
|
||||
this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||
use ::settings::{Settings, SettingsStore};
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashSet;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::{App, Context, Entity};
|
||||
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
|
||||
use provider::deepseek::DeepSeekLanguageModelProvider;
|
||||
|
|
@ -31,9 +32,16 @@ use crate::provider::x_ai::XAiLanguageModelProvider;
|
|||
pub use crate::settings::*;
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
let credentials_provider = client.credentials_provider();
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_language_model_providers(registry, user_store, client.clone(), cx);
|
||||
register_language_model_providers(
|
||||
registry,
|
||||
user_store,
|
||||
client.clone(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Subscribe to extension store events to track LLM extension installations
|
||||
|
|
@ -104,6 +112,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
|||
&HashSet::default(),
|
||||
&openai_compatible_providers,
|
||||
client.clone(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
|
@ -124,6 +133,7 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
|||
&openai_compatible_providers,
|
||||
&openai_compatible_providers_new,
|
||||
client.clone(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
|
@ -138,6 +148,7 @@ fn register_openai_compatible_providers(
|
|||
old: &HashSet<Arc<str>>,
|
||||
new: &HashSet<Arc<str>>,
|
||||
client: Arc<Client>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut Context<LanguageModelRegistry>,
|
||||
) {
|
||||
for provider_id in old {
|
||||
|
|
@ -152,6 +163,7 @@ fn register_openai_compatible_providers(
|
|||
Arc::new(OpenAiCompatibleLanguageModelProvider::new(
|
||||
provider_id.clone(),
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
|
|
@ -164,6 +176,7 @@ fn register_language_model_providers(
|
|||
registry: &mut LanguageModelRegistry,
|
||||
user_store: Entity<UserStore>,
|
||||
client: Arc<Client>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut Context<LanguageModelRegistry>,
|
||||
) {
|
||||
registry.register_provider(
|
||||
|
|
@ -177,62 +190,105 @@ fn register_language_model_providers(
|
|||
registry.register_provider(
|
||||
Arc::new(AnthropicLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(OpenAiLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(OllamaLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(LmStudioLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(DeepSeekLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(GoogleLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
MistralLanguageModelProvider::global(client.http_client(), cx),
|
||||
MistralLanguageModelProvider::global(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(BedrockLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(OpenRouterLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(VercelLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(VercelAiGatewayLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(XAiLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
|
||||
Arc::new(OpenCodeLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider,
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use anthropic::{
|
|||
};
|
||||
use anyhow::Result;
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
|
||||
use http_client::HttpClient;
|
||||
|
|
@ -51,6 +52,7 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -59,30 +61,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -195,12 +195,13 @@ pub struct State {
|
|||
settings: Option<AmazonBedrockSettings>,
|
||||
/// Whether credentials came from environment variables (only relevant for static credentials)
|
||||
credentials_from_env: bool,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn reset_auth(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(AMAZON_AWS_URL, cx)
|
||||
|
|
@ -220,7 +221,7 @@ impl State {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let auth = credentials.clone().into_auth();
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(
|
||||
|
|
@ -287,7 +288,7 @@ impl State {
|
|||
&self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
// Try environment variables first
|
||||
let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value {
|
||||
|
|
@ -400,11 +401,16 @@ pub struct BedrockLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl BedrockLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
auth: None,
|
||||
settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
|
||||
credentials_from_env: false,
|
||||
credentials_provider,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
use ai_onboarding::YoungAccountBanner;
|
||||
use anthropic::AnthropicModelMode;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use client::{
|
||||
Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls,
|
||||
};
|
||||
use cloud_api_types::{OrganizationId, Plan};
|
||||
use cloud_llm_client::{
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
|
||||
|
|
@ -24,10 +26,9 @@ use language_model::{
|
|||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
|
||||
OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
|
||||
RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
|
||||
ZED_CLOUD_PROVIDER_NAME,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
|
||||
OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
|
||||
ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
|
||||
};
|
||||
use release_channel::AppVersion;
|
||||
use schemars::JsonSchema;
|
||||
|
|
@ -111,7 +112,7 @@ impl State {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
let llm_api_token = LlmApiToken::global(cx);
|
||||
let llm_api_token = global_llm_token(cx);
|
||||
Self {
|
||||
client: client.clone(),
|
||||
llm_api_token,
|
||||
|
|
@ -226,7 +227,9 @@ impl State {
|
|||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<ListModelsResponse> {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client, organization_id).await?;
|
||||
let token = client
|
||||
.acquire_llm_token(&llm_api_token, organization_id)
|
||||
.await?;
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
|
|
@ -414,8 +417,8 @@ impl CloudLanguageModel {
|
|||
) -> Result<PerformLlmCompletionResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token
|
||||
.acquire(&client, organization_id.clone())
|
||||
let mut token = client
|
||||
.acquire_llm_token(&llm_api_token, organization_id.clone())
|
||||
.await?;
|
||||
let mut refreshed_token = false;
|
||||
|
||||
|
|
@ -447,8 +450,8 @@ impl CloudLanguageModel {
|
|||
}
|
||||
|
||||
if !refreshed_token && response.needs_llm_token_refresh() {
|
||||
token = llm_api_token
|
||||
.refresh(&client, organization_id.clone())
|
||||
token = client
|
||||
.refresh_llm_token(&llm_api_token, organization_id.clone())
|
||||
.await?;
|
||||
refreshed_token = true;
|
||||
continue;
|
||||
|
|
@ -713,7 +716,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
into_google(request, model_id.clone(), GoogleModelMode::Default);
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client, organization_id).await?;
|
||||
let token = client
|
||||
.acquire_llm_token(&llm_api_token, organization_id)
|
||||
.await?;
|
||||
|
||||
let request_body = CountTokensBody {
|
||||
provider: cloud_llm_client::LanguageModelProvider::Google,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use deepseek::DEEPSEEK_API_URL;
|
||||
|
||||
use futures::Stream;
|
||||
|
|
@ -49,6 +50,7 @@ pub struct DeepSeekLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -57,30 +59,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl DeepSeekLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use google_ai::{
|
||||
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
|
||||
|
|
@ -60,6 +61,7 @@ pub struct GoogleLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
|
||||
|
|
@ -76,30 +78,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl GoogleLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use fs::Fs;
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
|
|
@ -52,6 +53,7 @@ pub struct LmStudioLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<lmstudio::Model>,
|
||||
fetch_model_task: Option<Task<Result<()>>>,
|
||||
|
|
@ -64,10 +66,15 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
|
||||
let task = self
|
||||
.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx);
|
||||
let task = self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
self.restart_fetch_models_task(cx);
|
||||
task
|
||||
}
|
||||
|
|
@ -114,10 +121,14 @@ impl State {
|
|||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
|
||||
let _task = self
|
||||
.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
|
||||
let _task = self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
|
|
@ -152,16 +163,29 @@ impl State {
|
|||
}
|
||||
|
||||
impl LmStudioLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let this = Self {
|
||||
http_client: http_client.clone(),
|
||||
state: cx.new(|cx| {
|
||||
let subscription = cx.observe_global::<SettingsStore>({
|
||||
let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
|
||||
move |this: &mut State, cx| {
|
||||
let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
||||
if &settings != new_settings {
|
||||
settings = new_settings.clone();
|
||||
let new_settings =
|
||||
AllLanguageModelSettings::get_global(cx).lmstudio.clone();
|
||||
if settings != new_settings {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx).into();
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
settings = new_settings;
|
||||
this.restart_fetch_models_task(cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
|
@ -173,6 +197,7 @@ impl LmStudioLanguageModelProvider {
|
|||
Self::api_url(cx).into(),
|
||||
(*API_KEY_ENV_VAR).clone(),
|
||||
),
|
||||
credentials_provider,
|
||||
http_client,
|
||||
available_models: Default::default(),
|
||||
fetch_model_task: None,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
|
||||
|
|
@ -43,6 +44,7 @@ pub struct MistralLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -51,15 +53,26 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = MistralLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = MistralLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -73,20 +86,30 @@ impl MistralLanguageModelProvider {
|
|||
.map(|this| &this.0)
|
||||
}
|
||||
|
||||
pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
|
||||
pub fn global(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Arc<Self> {
|
||||
if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
|
||||
return this.0.clone();
|
||||
}
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use fs::Fs;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{Stream, TryFutureExt, stream};
|
||||
|
|
@ -54,6 +55,7 @@ pub struct OllamaLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
fetched_models: Vec<ollama::Model>,
|
||||
fetch_model_task: Option<Task<Result<()>>>,
|
||||
|
|
@ -65,10 +67,15 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
let task = self
|
||||
.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx);
|
||||
let task = self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
|
||||
self.fetched_models.clear();
|
||||
cx.spawn(async move |this, cx| {
|
||||
|
|
@ -80,10 +87,14 @@ impl State {
|
|||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
let task = self
|
||||
.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
|
||||
let task = self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
|
||||
// Always try to fetch models - if no API key is needed (local Ollama), it will work
|
||||
// If API key is needed and provided, it will work
|
||||
|
|
@ -157,7 +168,11 @@ impl State {
|
|||
}
|
||||
|
||||
impl OllamaLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let this = Self {
|
||||
http_client: http_client.clone(),
|
||||
state: cx.new(|cx| {
|
||||
|
|
@ -170,6 +185,14 @@ impl OllamaLanguageModelProvider {
|
|||
let url_changed = last_settings.api_url != current_settings.api_url;
|
||||
last_settings = current_settings.clone();
|
||||
if url_changed {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
this.fetched_models.clear();
|
||||
this.authenticate(cx).detach();
|
||||
}
|
||||
|
|
@ -184,6 +207,7 @@ impl OllamaLanguageModelProvider {
|
|||
fetched_models: Default::default(),
|
||||
fetch_model_task: None,
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
|
|
@ -55,6 +56,7 @@ pub struct OpenAiLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -63,30 +65,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OpenAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OpenAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use convert_case::{Case, Casing};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
|
|
@ -44,6 +45,7 @@ pub struct State {
|
|||
id: Arc<str>,
|
||||
api_key_state: ApiKeyState,
|
||||
settings: OpenAiCompatibleSettings,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -52,20 +54,36 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = SharedString::new(self.settings.api_url.as_str());
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = SharedString::new(self.settings.api_url.clone());
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleLanguageModelProvider {
|
||||
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
id: Arc<str>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
|
||||
crate::AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
|
|
@ -79,10 +97,12 @@ impl OpenAiCompatibleLanguageModelProvider {
|
|||
return;
|
||||
};
|
||||
if &this.settings != &settings {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = SharedString::new(settings.api_url.as_str());
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
this.settings = settings;
|
||||
|
|
@ -98,6 +118,7 @@ impl OpenAiCompatibleLanguageModelProvider {
|
|||
EnvVar::new(api_key_env_var_name),
|
||||
),
|
||||
settings,
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
|
||||
use http_client::HttpClient;
|
||||
|
|
@ -42,6 +43,7 @@ pub struct OpenRouterLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<open_router::Model>,
|
||||
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
|
||||
|
|
@ -53,16 +55,26 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
let task = self
|
||||
.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
|
||||
let task = self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
|
|
@ -114,7 +126,11 @@ impl State {
|
|||
}
|
||||
|
||||
impl OpenRouterLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>({
|
||||
let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
|
||||
|
|
@ -131,6 +147,7 @@ impl OpenRouterLanguageModelProvider {
|
|||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
http_client: http_client.clone(),
|
||||
available_models: Vec::new(),
|
||||
fetch_models_task: None,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
|
|
@ -43,6 +44,7 @@ pub struct OpenCodeLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -51,30 +53,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OpenCodeLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = OpenCodeLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenCodeLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
|
|
@ -38,6 +39,7 @@ pub struct VercelLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -46,30 +48,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = VercelLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = VercelLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl VercelLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
|
||||
|
|
@ -41,6 +42,7 @@ pub struct VercelAiGatewayLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<AvailableModel>,
|
||||
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
|
||||
|
|
@ -52,16 +54,26 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
|
||||
let task = self
|
||||
.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
|
||||
let task = self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
|
|
@ -100,7 +112,11 @@ impl State {
|
|||
}
|
||||
|
||||
impl VercelAiGatewayLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>({
|
||||
let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
|
||||
|
|
@ -116,6 +132,7 @@ impl VercelAiGatewayLanguageModelProvider {
|
|||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
http_client: http_client.clone(),
|
||||
available_models: Vec::new(),
|
||||
fetch_models_task: None,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
|
|
@ -39,6 +40,7 @@ pub struct XAiLanguageModelProvider {
|
|||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
|
@ -47,30 +49,51 @@ impl State {
|
|||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = XAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.store(
|
||||
api_url,
|
||||
api_key,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = self.credentials_provider.clone();
|
||||
let api_url = XAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl XAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let credentials_provider = this.credentials_provider.clone();
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state
|
||||
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
|this| &mut this.api_key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
|
||||
credentials_provider,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ watch.workspace = true
|
|||
wax.workspace = true
|
||||
which.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_credentials_provider.workspace = true
|
||||
zeroize.workspace = true
|
||||
zlog.workspace = true
|
||||
ztracing.workspace = true
|
||||
|
|
|
|||
|
|
@ -684,7 +684,7 @@ impl ContextServerStore {
|
|||
let server_url = url.clone();
|
||||
let id = id.clone();
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
|
||||
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
|
||||
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
|
||||
{
|
||||
log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
|
||||
|
|
@ -797,8 +797,7 @@ impl ContextServerStore {
|
|||
if configuration.has_static_auth_header() {
|
||||
None
|
||||
} else {
|
||||
let credentials_provider =
|
||||
cx.update(|cx| <dyn CredentialsProvider>::global(cx));
|
||||
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
|
||||
let http_client = cx.update(|cx| cx.http_client());
|
||||
|
||||
match Self::load_session(&credentials_provider, url, &cx).await {
|
||||
|
|
@ -1070,7 +1069,7 @@ impl ContextServerStore {
|
|||
.context("Failed to start OAuth callback server")?;
|
||||
|
||||
let http_client = cx.update(|cx| cx.http_client());
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
|
||||
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
|
||||
let server_url = match configuration.as_ref() {
|
||||
ContextServerConfiguration::Http { url, .. } => url.clone(),
|
||||
_ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
|
||||
|
|
@ -1233,7 +1232,7 @@ impl ContextServerStore {
|
|||
self.stop_server(&id, cx)?;
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
|
||||
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
|
||||
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
|
||||
log::error!("{} failed to clear OAuth session: {}", id, err);
|
||||
}
|
||||
|
|
@ -1451,7 +1450,7 @@ async fn resolve_start_failure(
|
|||
// (e.g. timeout because the server rejected the token silently). Clear it
|
||||
// so the next start attempt can get a clean 401 and trigger the auth flow.
|
||||
if www_authenticate.is_none() {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
|
||||
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
|
||||
match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
|
||||
Ok(Some(_)) => {
|
||||
log::info!("{id} start failed with a cached OAuth session present; clearing it");
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ ui.workspace = true
|
|||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zed_credentials_provider.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
fs = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
|||
|
|
@ -185,9 +185,15 @@ fn render_api_key_provider(
|
|||
cx: &mut Context<SettingsWindow>,
|
||||
) -> impl IntoElement {
|
||||
let weak_page = cx.weak_entity();
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
_ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
|
||||
let task = api_key_state.update(cx, |key_state, cx| {
|
||||
key_state.load_if_needed(current_url(cx), |state| state, cx)
|
||||
key_state.load_if_needed(
|
||||
current_url(cx),
|
||||
|state| state,
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
cx.spawn(async move |_, cx| {
|
||||
task.await.ok();
|
||||
|
|
@ -208,10 +214,17 @@ fn render_api_key_provider(
|
|||
});
|
||||
|
||||
let write_key = move |api_key: Option<String>, cx: &mut App| {
|
||||
let credentials_provider = zed_credentials_provider::global(cx);
|
||||
api_key_state
|
||||
.update(cx, |key_state, cx| {
|
||||
let url = current_url(cx);
|
||||
key_state.store(url, api_key, |key_state| key_state, cx)
|
||||
key_state.store(
|
||||
url,
|
||||
api_key,
|
||||
|key_state| key_state,
|
||||
credentials_provider,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
|
||||
use language_model::LlmApiToken;
|
||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||
|
||||
pub struct CloudWebSearchProvider {
|
||||
|
|
@ -30,7 +30,7 @@ pub struct State {
|
|||
|
||||
impl State {
|
||||
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
|
||||
let llm_api_token = LlmApiToken::global(cx);
|
||||
let llm_api_token = global_llm_token(cx);
|
||||
|
||||
Self {
|
||||
client,
|
||||
|
|
@ -73,8 +73,8 @@ async fn perform_web_search(
|
|||
|
||||
let http_client = &client.http_client();
|
||||
let mut retries_remaining = MAX_RETRIES;
|
||||
let mut token = llm_api_token
|
||||
.acquire(&client, organization_id.clone())
|
||||
let mut token = client
|
||||
.acquire_llm_token(&llm_api_token, organization_id.clone())
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
|
|
@ -100,8 +100,8 @@ async fn perform_web_search(
|
|||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else if response.needs_llm_token_refresh() {
|
||||
token = llm_api_token
|
||||
.refresh(&client, organization_id.clone())
|
||||
token = client
|
||||
.refresh_llm_token(&llm_api_token, organization_id.clone())
|
||||
.await?;
|
||||
retries_remaining -= 1;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use agent_ui::AgentPanel;
|
|||
use anyhow::{Context as _, Error, Result};
|
||||
use clap::Parser;
|
||||
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
|
||||
use client::{Client, ProxySettings, UserStore, parse_zed_link};
|
||||
use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link};
|
||||
use collab_ui::channel_view::ChannelView;
|
||||
use collections::HashMap;
|
||||
use crashes::InitCrashHandler;
|
||||
|
|
@ -664,7 +664,12 @@ fn main() {
|
|||
);
|
||||
|
||||
copilot_ui::init(&app_state, cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
RefreshLlmTokenListener::register(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
acp_tools::init(cx);
|
||||
zed::telemetry_log::init(cx);
|
||||
|
|
|
|||
|
|
@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
|
|||
});
|
||||
prompt_store::init(cx);
|
||||
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
client::RefreshLlmTokenListener::register(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
git_ui::init(cx);
|
||||
project::AgentRegistryStore::init_global(
|
||||
|
|
|
|||
|
|
@ -5189,7 +5189,12 @@ mod tests {
|
|||
cx,
|
||||
);
|
||||
image_viewer::init(cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
client::RefreshLlmTokenListener::register(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
web_search::init(cx);
|
||||
git_graph::init(cx);
|
||||
|
|
|
|||
|
|
@ -313,7 +313,12 @@ mod tests {
|
|||
let app_state = cx.update(|cx| {
|
||||
let app_state = AppState::test(cx);
|
||||
client::init(&app_state.client, cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_model::init(cx);
|
||||
client::RefreshLlmTokenListener::register(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
);
|
||||
editor::init(cx);
|
||||
app_state
|
||||
});
|
||||
|
|
|
|||
22
crates/zed_credentials_provider/Cargo.toml
Normal file
22
crates/zed_credentials_provider/Cargo.toml
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "zed_credentials_provider"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/zed_credentials_provider.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
paths.workspace = true
|
||||
release_channel.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
1
crates/zed_credentials_provider/LICENSE-GPL
Symbolic link
1
crates/zed_credentials_provider/LICENSE-GPL
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
||||
181
crates/zed_credentials_provider/src/zed_credentials_provider.rs
Normal file
181
crates/zed_credentials_provider/src/zed_credentials_provider.rs
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use anyhow::Result;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::FutureExt as _;
|
||||
use gpui::{App, AsyncApp, Global};
|
||||
use release_channel::ReleaseChannel;
|
||||
|
||||
/// An environment variable whose presence indicates that the system keychain
|
||||
/// should be used in development.
|
||||
///
|
||||
/// By default, running Zed in development uses the development credentials
|
||||
/// provider. Setting this environment variable allows you to interact with the
|
||||
/// system keychain (for instance, if you need to test something).
|
||||
///
|
||||
/// Only works in development. Setting this environment variable in other
|
||||
/// release channels is a no-op.
|
||||
static ZED_DEVELOPMENT_USE_KEYCHAIN: LazyLock<bool> = LazyLock::new(|| {
|
||||
std::env::var("ZED_DEVELOPMENT_USE_KEYCHAIN").is_ok_and(|value| !value.is_empty())
|
||||
});
|
||||
|
||||
pub struct ZedCredentialsProvider(pub Arc<dyn CredentialsProvider>);
|
||||
|
||||
impl Global for ZedCredentialsProvider {}
|
||||
|
||||
/// Returns the global [`CredentialsProvider`].
|
||||
pub fn init_global(cx: &mut App) {
|
||||
// The `CredentialsProvider` trait has `Send + Sync` bounds on it, so it
|
||||
// seems like this is a false positive from Clippy.
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
let provider = new(cx);
|
||||
cx.set_global(ZedCredentialsProvider(provider));
|
||||
}
|
||||
|
||||
pub fn global(cx: &App) -> Arc<dyn CredentialsProvider> {
|
||||
cx.try_global::<ZedCredentialsProvider>()
|
||||
.map(|provider| provider.0.clone())
|
||||
.unwrap_or_else(|| new(cx))
|
||||
}
|
||||
|
||||
fn new(cx: &App) -> Arc<dyn CredentialsProvider> {
|
||||
let use_development_provider = match ReleaseChannel::try_global(cx) {
|
||||
Some(ReleaseChannel::Dev) => {
|
||||
// In development we default to using the development
|
||||
// credentials provider to avoid getting spammed by relentless
|
||||
// keychain access prompts.
|
||||
//
|
||||
// However, if the `ZED_DEVELOPMENT_USE_KEYCHAIN` environment
|
||||
// variable is set, we will use the actual keychain.
|
||||
!*ZED_DEVELOPMENT_USE_KEYCHAIN
|
||||
}
|
||||
Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) | None => {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if use_development_provider {
|
||||
Arc::new(DevelopmentCredentialsProvider::new())
|
||||
} else {
|
||||
Arc::new(KeychainCredentialsProvider)
|
||||
}
|
||||
}
|
||||
|
||||
/// A credentials provider that stores credentials in the system keychain.
|
||||
struct KeychainCredentialsProvider;
|
||||
|
||||
impl CredentialsProvider for KeychainCredentialsProvider {
|
||||
fn read_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
|
||||
async move { cx.update(|cx| cx.read_credentials(url)).await }.boxed_local()
|
||||
}
|
||||
|
||||
fn write_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
username: &'a str,
|
||||
password: &'a [u8],
|
||||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move {
|
||||
cx.update(move |cx| cx.write_credentials(url, username, password))
|
||||
.await
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
||||
fn delete_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move { cx.update(move |cx| cx.delete_credentials(url)).await }.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
/// A credentials provider that stores credentials in a local file.
|
||||
///
|
||||
/// This MUST only be used in development, as this is not a secure way of storing
|
||||
/// credentials on user machines.
|
||||
///
|
||||
/// Its existence is purely to work around the annoyance of having to constantly
|
||||
/// re-allow access to the system keychain when developing Zed.
|
||||
struct DevelopmentCredentialsProvider {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl DevelopmentCredentialsProvider {
|
||||
fn new() -> Self {
|
||||
let path = paths::config_dir().join("development_credentials");
|
||||
|
||||
Self { path }
|
||||
}
|
||||
|
||||
fn load_credentials(&self) -> Result<HashMap<String, (String, Vec<u8>)>> {
|
||||
let json = std::fs::read(&self.path)?;
|
||||
let credentials: HashMap<String, (String, Vec<u8>)> = serde_json::from_slice(&json)?;
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
fn save_credentials(&self, credentials: &HashMap<String, (String, Vec<u8>)>) -> Result<()> {
|
||||
let json = serde_json::to_string(credentials)?;
|
||||
std::fs::write(&self.path, json)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialsProvider for DevelopmentCredentialsProvider {
|
||||
fn read_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
_cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
|
||||
async move {
|
||||
Ok(self
|
||||
.load_credentials()
|
||||
.unwrap_or_default()
|
||||
.get(url)
|
||||
.cloned())
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
||||
fn write_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
username: &'a str,
|
||||
password: &'a [u8],
|
||||
_cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move {
|
||||
let mut credentials = self.load_credentials().unwrap_or_default();
|
||||
credentials.insert(url.to_string(), (username.to_string(), password.to_vec()));
|
||||
|
||||
self.save_credentials(&credentials)
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
|
||||
fn delete_credentials<'a>(
|
||||
&'a self,
|
||||
url: &'a str,
|
||||
_cx: &'a AsyncApp,
|
||||
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
||||
async move {
|
||||
let mut credentials = self.load_credentials()?;
|
||||
credentials.remove(url);
|
||||
|
||||
self.save_credentials(&credentials)
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
|
@ -15,4 +15,4 @@ path = "src/zed_env_vars.rs"
|
|||
default = []
|
||||
|
||||
[dependencies]
|
||||
gpui.workspace = true
|
||||
env_var.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,45 +1,6 @@
|
|||
use gpui::SharedString;
|
||||
pub use env_var::{EnvVar, bool_env_var, env_var};
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Whether Zed is running in stateless mode.
|
||||
/// When true, Zed will use in-memory databases instead of persistent storage.
|
||||
pub static ZED_STATELESS: LazyLock<bool> = bool_env_var!("ZED_STATELESS");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EnvVar {
|
||||
pub name: SharedString,
|
||||
/// Value of the environment variable. Also `None` when set to an empty string.
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
impl EnvVar {
|
||||
pub fn new(name: SharedString) -> Self {
|
||||
let value = std::env::var(name.as_str()).ok();
|
||||
if value.as_ref().is_some_and(|v| v.is_empty()) {
|
||||
Self { name, value: None }
|
||||
} else {
|
||||
Self { name, value }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn or(self, other: EnvVar) -> EnvVar {
|
||||
if self.value.is_some() { self } else { other }
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
|
||||
#[macro_export]
|
||||
macro_rules! env_var {
|
||||
($name:expr) => {
|
||||
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()))
|
||||
};
|
||||
}
|
||||
|
||||
/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
|
||||
/// environment variable exists and is non-empty.
|
||||
#[macro_export]
|
||||
macro_rules! bool_env_var {
|
||||
($name:expr) => {
|
||||
::std::sync::LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
|
||||
};
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue