From dc0e41f8342537732887d2e10b8dedad1e9d59bd Mon Sep 17 00:00:00 2001 From: Neel Date: Fri, 6 Mar 2026 19:15:21 +0000 Subject: [PATCH] Refresh LLM API token on organization change (#50931) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Emit client-side organization changed events through `RefreshLlmTokenListener` so it produces the same `RefreshLlmTokenEvent` used for server-pushed `UserUpdated` messages. This keeps token refresh fan-out in one place. Closes CLO-383. Release Notes: - N/A --------- Co-authored-by: Tom Houlé --- crates/agent/src/edit_agent/evals.rs | 2 +- crates/agent/src/tests/mod.rs | 4 +-- crates/agent_servers/src/e2e_tests.rs | 4 ++- crates/agent_ui/src/inline_assistant.rs | 2 +- crates/client/src/user.rs | 18 +++++++++++-- crates/edit_prediction/src/capture_example.rs | 2 +- .../src/edit_prediction_tests.rs | 9 ++++--- crates/edit_prediction_cli/src/headless.rs | 2 +- crates/eval/src/eval.rs | 2 +- crates/eval_cli/src/headless.rs | 2 +- crates/language_model/src/language_model.rs | 7 +++--- .../language_model/src/model/cloud_model.rs | 25 ++++++++++++++----- crates/title_bar/src/title_bar.rs | 4 +-- crates/zed/src/main.rs | 2 +- crates/zed/src/visual_test_runner.rs | 2 +- crates/zed/src/zed.rs | 2 +- .../zed/src/zed/edit_prediction_registry.rs | 2 +- 17 files changed, 61 insertions(+), 30 deletions(-) diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index 2e8818b1019..e7b67e37bf4 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -1423,7 +1423,7 @@ impl EditAgentTest { let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); settings::init(cx); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store, client.clone(), cx); }); diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 79e8a5e2459..d33c80a435e 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -3167,7 +3167,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let clock = Arc::new(clock::FakeSystemClock::new()); let client = Client::new(clock, http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store, client.clone(), cx); LanguageModelRegistry::test(cx); }); @@ -3791,7 +3791,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.set_http_client(Arc::new(http_client)); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store, client.clone(), cx); } }; diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index c5754bcd761..a0150d41726 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -2,6 +2,7 @@ use crate::{AgentServer, AgentServerDelegate}; use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; use agent_client_protocol as acp; use futures::{FutureExt, StreamExt, channel::mpsc, select}; +use gpui::AppContext; use gpui::{Entity, TestAppContext}; use indoc::indoc; use project::{FakeFs, Project}; @@ -408,7 +409,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap(); cx.set_http_client(Arc::new(http_client)); let client = client::Client::production(cx); - language_model::init(client, cx); + let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx)); + language_model::init(user_store, client, cx); #[cfg(test)] project::agent_server_store::AllAgentServersSettings::override_global( diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 9ac84addcc8..4e7eecfe07a 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -2120,7 +2120,7 @@ pub mod test { client::init(&client, cx); workspace::init(app_state.clone(), cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store, client.clone(), cx); cx.set_global(inline_assistant); diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index d27bf3387a7..5d38569cfd8 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -140,6 +140,7 @@ pub enum Event { ParticipantIndicesChanged, PrivateUserInfoUpdated, PlanUpdated, + OrganizationChanged, } #[derive(Clone, Copy)] @@ -694,8 +695,21 @@ impl UserStore { self.current_organization.clone() } - pub fn set_current_organization(&mut self, organization: Arc) { - self.current_organization.replace(organization); + pub fn set_current_organization( + &mut self, + organization: Arc, + cx: &mut Context, + ) { + let is_same_organization = self + .current_organization + .as_ref() + .is_some_and(|current| current.id == organization.id); + + if !is_same_organization { + self.current_organization.replace(organization); + cx.emit(Event::OrganizationChanged); + cx.notify(); + } } pub fn organizations(&self) -> &Vec> { diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 0fbece74780..e0df8cf9577 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -533,8 +533,8 @@ mod tests { zlog::init_test(); let http_client = FakeHttpClient::with_404_response(); let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); - language_model::init(client.clone(), cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(user_store.clone(), client.clone(), cx); EditPredictionStore::global(&client, &user_store, cx); }) } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index bbad3c104e6..1ff77fd900d 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1850,9 +1850,8 @@ fn init_test_with_fake_client( let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); client.cloud_client().set_credentials(1, "test".into()); - language_model::init(client.clone(), cx); - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(user_store.clone(), client.clone(), cx); let ep_store = EditPredictionStore::global(&client, &user_store, cx); ( @@ -2218,8 +2217,9 @@ async fn make_test_ep_store( }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx))); cx.update(|cx| { - RefreshLlmTokenListener::register(client.clone(), cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); }); let _server = FakeServer::for_client(42, &client, cx).await; @@ -2301,8 +2301,9 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut let client = cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx))); cx.update(|cx| { - language_model::RefreshLlmTokenListener::register(client.clone(), cx); + language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); }); let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs index f78903b705a..eb2895b06f2 100644 --- a/crates/edit_prediction_cli/src/headless.rs +++ b/crates/edit_prediction_cli/src/headless.rs @@ -105,7 +105,7 @@ pub fn init(cx: &mut App) -> EpAppState { debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 4e9a0cb7915..a621cb0dedb 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -429,7 +429,7 @@ pub fn init(cx: &mut App) -> Arc { let extension_host_proxy = ExtensionHostProxy::global(cx); debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/eval_cli/src/headless.rs b/crates/eval_cli/src/headless.rs index 1448cbeb7a7..54f14ee1938 100644 --- a/crates/eval_cli/src/headless.rs +++ b/crates/eval_cli/src/headless.rs @@ -104,7 +104,7 @@ pub fn init(cx: &mut App) -> Arc { let extension_host_proxy = ExtensionHostProxy::global(cx); debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(client.clone(), cx); + language_model::init(user_store.clone(), client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c403774499c..0452c494a2a 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -13,10 +13,11 @@ pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; +use client::UserStore; use cloud_llm_client::CompletionRequestStatus; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window}; use http_client::{StatusCode, http}; use icons::IconName; use open_router::OpenRouterError; @@ -61,9 +62,9 @@ pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProvider pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Zed"); -pub fn init(client: Arc, cx: &mut App) { +pub fn init(user_store: Entity, client: Arc, cx: &mut App) { init_settings(cx); - RefreshLlmTokenListener::register(client, cx); + RefreshLlmTokenListener::register(client, user_store, cx); } pub fn init_settings(cx: &mut App) { diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index b2af80a3c29..e64cc43edd8 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,11 +3,14 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; +use client::UserStore; use cloud_api_client::ClientApiError; use cloud_api_types::OrganizationId; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; -use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; +use gpui::{ + App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, +}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -101,13 +104,15 @@ impl Global for GlobalRefreshLlmTokenListener {} pub struct RefreshLlmTokenEvent; -pub struct RefreshLlmTokenListener; +pub struct RefreshLlmTokenListener { + _subscription: Subscription, +} impl EventEmitter for RefreshLlmTokenListener {} impl RefreshLlmTokenListener { - pub fn register(client: Arc, cx: &mut App) { - let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx)); + pub fn register(client: Arc, user_store: Entity, cx: &mut App) { + let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx)); cx.set_global(GlobalRefreshLlmTokenListener(listener)); } @@ -115,7 +120,7 @@ impl RefreshLlmTokenListener { GlobalRefreshLlmTokenListener::global(cx).0.clone() } - fn new(client: Arc, cx: &mut Context) -> Self { + fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { client.add_message_to_client_handler({ let this = cx.entity(); move |message, cx| { @@ -123,7 +128,15 @@ impl RefreshLlmTokenListener { } }); - Self + let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| { + if matches!(event, client::user::Event::OrganizationChanged) { + cx.emit(RefreshLlmTokenEvent); + } + }); + + Self { + _subscription: subscription, + } } fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 05ede406b91..3566d621076 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -1014,9 +1014,9 @@ impl TitleBar { let user_store = user_store.clone(); let organization = organization.clone(); move |_window, cx| { - user_store.update(cx, |user_store, _cx| { + user_store.update(cx, |user_store, cx| { user_store - .set_current_organization(organization.clone()); + .set_current_organization(organization.clone(), cx); }); } }, diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 109b79ff06b..eccf6b51e01 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -657,7 +657,7 @@ fn main() { ); copilot_ui::init(&app_state, cx); - language_model::init(app_state.client.clone(), cx); + language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); acp_tools::init(cx); zed::telemetry_log::init(cx); diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index 57d2f4462b9..ead16b911e3 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -200,7 +200,7 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()> }); prompt_store::init(cx); let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx); - language_model::init(app_state.client.clone(), cx); + language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); git_ui::init(cx); project::AgentRegistryStore::init_global( diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 0cb93bbc4c9..562786fb3f0 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -5024,7 +5024,7 @@ mod tests { cx, ); image_viewer::init(cx); - language_model::init(app_state.client.clone(), cx); + language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); web_search::init(cx); git_graph::init(cx); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 9f05c5795e6..952c840d4ab 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -316,7 +316,7 @@ mod tests { let app_state = cx.update(|cx| { let app_state = AppState::test(cx); client::init(&app_state.client, cx); - language_model::init(app_state.client.clone(), cx); + language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); editor::init(cx); app_state });