mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
cloud_api_client: Send the organization ID in LLM token requests (#50517)
This is already expected on the cloud side. This lets us know under which organization the user is logged in when requesting an llm_api token. Closes CLO-337 Release Notes: - N/A
This commit is contained in:
parent
5641ccf250
commit
a1d40370cf
14 changed files with 247 additions and 61 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -19851,6 +19851,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ use futures::AsyncReadExt as _;
|
|||
use gpui::{App, Task};
|
||||
use gpui_tokio::Tokio;
|
||||
use http_client::http::request;
|
||||
use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode};
|
||||
use http_client::{
|
||||
AsyncBody, HttpClientWithUrl, HttpRequestExt, Json, Method, Request, StatusCode,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use thiserror::Error;
|
||||
use yawc::WebSocket;
|
||||
|
|
@ -141,6 +143,7 @@ impl CloudApiClient {
|
|||
pub async fn create_llm_token(
|
||||
&self,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<CreateLlmTokenResponse, ClientApiError> {
|
||||
let request_builder = Request::builder()
|
||||
.method(Method::POST)
|
||||
|
|
@ -153,7 +156,10 @@ impl CloudApiClient {
|
|||
builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id)
|
||||
});
|
||||
|
||||
let request = self.build_request(request_builder, AsyncBody::default())?;
|
||||
let request = self.build_request(
|
||||
request_builder,
|
||||
Json(CreateLlmTokenBody { organization_id }),
|
||||
)?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,12 @@ pub struct AcceptTermsOfServiceResponse {
|
|||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToken(pub String);
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateLlmTokenBody {
|
||||
#[serde(default)]
|
||||
pub organization_id: Option<OrganizationId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateLlmTokenResponse {
|
||||
pub token: LlmToken,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use anyhow::Result;
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, EditPredictionUsage, UserStore};
|
||||
use cloud_api_types::SubmitEditPredictionFeedbackBody;
|
||||
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
|
||||
};
|
||||
|
|
@ -143,7 +143,7 @@ pub struct EditPredictionStore {
|
|||
pub sweep_ai: SweepAi,
|
||||
pub mercury: Mercury,
|
||||
data_collection_choice: DataCollectionChoice,
|
||||
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
|
||||
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
|
||||
settled_predictions_tx: mpsc::UnboundedSender<Instant>,
|
||||
shown_predictions: VecDeque<EditPrediction>,
|
||||
rated_predictions: HashSet<EditPredictionId>,
|
||||
|
|
@ -151,6 +151,11 @@ pub struct EditPredictionStore {
|
|||
settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
|
||||
}
|
||||
|
||||
pub(crate) struct EditPredictionRejectionPayload {
|
||||
rejection: EditPredictionRejection,
|
||||
organization_id: Option<OrganizationId>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum EditPredictionModel {
|
||||
Zeta,
|
||||
|
|
@ -719,8 +724,13 @@ impl EditPredictionStore {
|
|||
|this, _listener, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_token = this.llm_token.clone();
|
||||
let organization_id = this
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
llm_token.refresh(&client).await?;
|
||||
llm_token.refresh(&client, organization_id).await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
|
@ -781,11 +791,17 @@ impl EditPredictionStore {
|
|||
let client = self.client.clone();
|
||||
let llm_token = self.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
let organization_id = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let experiments = cx
|
||||
.background_spawn(async move {
|
||||
let http_client = client.http_client();
|
||||
let token = llm_token.acquire(&client).await?;
|
||||
let token = llm_token.acquire(&client, organization_id).await?;
|
||||
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
|
|
@ -1424,7 +1440,7 @@ impl EditPredictionStore {
|
|||
}
|
||||
|
||||
async fn handle_rejected_predictions(
|
||||
rx: UnboundedReceiver<EditPredictionRejection>,
|
||||
rx: UnboundedReceiver<EditPredictionRejectionPayload>,
|
||||
client: Arc<Client>,
|
||||
llm_token: LlmApiToken,
|
||||
app_version: Version,
|
||||
|
|
@ -1433,7 +1449,11 @@ impl EditPredictionStore {
|
|||
let mut rx = std::pin::pin!(rx.peekable());
|
||||
let mut batched = Vec::new();
|
||||
|
||||
while let Some(rejection) = rx.next().await {
|
||||
while let Some(EditPredictionRejectionPayload {
|
||||
rejection,
|
||||
organization_id,
|
||||
}) = rx.next().await
|
||||
{
|
||||
batched.push(rejection);
|
||||
|
||||
if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
|
||||
|
|
@ -1471,6 +1491,7 @@ impl EditPredictionStore {
|
|||
},
|
||||
client.clone(),
|
||||
llm_token.clone(),
|
||||
organization_id,
|
||||
app_version.clone(),
|
||||
true,
|
||||
)
|
||||
|
|
@ -1676,13 +1697,23 @@ impl EditPredictionStore {
|
|||
all_language_settings(None, cx).edit_predictions.provider,
|
||||
EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
|
||||
);
|
||||
|
||||
if is_cloud {
|
||||
let organization_id = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
|
||||
self.reject_predictions_tx
|
||||
.unbounded_send(EditPredictionRejection {
|
||||
request_id: prediction_id.to_string(),
|
||||
reason,
|
||||
was_shown,
|
||||
model_version,
|
||||
.unbounded_send(EditPredictionRejectionPayload {
|
||||
rejection: EditPredictionRejection {
|
||||
request_id: prediction_id.to_string(),
|
||||
reason,
|
||||
was_shown,
|
||||
model_version,
|
||||
},
|
||||
organization_id,
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
|
@ -2337,6 +2368,7 @@ impl EditPredictionStore {
|
|||
client: Arc<Client>,
|
||||
custom_url: Option<Arc<Url>>,
|
||||
llm_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
app_version: Version,
|
||||
) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
|
||||
let url = if let Some(custom_url) = custom_url {
|
||||
|
|
@ -2356,6 +2388,7 @@ impl EditPredictionStore {
|
|||
},
|
||||
client,
|
||||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
true,
|
||||
)
|
||||
|
|
@ -2366,6 +2399,7 @@ impl EditPredictionStore {
|
|||
input: ZetaPromptInput,
|
||||
client: Arc<Client>,
|
||||
llm_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
app_version: Version,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
|
||||
|
|
@ -2388,6 +2422,7 @@ impl EditPredictionStore {
|
|||
},
|
||||
client,
|
||||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
true,
|
||||
)
|
||||
|
|
@ -2441,6 +2476,7 @@ impl EditPredictionStore {
|
|||
build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
|
||||
client: Arc<Client>,
|
||||
llm_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
app_version: Version,
|
||||
require_auth: bool,
|
||||
) -> Result<(Res, Option<EditPredictionUsage>)>
|
||||
|
|
@ -2450,9 +2486,12 @@ impl EditPredictionStore {
|
|||
let http_client = client.http_client();
|
||||
|
||||
let mut token = if require_auth {
|
||||
Some(llm_token.acquire(&client).await?)
|
||||
Some(llm_token.acquire(&client, organization_id.clone()).await?)
|
||||
} else {
|
||||
llm_token.acquire(&client).await.ok()
|
||||
llm_token
|
||||
.acquire(&client, organization_id.clone())
|
||||
.await
|
||||
.ok()
|
||||
};
|
||||
let mut did_retry = false;
|
||||
|
||||
|
|
@ -2494,7 +2533,7 @@ impl EditPredictionStore {
|
|||
return Ok((serde_json::from_slice(&body)?, usage));
|
||||
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
|
||||
did_retry = true;
|
||||
token = Some(llm_token.refresh(&client).await?);
|
||||
token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
|
|
|||
|
|
@ -66,6 +66,11 @@ pub fn request_prediction_with_zeta(
|
|||
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let organization_id = store
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
|
|
@ -201,6 +206,7 @@ pub fn request_prediction_with_zeta(
|
|||
client,
|
||||
None,
|
||||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
)
|
||||
.await?;
|
||||
|
|
@ -219,6 +225,7 @@ pub fn request_prediction_with_zeta(
|
|||
prompt_input.clone(),
|
||||
client,
|
||||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
trigger,
|
||||
)
|
||||
|
|
@ -430,6 +437,11 @@ pub(crate) fn edit_prediction_accepted(
|
|||
let require_auth = custom_accept_url.is_none();
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let organization_id = store
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|organization| organization.id.clone());
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
|
|
@ -454,6 +466,7 @@ pub(crate) fn edit_prediction_accepted(
|
|||
},
|
||||
client,
|
||||
llm_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
require_auth,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ use std::{
|
|||
use bytes::Bytes;
|
||||
use futures::AsyncRead;
|
||||
use http_body::{Body, Frame};
|
||||
use serde::Serialize;
|
||||
|
||||
/// Based on the implementation of AsyncBody in
|
||||
/// <https://github.com/sagebind/isahc/blob/5c533f1ef4d6bdf1fd291b5103c22110f41d0bf0/src/body/mod.rs>.
|
||||
|
|
@ -88,6 +89,19 @@ impl From<&'static str> for AsyncBody {
|
|||
}
|
||||
}
|
||||
|
||||
/// Newtype wrapper that serializes a value as JSON into an `AsyncBody`.
|
||||
pub struct Json<T: Serialize>(pub T);
|
||||
|
||||
impl<T: Serialize> From<Json<T>> for AsyncBody {
|
||||
fn from(json: Json<T>) -> Self {
|
||||
Self::from_bytes(
|
||||
serde_json::to_vec(&json.0)
|
||||
.expect("failed to serialize JSON")
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<Self>> From<Option<T>> for AsyncBody {
|
||||
fn from(body: Option<T>) -> Self {
|
||||
match body {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ pub mod github;
|
|||
pub mod github_download;
|
||||
|
||||
pub use anyhow::{Result, anyhow};
|
||||
pub use async_body::{AsyncBody, Inner};
|
||||
pub use async_body::{AsyncBody, Inner, Json};
|
||||
use derive_more::Deref;
|
||||
use http::HeaderValue;
|
||||
pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ use std::sync::Arc;
|
|||
use anyhow::{Context as _, Result};
|
||||
use client::Client;
|
||||
use cloud_api_client::ClientApiError;
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_api_types::websocket_protocol::MessageToClient;
|
||||
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
|
||||
|
|
@ -26,29 +27,46 @@ impl fmt::Display for PaymentRequiredError {
|
|||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LlmApiToken {
|
||||
pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||
pub async fn acquire(
|
||||
&self,
|
||||
client: &Arc<Client>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
if let Some(token) = lock.as_ref() {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
|
||||
Self::fetch(
|
||||
RwLockUpgradableReadGuard::upgrade(lock).await,
|
||||
client,
|
||||
organization_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
|
||||
Self::fetch(self.0.write().await, client).await
|
||||
pub async fn refresh(
|
||||
&self,
|
||||
client: &Arc<Client>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
Self::fetch(self.0.write().await, client, organization_id).await
|
||||
}
|
||||
|
||||
async fn fetch(
|
||||
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
||||
client: &Arc<Client>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String> {
|
||||
let system_id = client
|
||||
.telemetry()
|
||||
.system_id()
|
||||
.map(|system_id| system_id.to_string());
|
||||
|
||||
let result = client.cloud_client().create_llm_token(system_id).await;
|
||||
let result = client
|
||||
.cloud_client()
|
||||
.create_llm_token(system_id, organization_id)
|
||||
.await;
|
||||
match result {
|
||||
Ok(response) => {
|
||||
*lock = Some(response.token.0.clone());
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use anthropic::AnthropicModelMode;
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use cloud_api_types::Plan;
|
||||
use cloud_api_types::{OrganizationId, Plan};
|
||||
use cloud_llm_client::{
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
|
||||
CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
|
||||
|
|
@ -122,15 +122,25 @@ impl State {
|
|||
recommended_models: Vec::new(),
|
||||
_fetch_models_task: cx.spawn(async move |this, cx| {
|
||||
maybe!(async move {
|
||||
let (client, llm_api_token) = this
|
||||
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
|
||||
let (client, llm_api_token, organization_id) =
|
||||
this.read_with(cx, |this, cx| {
|
||||
(
|
||||
client.clone(),
|
||||
this.llm_api_token.clone(),
|
||||
this.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|o| o.id.clone()),
|
||||
)
|
||||
})?;
|
||||
|
||||
while current_user.borrow().is_none() {
|
||||
current_user.next().await;
|
||||
}
|
||||
|
||||
let response =
|
||||
Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
|
||||
Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id)
|
||||
.await?;
|
||||
this.update(cx, |this, cx| this.update_models(response, cx))?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
|
|
@ -146,9 +156,17 @@ impl State {
|
|||
move |this, _listener, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
let organization_id = this
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|o| o.id.clone());
|
||||
cx.spawn(async move |this, cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
let response = Self::fetch_models(client, llm_api_token).await?;
|
||||
llm_api_token
|
||||
.refresh(&client, organization_id.clone())
|
||||
.await?;
|
||||
let response =
|
||||
Self::fetch_models(client, llm_api_token, organization_id).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
|
|
@ -209,9 +227,10 @@ impl State {
|
|||
async fn fetch_models(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<ListModelsResponse> {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
let token = llm_api_token.acquire(&client, organization_id).await?;
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
|
|
@ -273,11 +292,13 @@ impl CloudLanguageModelProvider {
|
|||
&self,
|
||||
model: Arc<cloud_llm_client::LanguageModel>,
|
||||
llm_api_token: LlmApiToken,
|
||||
user_store: Entity<UserStore>,
|
||||
) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(CloudLanguageModel {
|
||||
id: LanguageModelId(SharedString::from(model.id.0.clone())),
|
||||
model,
|
||||
llm_api_token,
|
||||
user_store,
|
||||
client: self.client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
|
|
@ -306,36 +327,46 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
let default_model = self.state.read(cx).default_model.clone()?;
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
Some(self.create_language_model(default_model, llm_api_token))
|
||||
let state = self.state.read(cx);
|
||||
let default_model = state.default_model.clone()?;
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let user_store = state.user_store.clone();
|
||||
Some(self.create_language_model(default_model, llm_api_token, user_store))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
Some(self.create_language_model(default_fast_model, llm_api_token))
|
||||
let state = self.state.read(cx);
|
||||
let default_fast_model = state.default_fast_model.clone()?;
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let user_store = state.user_store.clone();
|
||||
Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
|
||||
}
|
||||
|
||||
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
self.state
|
||||
.read(cx)
|
||||
let state = self.state.read(cx);
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let user_store = state.user_store.clone();
|
||||
state
|
||||
.recommended_models
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|model| self.create_language_model(model, llm_api_token.clone()))
|
||||
.map(|model| {
|
||||
self.create_language_model(model, llm_api_token.clone(), user_store.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
self.state
|
||||
.read(cx)
|
||||
let state = self.state.read(cx);
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let user_store = state.user_store.clone();
|
||||
state
|
||||
.models
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|model| self.create_language_model(model, llm_api_token.clone()))
|
||||
.map(|model| {
|
||||
self.create_language_model(model, llm_api_token.clone(), user_store.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
|
@ -367,6 +398,7 @@ pub struct CloudLanguageModel {
|
|||
id: LanguageModelId,
|
||||
model: Arc<cloud_llm_client::LanguageModel>,
|
||||
llm_api_token: LlmApiToken,
|
||||
user_store: Entity<UserStore>,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
|
@ -380,12 +412,15 @@ impl CloudLanguageModel {
|
|||
async fn perform_llm_completion(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
app_version: Option<Version>,
|
||||
body: CompletionBody,
|
||||
) -> Result<PerformLlmCompletionResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut token = llm_api_token
|
||||
.acquire(&client, organization_id.clone())
|
||||
.await?;
|
||||
let mut refreshed_token = false;
|
||||
|
||||
loop {
|
||||
|
|
@ -416,7 +451,9 @@ impl CloudLanguageModel {
|
|||
}
|
||||
|
||||
if !refreshed_token && response.needs_llm_token_refresh() {
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
token = llm_api_token
|
||||
.refresh(&client, organization_id.clone())
|
||||
.await?;
|
||||
refreshed_token = true;
|
||||
continue;
|
||||
}
|
||||
|
|
@ -670,12 +707,17 @@ impl LanguageModel for CloudLanguageModel {
|
|||
cloud_llm_client::LanguageModelProvider::Google => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let organization_id = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|o| o.id.clone());
|
||||
let model_id = self.model.id.to_string();
|
||||
let generate_content_request =
|
||||
into_google(request, model_id.clone(), GoogleModelMode::Default);
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
let token = llm_api_token.acquire(&client, organization_id).await?;
|
||||
|
||||
let request_body = CountTokensBody {
|
||||
provider: cloud_llm_client::LanguageModelProvider::Google,
|
||||
|
|
@ -736,6 +778,13 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let prompt_id = request.prompt_id.clone();
|
||||
let intent = request.intent;
|
||||
let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
|
||||
let user_store = self.user_store.clone();
|
||||
let organization_id = cx.update(|cx| {
|
||||
user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|o| o.id.clone())
|
||||
});
|
||||
let thinking_allowed = request.thinking_allowed;
|
||||
let enable_thinking = thinking_allowed && self.model.supports_thinking;
|
||||
let provider_name = provider_name(&self.model.provider);
|
||||
|
|
@ -767,6 +816,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let organization_id = organization_id.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
|
|
@ -774,6 +824,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
CompletionBody {
|
||||
thread_id,
|
||||
|
|
@ -803,6 +854,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
cloud_llm_client::LanguageModelProvider::OpenAi => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let organization_id = organization_id.clone();
|
||||
let effort = request
|
||||
.thinking_effort
|
||||
.as_ref()
|
||||
|
|
@ -828,6 +880,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
CompletionBody {
|
||||
thread_id,
|
||||
|
|
@ -861,6 +914,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
None,
|
||||
);
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let organization_id = organization_id.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
|
|
@ -868,6 +922,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
CompletionBody {
|
||||
thread_id,
|
||||
|
|
@ -902,6 +957,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
organization_id,
|
||||
app_version,
|
||||
CompletionBody {
|
||||
thread_id,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::Client;
|
||||
use client::{Client, UserStore};
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
|
||||
|
|
@ -14,8 +15,8 @@ pub struct CloudWebSearchProvider {
|
|||
}
|
||||
|
||||
impl CloudWebSearchProvider {
|
||||
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State::new(client, cx));
|
||||
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State::new(client, user_store, cx));
|
||||
|
||||
Self { state }
|
||||
}
|
||||
|
|
@ -23,24 +24,31 @@ impl CloudWebSearchProvider {
|
|||
|
||||
pub struct State {
|
||||
client: Arc<Client>,
|
||||
user_store: Entity<UserStore>,
|
||||
llm_api_token: LlmApiToken,
|
||||
_llm_token_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
|
||||
Self {
|
||||
client,
|
||||
user_store,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
let organization_id = this
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|o| o.id.clone());
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
llm_api_token.refresh(&client, organization_id).await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
|
@ -61,21 +69,31 @@ impl WebSearchProvider for CloudWebSearchProvider {
|
|||
let state = self.state.read(cx);
|
||||
let client = state.client.clone();
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let organization_id = state
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_organization()
|
||||
.map(|o| o.id.clone());
|
||||
let body = WebSearchBody { query };
|
||||
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
|
||||
cx.background_spawn(async move {
|
||||
perform_web_search(client, llm_api_token, organization_id, body).await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_web_search(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
organization_id: Option<OrganizationId>,
|
||||
body: WebSearchBody,
|
||||
) -> Result<WebSearchResponse> {
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
let http_client = &client.http_client();
|
||||
let mut retries_remaining = MAX_RETRIES;
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut token = llm_api_token
|
||||
.acquire(&client, organization_id.clone())
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
if retries_remaining == 0 {
|
||||
|
|
@ -100,7 +118,9 @@ async fn perform_web_search(
|
|||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else if response.needs_llm_token_refresh() {
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
token = llm_api_token
|
||||
.refresh(&client, organization_id.clone())
|
||||
.await?;
|
||||
retries_remaining -= 1;
|
||||
} else {
|
||||
// For now we will only retry if the LLM token is expired,
|
||||
|
|
|
|||
|
|
@ -1,26 +1,28 @@
|
|||
mod cloud;
|
||||
|
||||
use client::Client;
|
||||
use client::{Client, UserStore};
|
||||
use gpui::{App, Context, Entity};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use std::sync::Arc;
|
||||
use web_search::{WebSearchProviderId, WebSearchRegistry};
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
let registry = WebSearchRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_web_search_providers(registry, client, cx);
|
||||
register_web_search_providers(registry, client, user_store, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn register_web_search_providers(
|
||||
registry: &mut WebSearchRegistry,
|
||||
client: Arc<Client>,
|
||||
user_store: Entity<UserStore>,
|
||||
cx: &mut Context<WebSearchRegistry>,
|
||||
) {
|
||||
register_zed_web_search_provider(
|
||||
registry,
|
||||
client.clone(),
|
||||
user_store.clone(),
|
||||
&LanguageModelRegistry::global(cx),
|
||||
cx,
|
||||
);
|
||||
|
|
@ -29,7 +31,13 @@ fn register_web_search_providers(
|
|||
&LanguageModelRegistry::global(cx),
|
||||
move |this, registry, event, cx| {
|
||||
if let language_model::Event::DefaultModelChanged = event {
|
||||
register_zed_web_search_provider(this, client.clone(), ®istry, cx)
|
||||
register_zed_web_search_provider(
|
||||
this,
|
||||
client.clone(),
|
||||
user_store.clone(),
|
||||
®istry,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
|
@ -39,6 +47,7 @@ fn register_web_search_providers(
|
|||
fn register_zed_web_search_provider(
|
||||
registry: &mut WebSearchRegistry,
|
||||
client: Arc<Client>,
|
||||
user_store: Entity<UserStore>,
|
||||
language_model_registry: &Entity<LanguageModelRegistry>,
|
||||
cx: &mut Context<WebSearchRegistry>,
|
||||
) {
|
||||
|
|
@ -47,7 +56,10 @@ fn register_zed_web_search_provider(
|
|||
.default_model()
|
||||
.is_some_and(|default| default.is_provided_by_zed());
|
||||
if using_zed_provider {
|
||||
registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
|
||||
registry.register_provider(
|
||||
cloud::CloudWebSearchProvider::new(client, user_store, cx),
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
registry.unregister_provider(WebSearchProviderId(
|
||||
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
|
||||
|
|
|
|||
|
|
@ -645,7 +645,7 @@ fn main() {
|
|||
zed::remote_debug::init(cx);
|
||||
edit_prediction_ui::init(cx);
|
||||
web_search::init(cx);
|
||||
web_search_providers::init(app_state.client.clone(), cx);
|
||||
web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
|
||||
snippet_provider::init(cx);
|
||||
edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx);
|
||||
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
|
||||
|
|
|
|||
|
|
@ -5021,7 +5021,7 @@ mod tests {
|
|||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
web_search::init(cx);
|
||||
git_graph::init(cx);
|
||||
web_search_providers::init(app_state.client.clone(), cx);
|
||||
web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
|
||||
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
|
||||
project::AgentRegistryStore::init_global(
|
||||
cx,
|
||||
|
|
|
|||
Loading…
Reference in a new issue