zed/crates/language_models/src/provider/x_ai.rs
Ben Brandt 2eafa6e6aa
language_models: Remove unused language model token counting (#54177)
Drop the `count_tokens` API and related implementations across
providers, and remove the unused `tiktoken-rs` dependency.

I was going to update the dependency becuase they finally released a fix
we needed. But then I realized we only used this api in one place, the
Rules library. And for most models it would have been wildly incorrect
becuase we use tiktoken, i.e. OpenAI tokenizers, for almost every model,
which is going to give incorrect results.

Given that, I just removed these because the difference in how we get
these has caused plenty of confusion in the past.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- N/A
2026-04-22 13:39:48 +00:00

490 lines
16 KiB
Rust

use anyhow::Result;
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
env_var,
};
use open_ai::ResponseStreamEvent;
pub use settings::XaiAvailableModel as AvailableModel;
use settings::{Settings, SettingsStore};
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use x_ai::XAI_API_URL;
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct XAiSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
pub struct XAiLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: Entity<State>,
}
pub struct State {
api_key_state: ApiKeyState,
credentials_provider: Arc<dyn CredentialsProvider>,
}
impl State {
fn is_authenticated(&self) -> bool {
self.api_key_state.has_key()
}
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state.store(
api_url,
api_key,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let credentials_provider = self.credentials_provider.clone();
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
)
}
}
impl XAiLanguageModelProvider {
pub fn new(
http_client: Arc<dyn HttpClient>,
credentials_provider: Arc<dyn CredentialsProvider>,
cx: &mut App,
) -> Self {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let credentials_provider = this.credentials_provider.clone();
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
|this| &mut this.api_key_state,
credentials_provider,
cx,
);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
credentials_provider,
}
});
Self { http_client, state }
}
fn create_language_model(&self, model: x_ai::Model) -> Arc<dyn LanguageModel> {
Arc::new(XAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
fn settings(cx: &App) -> &XAiSettings {
&crate::AllLanguageModelSettings::get_global(cx).x_ai
}
fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
XAI_API_URL.into()
} else {
SharedString::new(api_url.as_str())
}
}
}
impl LanguageModelProviderState for XAiLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for XAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
PROVIDER_NAME
}
fn icon(&self) -> IconOrSvg {
IconOrSvg::Icon(IconName::AiXAi)
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(x_ai::Model::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(x_ai::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
for model in x_ai::Model::iter() {
if !matches!(model, x_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
x_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
supports_images: model.supports_images,
supports_tools: model.supports_tools,
parallel_tool_calls: model.parallel_tool_calls,
},
);
}
models
.into_values()
.map(|model| self.create_language_model(model))
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state
.update(cx, |state, cx| state.set_api_key(None, cx))
}
}
pub struct XAiLanguageModel {
id: LanguageModelId,
model: x_ai::Model,
state: Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl XAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
LanguageModelCompletionError,
>,
> {
let http_client = self.http_client.clone();
let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
let api_url = XAiLanguageModelProvider::api_url(cx);
(state.api_key_state.key(&api_url), api_url)
});
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
};
let request = open_ai::stream_completion(
http_client.as_ref(),
provider.0.as_str(),
&api_url,
&api_key,
request,
);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for XAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
self.model.supports_tool()
}
fn supports_images(&self) -> bool {
self.model.supports_images()
}
fn supports_streaming_tools(&self) -> bool {
true
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
if self.model.requires_json_schema_subset() {
LanguageModelToolSchemaFormat::JsonSchemaSubset
} else {
LanguageModelToolSchemaFormat::JsonSchema
}
}
fn telemetry_id(&self) -> String {
format!("x_ai/{}", self.model.id())
}
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
fn supports_split_token_display(&self) -> bool {
true
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = crate::provider::open_ai::into_open_ai(
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.model.supports_prompt_cache_key(),
self.max_output_tokens(),
None,
false,
);
let completions = self.stream_completion(request, cx);
async move {
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
}
}
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor = cx.new(|cx| {
InputField::new(
window,
cx,
"xai-0000000000000000000000000000000000000000000000000",
)
.label("API key")
});
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
api_key_editor,
state,
load_credentials_task,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
.await
})
.detach_and_log_err(cx);
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor
.update(cx, |input, cx| input.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(None, cx))
.await
})
.detach_and_log_err(cx);
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let configured_card_label = if env_var_set {
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
} else {
let api_url = XAiLanguageModelProvider::api_url(cx);
if api_url == XAI_API_URL {
"API key configured".to_string()
} else {
format!("API key configured for {}", api_url)
}
};
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("xAI console", "https://console.x.ai/team/default/api-keys"))
)
.child(
ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
),
)
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
"You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
Label::new("Note that xAI is a custom OpenAI-compatible provider.")
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any_element()
} else {
ConfiguredApiCard::new(configured_card_label)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
.into_any_element()
};
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
} else {
v_flex().size_full().child(api_key_section).into_any()
}
}
}