mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Refresh LLM API token on organization change (#50931)
Emit client-side organization changed events through `RefreshLlmTokenListener` so it produces the same `RefreshLlmTokenEvent` used for server-pushed `UserUpdated` messages. This keeps token refresh fan-out in one place. Closes CLO-383. Release Notes: - N/A --------- Co-authored-by: Tom Houlé <tom@tomhoule.com>
This commit is contained in:
parent
21e202ef0c
commit
dc0e41f834
17 changed files with 61 additions and 30 deletions
|
|
@ -1423,7 +1423,7 @@ impl EditAgentTest {
|
|||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
settings::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -3167,7 +3167,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
let clock = Arc::new(clock::FakeSystemClock::new());
|
||||
let client = Client::new(clock, http_client, cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
|
|
@ -3791,7 +3791,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
cx.set_http_client(Arc::new(http_client));
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use crate::{AgentServer, AgentServerDelegate};
|
|||
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
|
||||
use agent_client_protocol as acp;
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::AppContext;
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use project::{FakeFs, Project};
|
||||
|
|
@ -408,7 +409,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
|||
let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
let client = client::Client::production(cx);
|
||||
language_model::init(client, cx);
|
||||
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store, client, cx);
|
||||
|
||||
#[cfg(test)]
|
||||
project::agent_server_store::AllAgentServersSettings::override_global(
|
||||
|
|
|
|||
|
|
@ -2120,7 +2120,7 @@ pub mod test {
|
|||
client::init(&client, cx);
|
||||
workspace::init(app_state.clone(), cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store, client.clone(), cx);
|
||||
|
||||
cx.set_global(inline_assistant);
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ pub enum Event {
|
|||
ParticipantIndicesChanged,
|
||||
PrivateUserInfoUpdated,
|
||||
PlanUpdated,
|
||||
OrganizationChanged,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
|
|
@ -694,8 +695,21 @@ impl UserStore {
|
|||
self.current_organization.clone()
|
||||
}
|
||||
|
||||
pub fn set_current_organization(&mut self, organization: Arc<Organization>) {
|
||||
self.current_organization.replace(organization);
|
||||
pub fn set_current_organization(
|
||||
&mut self,
|
||||
organization: Arc<Organization>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let is_same_organization = self
|
||||
.current_organization
|
||||
.as_ref()
|
||||
.is_some_and(|current| current.id == organization.id);
|
||||
|
||||
if !is_same_organization {
|
||||
self.current_organization.replace(organization);
|
||||
cx.emit(Event::OrganizationChanged);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn organizations(&self) -> &Vec<Arc<Organization>> {
|
||||
|
|
|
|||
|
|
@ -533,8 +533,8 @@ mod tests {
|
|||
zlog::init_test();
|
||||
let http_client = FakeHttpClient::with_404_response();
|
||||
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
EditPredictionStore::global(&client, &user_store, cx);
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1850,9 +1850,8 @@ fn init_test_with_fake_client(
|
|||
let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
|
||||
client.cloud_client().set_credentials(1, "test".into());
|
||||
|
||||
language_model::init(client.clone(), cx);
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
|
||||
|
||||
(
|
||||
|
|
@ -2218,8 +2217,9 @@ async fn make_test_ep_store(
|
|||
});
|
||||
|
||||
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
||||
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
|
||||
cx.update(|cx| {
|
||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
});
|
||||
let _server = FakeServer::for_client(42, &client, cx).await;
|
||||
|
||||
|
|
@ -2301,8 +2301,9 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
|
|||
|
||||
let client =
|
||||
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
||||
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
|
||||
cx.update(|cx| {
|
||||
language_model::RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
||||
});
|
||||
|
||||
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ pub fn init(cx: &mut App) -> EpAppState {
|
|||
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
|
||||
prompt_store::init(cx);
|
||||
|
|
|
|||
|
|
@ -429,7 +429,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
|
||||
prompt_store::init(cx);
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
|
|||
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
|
||||
language_model::init(client.clone(), cx);
|
||||
language_model::init(user_store.clone(), client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
|
||||
prompt_store::init(cx);
|
||||
|
|
|
|||
|
|
@ -13,10 +13,11 @@ pub mod fake_provider;
|
|||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use client::UserStore;
|
||||
use cloud_llm_client::CompletionRequestStatus;
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
|
||||
use http_client::{StatusCode, http};
|
||||
use icons::IconName;
|
||||
use open_router::OpenRouterError;
|
||||
|
|
@ -61,9 +62,9 @@ pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProvider
|
|||
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Zed");
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
init_settings(cx);
|
||||
RefreshLlmTokenListener::register(client, cx);
|
||||
RefreshLlmTokenListener::register(client, user_store, cx);
|
||||
}
|
||||
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
|
|
|
|||
|
|
@ -3,11 +3,14 @@ use std::sync::Arc;
|
|||
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::Client;
|
||||
use client::UserStore;
|
||||
use cloud_api_client::ClientApiError;
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_api_types::websocket_protocol::MessageToClient;
|
||||
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
|
||||
};
|
||||
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
||||
use thiserror::Error;
|
||||
|
||||
|
|
@ -101,13 +104,15 @@ impl Global for GlobalRefreshLlmTokenListener {}
|
|||
|
||||
pub struct RefreshLlmTokenEvent;
|
||||
|
||||
pub struct RefreshLlmTokenListener;
|
||||
pub struct RefreshLlmTokenListener {
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
|
||||
|
||||
impl RefreshLlmTokenListener {
|
||||
pub fn register(client: Arc<Client>, cx: &mut App) {
|
||||
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
|
||||
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
|
||||
cx.set_global(GlobalRefreshLlmTokenListener(listener));
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +120,7 @@ impl RefreshLlmTokenListener {
|
|||
GlobalRefreshLlmTokenListener::global(cx).0.clone()
|
||||
}
|
||||
|
||||
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
|
||||
client.add_message_to_client_handler({
|
||||
let this = cx.entity();
|
||||
move |message, cx| {
|
||||
|
|
@ -123,7 +128,15 @@ impl RefreshLlmTokenListener {
|
|||
}
|
||||
});
|
||||
|
||||
Self
|
||||
let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
|
||||
if matches!(event, client::user::Event::OrganizationChanged) {
|
||||
cx.emit(RefreshLlmTokenEvent);
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
_subscription: subscription,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
|
||||
|
|
|
|||
|
|
@ -1014,9 +1014,9 @@ impl TitleBar {
|
|||
let user_store = user_store.clone();
|
||||
let organization = organization.clone();
|
||||
move |_window, cx| {
|
||||
user_store.update(cx, |user_store, _cx| {
|
||||
user_store.update(cx, |user_store, cx| {
|
||||
user_store
|
||||
.set_current_organization(organization.clone());
|
||||
.set_current_organization(organization.clone(), cx);
|
||||
});
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -657,7 +657,7 @@ fn main() {
|
|||
);
|
||||
|
||||
copilot_ui::init(&app_state, cx);
|
||||
language_model::init(app_state.client.clone(), cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
acp_tools::init(cx);
|
||||
zed::telemetry_log::init(cx);
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
|
|||
});
|
||||
prompt_store::init(cx);
|
||||
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
|
||||
language_model::init(app_state.client.clone(), cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
git_ui::init(cx);
|
||||
project::AgentRegistryStore::init_global(
|
||||
|
|
|
|||
|
|
@ -5024,7 +5024,7 @@ mod tests {
|
|||
cx,
|
||||
);
|
||||
image_viewer::init(cx);
|
||||
language_model::init(app_state.client.clone(), cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
web_search::init(cx);
|
||||
git_graph::init(cx);
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ mod tests {
|
|||
let app_state = cx.update(|cx| {
|
||||
let app_state = AppState::test(cx);
|
||||
client::init(&app_state.client, cx);
|
||||
language_model::init(app_state.client.clone(), cx);
|
||||
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
editor::init(cx);
|
||||
app_state
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in a new issue