Fix token refresh for HTTP requests (#56559)

Code had been assuming (erroneously, but understandably) that
LlmApiToken::acquire would give them a valid token.

This is not true, as those tokens expire and you must call refresh
explicitly.

Add some helpers to do the retry for you, and rename acquire to cached
to be
clearer about the intent.

Closes #ISSUE

Release Notes:

- Fixed some rare cases where API requests would fail with Unauthorized
This commit is contained in:
Conrad Irwin 2026-05-12 13:40:00 -06:00 committed by GitHub
parent 64f624773f
commit 54188321be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 200 additions and 217 deletions

1
Cargo.lock generated
View file

@ -9696,6 +9696,7 @@ dependencies = [
"gpui",
"http_client",
"language_model",
"log",
"open_ai",
"schemars 1.0.4",
"semver",

View file

@ -1539,7 +1539,7 @@ impl Client {
})
}
pub async fn acquire_llm_token(
pub async fn cached_llm_token(
&self,
llm_token: &LlmApiToken,
organization_id: Option<OrganizationId>,
@ -1547,7 +1547,7 @@ impl Client {
let system_id = self.telemetry().system_id().map(|x| x.to_string());
let cloud_client = self.cloud_client();
match llm_token
.acquire(&cloud_client, system_id, organization_id)
.cached(&cloud_client, system_id, organization_id)
.await
{
Ok(token) => Ok(token),
@ -1559,6 +1559,31 @@ impl Client {
}
}
/// Sends an authenticated request to the Zed LLM service, retrying once
/// with a refreshed token if the server signals that the cached LLM
/// token is expired or otherwise rejected. Returns the raw response so
/// callers can inspect headers and stream the body.
pub async fn authenticated_llm_request(
&self,
llm_token: &LlmApiToken,
organization_id: Option<OrganizationId>,
build_request: impl Fn(&str) -> Result<http_client::Request<http_client::AsyncBody>>,
) -> Result<http_client::Response<http_client::AsyncBody>> {
let http_client = self.http_client();
let token = self
.cached_llm_token(llm_token, organization_id.clone())
.await?;
let response = http_client.send(build_request(&token)?).await?;
if !response.needs_llm_token_refresh()
&& response.status() != http_client::http::StatusCode::UNAUTHORIZED
{
return Ok(response);
}
log::info!("LLM token rejected; refreshing and retrying request");
let token = self.refresh_llm_token(llm_token, organization_id).await?;
http_client.send(build_request(&token)?).await
}
pub async fn refresh_llm_token(
&self,
llm_token: &LlmApiToken,

View file

@ -9,7 +9,11 @@ use crate::{ClientApiError, CloudApiClient};
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
pub async fn acquire(
/// Returns the cached LLM token, fetching a fresh one only if none has
/// been cached yet. The returned token is not validated; callers must
/// be prepared to refresh it (via [`LlmApiToken::refresh`]) if the
/// server rejects it.
pub async fn cached(
&self,
client: &CloudApiClient,
system_id: Option<String>,

View file

@ -1,5 +1,5 @@
use anyhow::Result;
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use client::{Client, EditPredictionUsage, UserStore, global_llm_token};
use cloud_api_client::LlmApiToken;
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
@ -889,18 +889,19 @@ impl EditPredictionStore {
cx.spawn(async move |this, cx| {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
let token = client
.acquire_llm_token(&llm_token, organization_id.clone())
let url = client
.http_client()
.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let mut response = client
.authenticated_llm_request(&llm_token, organization_id, |token| {
Ok(http_client::Request::builder()
.method(Method::GET)
.uri(url.as_ref())
.header("Authorization", format!("Bearer {token}"))
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
.body(Default::default())?)
})
.await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
.uri(url.as_ref())
.header("Authorization", format!("Bearer {}", token))
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
.body(Default::default())?;
let mut response = http_client.send(request).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
@ -1580,7 +1581,6 @@ impl EditPredictionStore {
llm_token.clone(),
organization_id,
app_version.clone(),
true,
)
.await;
@ -2582,7 +2582,6 @@ impl EditPredictionStore {
llm_token,
organization_id,
app_version,
true,
)
.await
}
@ -2627,7 +2626,6 @@ impl EditPredictionStore {
llm_token,
organization_id,
app_version,
true,
)
.await
}
@ -2638,78 +2636,55 @@ impl EditPredictionStore {
llm_token: LlmApiToken,
organization_id: Option<OrganizationId>,
app_version: Version,
require_auth: bool,
) -> Result<(Res, Option<EditPredictionUsage>)>
where
Res: DeserializeOwned,
{
let http_client = client.http_client();
let mut token = if require_auth {
Some(
client
.acquire_llm_token(&llm_token, organization_id.clone())
.await?,
)
let response = client
.authenticated_llm_request(&llm_token, organization_id, |token| {
build(
http_client::Request::builder()
.method(Method::POST)
.header("Content-Type", "application/json")
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
.header("Authorization", format!("Bearer {token}")),
)
})
.await?;
Self::process_api_response(response, &app_version).await
}
async fn process_api_response<Res>(
mut response: http_client::Response<AsyncBody>,
app_version: &Version,
) -> Result<(Res, Option<EditPredictionUsage>)>
where
Res: DeserializeOwned,
{
if let Some(minimum_required_version) = response
.headers()
.get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
.and_then(|version| Version::from_str(version.to_str().ok()?).ok())
{
anyhow::ensure!(
*app_version >= minimum_required_version,
ZedUpdateRequiredError {
minimum_version: minimum_required_version
}
);
}
if response.status().is_success() {
let usage = EditPredictionUsage::from_headers(response.headers()).ok();
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
Ok((serde_json::from_slice(&body)?, usage))
} else {
client
.acquire_llm_token(&llm_token, organization_id.clone())
.await
.ok()
};
let mut did_retry = false;
loop {
let request_builder = http_client::Request::builder().method(Method::POST);
let mut request_builder = request_builder
.header("Content-Type", "application/json")
.header(ZED_VERSION_HEADER_NAME, app_version.to_string());
// Only add Authorization header if we have a token
if let Some(ref token_value) = token {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", token_value));
}
let request = build(request_builder)?;
let mut response = http_client.send(request).await?;
if let Some(minimum_required_version) = response
.headers()
.get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
.and_then(|version| Version::from_str(version.to_str().ok()?).ok())
{
anyhow::ensure!(
app_version >= minimum_required_version,
ZedUpdateRequiredError {
minimum_version: minimum_required_version
}
);
}
if response.status().is_success() {
let usage = EditPredictionUsage::from_headers(response.headers()).ok();
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
token = Some(
client
.refresh_llm_token(&llm_token, organization_id.clone())
.await?,
);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
body
);
}
let status = response.status();
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!("Request failed with status: {status:?}\nBody: {body}");
}
}

View file

@ -21,7 +21,7 @@ use ui::SharedString;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use zeta_prompt::{ParsedOutput, ZetaPromptInput};
use std::{env, ops::Range, path::Path, sync::Arc};
use std::{ops::Range, path::Path, sync::Arc};
use zeta_prompt::{
ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, stop_tokens_for_format,
zeta1::{self, EDITABLE_REGION_END_MARKER},
@ -632,15 +632,13 @@ pub(crate) fn edit_prediction_accepted(
current_prediction: CurrentEditPrediction,
cx: &App,
) {
let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
if store.zeta2_raw_config().is_some() {
return;
}
let request_id = current_prediction.prediction.id.to_string();
let model_version = current_prediction.prediction.model_version;
let e2e_latency = current_prediction.e2e_latency;
let require_auth = custom_accept_url.is_none();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let organization_id = store
@ -651,35 +649,23 @@ pub(crate) fn edit_prediction_accepted(
let app_version = AppVersion::global(cx);
cx.background_spawn(async move {
let url = if let Some(accept_edits_url) = custom_accept_url {
gpui::http_client::Url::parse(&accept_edits_url)?
} else {
client
.http_client()
.build_zed_llm_url("/predict_edits/accept", &[])?
};
let body = serde_json::to_string(&AcceptEditPredictionBody {
request_id,
model_version,
e2e_latency_ms: Some(e2e_latency.as_millis()),
})?;
let response = EditPredictionStore::send_api_request::<()>(
move |builder| {
let req = builder.uri(url.as_ref()).body(
serde_json::to_string(&AcceptEditPredictionBody {
request_id: request_id.clone(),
model_version: model_version.clone(),
e2e_latency_ms: Some(e2e_latency.as_millis()),
})?
.into(),
);
Ok(req?)
},
let url = client
.http_client()
.build_zed_llm_url("/predict_edits/accept", &[])?;
EditPredictionStore::send_api_request::<()>(
move |builder| Ok(builder.uri(url.as_ref()).body(body.clone().into())?),
client,
llm_token,
organization_id,
app_version,
require_auth,
)
.await;
response?;
.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);

View file

@ -42,7 +42,7 @@ impl CloudLlmTokenProvider for ClientTokenProvider {
})
}
fn acquire_token(
fn cached_token(
&self,
organization_id: Self::AuthContext,
) -> BoxFuture<'static, Result<String>> {
@ -50,7 +50,7 @@ impl CloudLlmTokenProvider for ClientTokenProvider {
let llm_api_token = self.llm_api_token.clone();
Box::pin(async move {
client
.acquire_llm_token(&llm_api_token, organization_id)
.cached_llm_token(&llm_api_token, organization_id)
.await
})
}

View file

@ -20,6 +20,7 @@ google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
language_model.workspace = true
log.workspace = true
open_ai = { workspace = true, features = ["schemars"] }
schemars.workspace = true
semver.workspace = true

View file

@ -53,10 +53,30 @@ pub trait CloudLlmTokenProvider: Send + Sync {
type AuthContext: Clone + Send + 'static;
fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
fn cached_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
}
/// Sends an authenticated request to the Zed LLM service, retrying once with
/// a refreshed token if the server signals that the cached LLM token is
/// expired or otherwise rejected. Returns the raw response so callers can
/// inspect headers and stream the body.
pub async fn authenticated_llm_request<TP: CloudLlmTokenProvider>(
http_client: &HttpClientWithUrl,
token_provider: &TP,
auth_context: TP::AuthContext,
build_request: impl Fn(&str) -> Result<http_client::Request<AsyncBody>>,
) -> Result<Response<AsyncBody>> {
let token = token_provider.cached_token(auth_context.clone()).await?;
let response = http_client.send(build_request(&token)?).await?;
if !needs_llm_token_refresh(&response) && response.status() != StatusCode::UNAUTHORIZED {
return Ok(response);
}
log::info!("LLM token rejected; refreshing and retrying request");
let token = token_provider.refresh_token(auth_context).await?;
http_client.send(build_request(&token)?).await
}
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
@ -99,55 +119,49 @@ impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
app_version: Option<Version>,
body: CompletionBody,
) -> Result<PerformLlmCompletionResponse> {
let mut token = token_provider.acquire_token(auth_context.clone()).await?;
let mut refreshed_token = false;
let url = http_client.build_zed_llm_url("/completions", &[])?;
let body = serde_json::to_string(&body)?;
let mut response =
authenticated_llm_request(http_client, token_provider, auth_context, |token| {
Ok(http_client::Request::builder()
.method(Method::POST)
.uri(url.as_ref())
.when_some(app_version.as_ref(), |builder, app_version| {
builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
})
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
.header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
.body(body.clone().into())?)
})
.await?;
loop {
let request = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
.when_some(app_version.as_ref(), |builder, app_version| {
builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
})
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
.header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?;
let status = response.status();
if status.is_success() {
let includes_status_messages = response
.headers()
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
.is_some();
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
let includes_status_messages = response
.headers()
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
.is_some();
return Ok(PerformLlmCompletionResponse {
response,
includes_status_messages,
});
}
if !refreshed_token && needs_llm_token_refresh(&response) {
token = token_provider.refresh_token(auth_context.clone()).await?;
refreshed_token = true;
continue;
}
if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
}
let mut body = String::new();
let headers = response.headers().clone();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError {
status,
body,
headers
}));
return Ok(PerformLlmCompletionResponse {
response,
includes_status_messages,
});
}
if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
}
let mut body = String::new();
let headers = response.headers().clone();
response.body_mut().read_to_string(&mut body).await?;
Err(anyhow!(ApiError {
status,
body,
headers
}))
}
}
@ -642,16 +656,16 @@ impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
token_provider: &TP,
auth_context: TP::AuthContext,
) -> Result<ListModelsResponse> {
let token = token_provider.acquire_token(auth_context).await?;
let request = http_client::Request::builder()
.method(Method::GET)
.header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
.uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
.header("Authorization", format!("Bearer {token}"))
.body(AsyncBody::empty())?;
let mut response = http_client
.send(request)
let url = http_client.build_zed_llm_url("/models", &[])?;
let mut response =
authenticated_llm_request(http_client, token_provider, auth_context, |token| {
Ok(http_client::Request::builder()
.method(Method::GET)
.header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
.uri(url.as_ref())
.header("Authorization", format!("Bearer {token}"))
.body(AsyncBody::empty())?)
})
.await
.context("failed to send list models request")?;

View file

@ -1,13 +1,13 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use anyhow::Result;
use client::{Client, UserStore, global_llm_token};
use cloud_api_client::LlmApiToken;
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
use http_client::Method;
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@ -69,50 +69,27 @@ async fn perform_web_search(
organization_id: Option<OrganizationId>,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
const MAX_RETRIES: usize = 3;
let http_client = &client.http_client();
let mut retries_remaining = MAX_RETRIES;
let mut token = client
.acquire_llm_token(&llm_api_token, organization_id.clone())
let url = client.http_client().build_zed_llm_url("/web_search", &[])?;
let body = serde_json::to_string(&body)?;
let mut response = client
.authenticated_llm_request(&llm_api_token, organization_id, |token| {
Ok(http_client::Request::builder()
.method(Method::POST)
.uri(url.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(body.clone().into())?)
})
.await?;
loop {
if retries_remaining == 0 {
return Err(anyhow::anyhow!(
"error performing web search, max retries exceeded"
));
}
let request = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client
.send(request)
.await
.context("failed to send web search request")?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if response.needs_llm_token_refresh() {
token = client
.refresh_llm_token(&llm_api_token, organization_id.clone())
.await?;
retries_remaining -= 1;
} else {
// For now we will only retry if the LLM token is expired,
// not if the request failed for any other reason.
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"error performing web search.\nStatus: {:?}\nBody: {body}",
response.status(),
);
}
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(serde_json::from_str(&body)?)
} else {
let status = response.status();
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!("error performing web search.\nStatus: {status:?}\nBody: {body}");
}
}