mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Fix token refresh for HTTP requests (#56559)
Code had been assuming (erroneously, but understandably) that LlmApiToken::acquire would give them a valid token. This is not true, as those tokens expire and you must call refresh explicitly. Add some helpers to do the retry for you, and rename acquire to cached to be clearer about the intent. Closes #ISSUE Release Notes: - Fixed some rare cases where API requests would fail with Unauthorized
This commit is contained in:
parent
64f624773f
commit
54188321be
9 changed files with 200 additions and 217 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -9696,6 +9696,7 @@ dependencies = [
|
|||
"gpui",
|
||||
"http_client",
|
||||
"language_model",
|
||||
"log",
|
||||
"open_ai",
|
||||
"schemars 1.0.4",
|
||||
"semver",
|
||||
|
|
|
|||
|
|
@ -1539,7 +1539,7 @@ impl Client {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn acquire_llm_token(
|
||||
pub async fn cached_llm_token(
|
||||
&self,
|
||||
llm_token: &LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
|
|
@ -1547,7 +1547,7 @@ impl Client {
|
|||
let system_id = self.telemetry().system_id().map(|x| x.to_string());
|
||||
let cloud_client = self.cloud_client();
|
||||
match llm_token
|
||||
.acquire(&cloud_client, system_id, organization_id)
|
||||
.cached(&cloud_client, system_id, organization_id)
|
||||
.await
|
||||
{
|
||||
Ok(token) => Ok(token),
|
||||
|
|
@ -1559,6 +1559,31 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sends an authenticated request to the Zed LLM service, retrying once
|
||||
/// with a refreshed token if the server signals that the cached LLM
|
||||
/// token is expired or otherwise rejected. Returns the raw response so
|
||||
/// callers can inspect headers and stream the body.
|
||||
pub async fn authenticated_llm_request(
|
||||
&self,
|
||||
llm_token: &LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
build_request: impl Fn(&str) -> Result<http_client::Request<http_client::AsyncBody>>,
|
||||
) -> Result<http_client::Response<http_client::AsyncBody>> {
|
||||
let http_client = self.http_client();
|
||||
let token = self
|
||||
.cached_llm_token(llm_token, organization_id.clone())
|
||||
.await?;
|
||||
let response = http_client.send(build_request(&token)?).await?;
|
||||
if !response.needs_llm_token_refresh()
|
||||
&& response.status() != http_client::http::StatusCode::UNAUTHORIZED
|
||||
{
|
||||
return Ok(response);
|
||||
}
|
||||
log::info!("LLM token rejected; refreshing and retrying request");
|
||||
let token = self.refresh_llm_token(llm_token, organization_id).await?;
|
||||
http_client.send(build_request(&token)?).await
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_token(
|
||||
&self,
|
||||
llm_token: &LlmApiToken,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,11 @@ use crate::{ClientApiError, CloudApiClient};
|
|||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LlmApiToken {
|
||||
pub async fn acquire(
|
||||
/// Returns the cached LLM token, fetching a fresh one only if none has
|
||||
/// been cached yet. The returned token is not validated; callers must
|
||||
/// be prepared to refresh it (via [`LlmApiToken::refresh`]) if the
|
||||
/// server rejects it.
|
||||
pub async fn cached(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use anyhow::Result;
|
||||
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
|
||||
use client::{Client, EditPredictionUsage, UserStore, global_llm_token};
|
||||
use cloud_api_client::LlmApiToken;
|
||||
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
|
|
@ -889,18 +889,19 @@ impl EditPredictionStore {
|
|||
cx.spawn(async move |this, cx| {
|
||||
let experiments = cx
|
||||
.background_spawn(async move {
|
||||
let http_client = client.http_client();
|
||||
let token = client
|
||||
.acquire_llm_token(&llm_token, organization_id.clone())
|
||||
let url = client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/edit_prediction_experiments", &[])?;
|
||||
let mut response = client
|
||||
.authenticated_llm_request(&llm_token, organization_id, |token| {
|
||||
Ok(http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(url.as_ref())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
|
||||
.body(Default::default())?)
|
||||
})
|
||||
.await?;
|
||||
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(url.as_ref())
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
|
||||
.body(Default::default())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
|
@ -1580,7 +1581,6 @@ impl EditPredictionStore {
|
|||
llm_token.clone(),
|
||||
organization_id,
|
||||
app_version.clone(),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -2582,7 +2582,6 @@ impl EditPredictionStore {
|
|||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -2627,7 +2626,6 @@ impl EditPredictionStore {
|
|||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -2638,78 +2636,55 @@ impl EditPredictionStore {
|
|||
llm_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
app_version: Version,
|
||||
require_auth: bool,
|
||||
) -> Result<(Res, Option<EditPredictionUsage>)>
|
||||
where
|
||||
Res: DeserializeOwned,
|
||||
{
|
||||
let http_client = client.http_client();
|
||||
let mut token = if require_auth {
|
||||
Some(
|
||||
client
|
||||
.acquire_llm_token(&llm_token, organization_id.clone())
|
||||
.await?,
|
||||
)
|
||||
let response = client
|
||||
.authenticated_llm_request(&llm_token, organization_id, |token| {
|
||||
build(
|
||||
http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.header("Content-Type", "application/json")
|
||||
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
|
||||
.header("Authorization", format!("Bearer {token}")),
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
Self::process_api_response(response, &app_version).await
|
||||
}
|
||||
|
||||
async fn process_api_response<Res>(
|
||||
mut response: http_client::Response<AsyncBody>,
|
||||
app_version: &Version,
|
||||
) -> Result<(Res, Option<EditPredictionUsage>)>
|
||||
where
|
||||
Res: DeserializeOwned,
|
||||
{
|
||||
if let Some(minimum_required_version) = response
|
||||
.headers()
|
||||
.get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
|
||||
.and_then(|version| Version::from_str(version.to_str().ok()?).ok())
|
||||
{
|
||||
anyhow::ensure!(
|
||||
*app_version >= minimum_required_version,
|
||||
ZedUpdateRequiredError {
|
||||
minimum_version: minimum_required_version
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
if response.status().is_success() {
|
||||
let usage = EditPredictionUsage::from_headers(response.headers()).ok();
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
Ok((serde_json::from_slice(&body)?, usage))
|
||||
} else {
|
||||
client
|
||||
.acquire_llm_token(&llm_token, organization_id.clone())
|
||||
.await
|
||||
.ok()
|
||||
};
|
||||
let mut did_retry = false;
|
||||
|
||||
loop {
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
|
||||
let mut request_builder = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header(ZED_VERSION_HEADER_NAME, app_version.to_string());
|
||||
|
||||
// Only add Authorization header if we have a token
|
||||
if let Some(ref token_value) = token {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", token_value));
|
||||
}
|
||||
|
||||
let request = build(request_builder)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
if let Some(minimum_required_version) = response
|
||||
.headers()
|
||||
.get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
|
||||
.and_then(|version| Version::from_str(version.to_str().ok()?).ok())
|
||||
{
|
||||
anyhow::ensure!(
|
||||
app_version >= minimum_required_version,
|
||||
ZedUpdateRequiredError {
|
||||
minimum_version: minimum_required_version
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
if response.status().is_success() {
|
||||
let usage = EditPredictionUsage::from_headers(response.headers()).ok();
|
||||
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
return Ok((serde_json::from_slice(&body)?, usage));
|
||||
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
|
||||
did_retry = true;
|
||||
token = Some(
|
||||
client
|
||||
.refresh_llm_token(&llm_token, organization_id.clone())
|
||||
.await?,
|
||||
);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
body
|
||||
);
|
||||
}
|
||||
let status = response.status();
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!("Request failed with status: {status:?}\nBody: {body}");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ use ui::SharedString;
|
|||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
use zeta_prompt::{ParsedOutput, ZetaPromptInput};
|
||||
|
||||
use std::{env, ops::Range, path::Path, sync::Arc};
|
||||
use std::{ops::Range, path::Path, sync::Arc};
|
||||
use zeta_prompt::{
|
||||
ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, stop_tokens_for_format,
|
||||
zeta1::{self, EDITABLE_REGION_END_MARKER},
|
||||
|
|
@ -632,15 +632,13 @@ pub(crate) fn edit_prediction_accepted(
|
|||
current_prediction: CurrentEditPrediction,
|
||||
cx: &App,
|
||||
) {
|
||||
let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
|
||||
if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
|
||||
if store.zeta2_raw_config().is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let request_id = current_prediction.prediction.id.to_string();
|
||||
let model_version = current_prediction.prediction.model_version;
|
||||
let e2e_latency = current_prediction.e2e_latency;
|
||||
let require_auth = custom_accept_url.is_none();
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let organization_id = store
|
||||
|
|
@ -651,35 +649,23 @@ pub(crate) fn edit_prediction_accepted(
|
|||
let app_version = AppVersion::global(cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let url = if let Some(accept_edits_url) = custom_accept_url {
|
||||
gpui::http_client::Url::parse(&accept_edits_url)?
|
||||
} else {
|
||||
client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/accept", &[])?
|
||||
};
|
||||
let body = serde_json::to_string(&AcceptEditPredictionBody {
|
||||
request_id,
|
||||
model_version,
|
||||
e2e_latency_ms: Some(e2e_latency.as_millis()),
|
||||
})?;
|
||||
|
||||
let response = EditPredictionStore::send_api_request::<()>(
|
||||
move |builder| {
|
||||
let req = builder.uri(url.as_ref()).body(
|
||||
serde_json::to_string(&AcceptEditPredictionBody {
|
||||
request_id: request_id.clone(),
|
||||
model_version: model_version.clone(),
|
||||
e2e_latency_ms: Some(e2e_latency.as_millis()),
|
||||
})?
|
||||
.into(),
|
||||
);
|
||||
Ok(req?)
|
||||
},
|
||||
let url = client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/accept", &[])?;
|
||||
EditPredictionStore::send_api_request::<()>(
|
||||
move |builder| Ok(builder.uri(url.as_ref()).body(body.clone().into())?),
|
||||
client,
|
||||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
require_auth,
|
||||
)
|
||||
.await;
|
||||
|
||||
response?;
|
||||
.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ impl CloudLlmTokenProvider for ClientTokenProvider {
|
|||
})
|
||||
}
|
||||
|
||||
fn acquire_token(
|
||||
fn cached_token(
|
||||
&self,
|
||||
organization_id: Self::AuthContext,
|
||||
) -> BoxFuture<'static, Result<String>> {
|
||||
|
|
@ -50,7 +50,7 @@ impl CloudLlmTokenProvider for ClientTokenProvider {
|
|||
let llm_api_token = self.llm_api_token.clone();
|
||||
Box::pin(async move {
|
||||
client
|
||||
.acquire_llm_token(&llm_api_token, organization_id)
|
||||
.cached_llm_token(&llm_api_token, organization_id)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ google_ai = { workspace = true, features = ["schemars"] }
|
|||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
schemars.workspace = true
|
||||
semver.workspace = true
|
||||
|
|
|
|||
|
|
@ -53,10 +53,30 @@ pub trait CloudLlmTokenProvider: Send + Sync {
|
|||
type AuthContext: Clone + Send + 'static;
|
||||
|
||||
fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
|
||||
fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
|
||||
fn cached_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
|
||||
fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
|
||||
}
|
||||
|
||||
/// Sends an authenticated request to the Zed LLM service, retrying once with
|
||||
/// a refreshed token if the server signals that the cached LLM token is
|
||||
/// expired or otherwise rejected. Returns the raw response so callers can
|
||||
/// inspect headers and stream the body.
|
||||
pub async fn authenticated_llm_request<TP: CloudLlmTokenProvider>(
|
||||
http_client: &HttpClientWithUrl,
|
||||
token_provider: &TP,
|
||||
auth_context: TP::AuthContext,
|
||||
build_request: impl Fn(&str) -> Result<http_client::Request<AsyncBody>>,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let token = token_provider.cached_token(auth_context.clone()).await?;
|
||||
let response = http_client.send(build_request(&token)?).await?;
|
||||
if !needs_llm_token_refresh(&response) && response.status() != StatusCode::UNAUTHORIZED {
|
||||
return Ok(response);
|
||||
}
|
||||
log::info!("LLM token rejected; refreshing and retrying request");
|
||||
let token = token_provider.refresh_token(auth_context).await?;
|
||||
http_client.send(build_request(&token)?).await
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ModelMode {
|
||||
|
|
@ -99,55 +119,49 @@ impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
|
|||
app_version: Option<Version>,
|
||||
body: CompletionBody,
|
||||
) -> Result<PerformLlmCompletionResponse> {
|
||||
let mut token = token_provider.acquire_token(auth_context.clone()).await?;
|
||||
let mut refreshed_token = false;
|
||||
let url = http_client.build_zed_llm_url("/completions", &[])?;
|
||||
let body = serde_json::to_string(&body)?;
|
||||
let mut response =
|
||||
authenticated_llm_request(http_client, token_provider, auth_context, |token| {
|
||||
Ok(http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(url.as_ref())
|
||||
.when_some(app_version.as_ref(), |builder, app_version| {
|
||||
builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
|
||||
})
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
|
||||
.header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
|
||||
.body(body.clone().into())?)
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
|
||||
.when_some(app_version.as_ref(), |builder, app_version| {
|
||||
builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
|
||||
})
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
|
||||
.header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let includes_status_messages = response
|
||||
.headers()
|
||||
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
|
||||
.is_some();
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let includes_status_messages = response
|
||||
.headers()
|
||||
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
|
||||
.is_some();
|
||||
|
||||
return Ok(PerformLlmCompletionResponse {
|
||||
response,
|
||||
includes_status_messages,
|
||||
});
|
||||
}
|
||||
|
||||
if !refreshed_token && needs_llm_token_refresh(&response) {
|
||||
token = token_provider.refresh_token(auth_context.clone()).await?;
|
||||
refreshed_token = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if status == StatusCode::PAYMENT_REQUIRED {
|
||||
return Err(anyhow!(PaymentRequiredError));
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
let headers = response.headers().clone();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(ApiError {
|
||||
status,
|
||||
body,
|
||||
headers
|
||||
}));
|
||||
return Ok(PerformLlmCompletionResponse {
|
||||
response,
|
||||
includes_status_messages,
|
||||
});
|
||||
}
|
||||
|
||||
if status == StatusCode::PAYMENT_REQUIRED {
|
||||
return Err(anyhow!(PaymentRequiredError));
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
let headers = response.headers().clone();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
Err(anyhow!(ApiError {
|
||||
status,
|
||||
body,
|
||||
headers
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -642,16 +656,16 @@ impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
|
|||
token_provider: &TP,
|
||||
auth_context: TP::AuthContext,
|
||||
) -> Result<ListModelsResponse> {
|
||||
let token = token_provider.acquire_token(auth_context).await?;
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
.header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
|
||||
.uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(AsyncBody::empty())?;
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
let url = http_client.build_zed_llm_url("/models", &[])?;
|
||||
let mut response =
|
||||
authenticated_llm_request(http_client, token_provider, auth_context, |token| {
|
||||
Ok(http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
.header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
|
||||
.uri(url.as_ref())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(AsyncBody::empty())?)
|
||||
})
|
||||
.await
|
||||
.context("failed to send list models request")?;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore, global_llm_token};
|
||||
use cloud_api_client::LlmApiToken;
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use http_client::Method;
|
||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||
|
||||
pub struct CloudWebSearchProvider {
|
||||
|
|
@ -69,50 +69,27 @@ async fn perform_web_search(
|
|||
organization_id: Option<OrganizationId>,
|
||||
body: WebSearchBody,
|
||||
) -> Result<WebSearchResponse> {
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
let http_client = &client.http_client();
|
||||
let mut retries_remaining = MAX_RETRIES;
|
||||
let mut token = client
|
||||
.acquire_llm_token(&llm_api_token, organization_id.clone())
|
||||
let url = client.http_client().build_zed_llm_url("/web_search", &[])?;
|
||||
let body = serde_json::to_string(&body)?;
|
||||
let mut response = client
|
||||
.authenticated_llm_request(&llm_api_token, organization_id, |token| {
|
||||
Ok(http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(url.as_ref())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(body.clone().into())?)
|
||||
})
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
if retries_remaining == 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"error performing web search, max retries exceeded"
|
||||
));
|
||||
}
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send web search request")?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else if response.needs_llm_token_refresh() {
|
||||
token = client
|
||||
.refresh_llm_token(&llm_api_token, organization_id.clone())
|
||||
.await?;
|
||||
retries_remaining -= 1;
|
||||
} else {
|
||||
// For now we will only retry if the LLM token is expired,
|
||||
// not if the request failed for any other reason.
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!(
|
||||
"error performing web search.\nStatus: {:?}\nBody: {body}",
|
||||
response.status(),
|
||||
);
|
||||
}
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
} else {
|
||||
let status = response.status();
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!("error performing web search.\nStatus: {status:?}\nBody: {body}");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue