Refresh LLM API token on organization change (#50931)

Emit client-side organization changed events through
`RefreshLlmTokenListener` so it produces the same `RefreshLlmTokenEvent`
used for server-pushed `UserUpdated` messages.

This keeps token refresh fan-out in one place.

Closes CLO-383.

Release Notes:

- N/A

---------

Co-authored-by: Tom Houlé <tom@tomhoule.com>
This commit is contained in:
Neel 2026-03-06 19:15:21 +00:00 committed by GitHub
parent 21e202ef0c
commit dc0e41f834
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 61 additions and 30 deletions

View file

@ -1423,7 +1423,7 @@ impl EditAgentTest {
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
settings::init(cx);
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
});

View file

@ -3167,7 +3167,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
LanguageModelRegistry::test(cx);
});
@ -3791,7 +3791,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
}
};

View file

@ -2,6 +2,7 @@ use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::AppContext;
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use project::{FakeFs, Project};
@ -408,7 +409,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
let client = client::Client::production(cx);
language_model::init(client, cx);
let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
language_model::init(user_store, client, cx);
#[cfg(test)]
project::agent_server_store::AllAgentServersSettings::override_global(

View file

@ -2120,7 +2120,7 @@ pub mod test {
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);

View file

@ -140,6 +140,7 @@ pub enum Event {
ParticipantIndicesChanged,
PrivateUserInfoUpdated,
PlanUpdated,
OrganizationChanged,
}
#[derive(Clone, Copy)]
@ -694,8 +695,21 @@ impl UserStore {
self.current_organization.clone()
}
pub fn set_current_organization(&mut self, organization: Arc<Organization>) {
self.current_organization.replace(organization);
pub fn set_current_organization(
&mut self,
organization: Arc<Organization>,
cx: &mut Context<Self>,
) {
let is_same_organization = self
.current_organization
.as_ref()
.is_some_and(|current| current.id == organization.id);
if !is_same_organization {
self.current_organization.replace(organization);
cx.emit(Event::OrganizationChanged);
cx.notify();
}
}
pub fn organizations(&self) -> &Vec<Arc<Organization>> {

View file

@ -533,8 +533,8 @@ mod tests {
zlog::init_test();
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
language_model::init(client.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
EditPredictionStore::global(&client, &user_store, cx);
})
}

View file

@ -1850,9 +1850,8 @@ fn init_test_with_fake_client(
let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
client.cloud_client().set_credentials(1, "test".into());
language_model::init(client.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(user_store.clone(), client.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
@ -2218,8 +2217,9 @@ async fn make_test_ep_store(
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let _server = FakeServer::for_client(42, &client, cx).await;
@ -2301,8 +2301,9 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
let client =
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
language_model::RefreshLlmTokenListener::register(client.clone(), cx);
language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));

View file

@ -105,7 +105,7 @@ pub fn init(cx: &mut App) -> EpAppState {
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);

View file

@ -429,7 +429,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);

View file

@ -104,7 +104,7 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
language_model::init(client.clone(), cx);
language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);

View file

@ -13,10 +13,11 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::Client;
use client::UserStore;
use cloud_llm_client::CompletionRequestStatus;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use open_router::OpenRouterError;
@ -61,9 +62,9 @@ pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProvider
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");
pub fn init(client: Arc<Client>, cx: &mut App) {
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
init_settings(cx);
RefreshLlmTokenListener::register(client, cx);
RefreshLlmTokenListener::register(client, user_store, cx);
}
pub fn init_settings(cx: &mut App) {

View file

@ -3,11 +3,14 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::Client;
use client::UserStore;
use cloud_api_client::ClientApiError;
use cloud_api_types::OrganizationId;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@ -101,13 +104,15 @@ impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent;
pub struct RefreshLlmTokenListener;
pub struct RefreshLlmTokenListener {
_subscription: Subscription,
}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, cx: &mut App) {
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
@ -115,7 +120,7 @@ impl RefreshLlmTokenListener {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
client.add_message_to_client_handler({
let this = cx.entity();
move |message, cx| {
@ -123,7 +128,15 @@ impl RefreshLlmTokenListener {
}
});
Self
let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
if matches!(event, client::user::Event::OrganizationChanged) {
cx.emit(RefreshLlmTokenEvent);
}
});
Self {
_subscription: subscription,
}
}
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {

View file

@ -1014,9 +1014,9 @@ impl TitleBar {
let user_store = user_store.clone();
let organization = organization.clone();
move |_window, cx| {
user_store.update(cx, |user_store, _cx| {
user_store.update(cx, |user_store, cx| {
user_store
.set_current_organization(organization.clone());
.set_current_organization(organization.clone(), cx);
});
}
},

View file

@ -657,7 +657,7 @@ fn main() {
);
copilot_ui::init(&app_state, cx);
language_model::init(app_state.client.clone(), cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
zed::telemetry_log::init(cx);

View file

@ -200,7 +200,7 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
});
prompt_store::init(cx);
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
language_model::init(app_state.client.clone(), cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
git_ui::init(cx);
project::AgentRegistryStore::init_global(

View file

@ -5024,7 +5024,7 @@ mod tests {
cx,
);
image_viewer::init(cx);
language_model::init(app_state.client.clone(), cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);

View file

@ -316,7 +316,7 @@ mod tests {
let app_state = cx.update(|cx| {
let app_state = AppState::test(cx);
client::init(&app_state.client, cx);
language_model::init(app_state.client.clone(), cx);
language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
editor::init(cx);
app_state
});