use anyhow::Result; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, TaskExt, Window}; use http_client::HttpClient; use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, env_var, }; use open_ai::ResponseStreamEvent; pub use settings::XaiAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; use x_ai::XAI_API_URL; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY"; static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default, Clone, Debug, PartialEq)] pub struct XAiSettings { pub api_url: String, pub available_models: Vec, } pub struct XAiLanguageModelProvider { http_client: Arc, state: Entity, } pub struct State { api_key_state: ApiKeyState, credentials_provider: Arc, } impl State { fn is_authenticated(&self) -> bool { self.api_key_state.has_key() } fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { let credentials_provider = self.credentials_provider.clone(); let api_url = XAiLanguageModelProvider::api_url(cx); self.api_key_state.store( api_url, api_key, |this| &mut this.api_key_state, credentials_provider, cx, ) } fn authenticate(&mut self, cx: &mut Context) -> Task> { let credentials_provider = self.credentials_provider.clone(); let api_url = XAiLanguageModelProvider::api_url(cx); self.api_key_state.load_if_needed( api_url, |this| &mut this.api_key_state, credentials_provider, cx, ) } } impl XAiLanguageModelProvider { pub fn new( http_client: Arc, credentials_provider: Arc, cx: &mut App, ) -> Self { let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { let credentials_provider = this.credentials_provider.clone(); let api_url = Self::api_url(cx); this.api_key_state.handle_url_change( api_url, |this| &mut this.api_key_state, credentials_provider, cx, ); cx.notify(); }) .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()), credentials_provider, } }); Self { http_client, state } } fn create_language_model(&self, model: x_ai::Model) -> Arc { Arc::new(XAiLanguageModel { id: LanguageModelId::from(model.id().to_string()), model, state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) } fn settings(cx: &App) -> &XAiSettings { &crate::AllLanguageModelSettings::get_global(cx).x_ai } fn api_url(cx: &App) -> SharedString { let api_url = &Self::settings(cx).api_url; if api_url.is_empty() { XAI_API_URL.into() } else { SharedString::new(api_url.as_str()) } } } impl LanguageModelProviderState for XAiLanguageModelProvider { type ObservableEntity = State; fn observable_entity(&self) -> Option> { Some(self.state.clone()) } } impl LanguageModelProvider for XAiLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { PROVIDER_NAME } fn icon(&self) -> IconOrSvg { IconOrSvg::Icon(IconName::AiXAi) } fn default_model(&self, _cx: &App) -> Option> { Some(self.create_language_model(x_ai::Model::default())) } fn default_fast_model(&self, _cx: &App) -> Option> { Some(self.create_language_model(x_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); for model in x_ai::Model::iter() { if !matches!(model, x_ai::Model::Custom { .. }) { models.insert(model.id().to_string(), model); } } for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), x_ai::Model::Custom { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, supports_images: model.supports_images, supports_tools: model.supports_tools, parallel_tool_calls: model.parallel_tool_calls, }, ); } models .into_values() .map(|model| self.create_language_model(model)) .collect() } fn is_authenticated(&self, cx: &App) -> bool { self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } fn configuration_view( &self, _target_agent: language_model::ConfigurationViewTargetAgent, window: &mut Window, cx: &mut App, ) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } fn reset_credentials(&self, cx: &mut App) -> Task> { self.state .update(cx, |state, cx| state.set_api_key(None, cx)) } } pub struct XAiLanguageModel { id: LanguageModelId, model: x_ai::Model, state: Entity, http_client: Arc, request_limiter: RateLimiter, } impl XAiLanguageModel { fn stream_completion( &self, request: open_ai::Request, cx: &AsyncApp, ) -> BoxFuture< 'static, Result< futures::stream::BoxStream<'static, Result>, LanguageModelCompletionError, >, > { let http_client = self.http_client.clone(); let (api_key, api_url) = self.state.read_with(cx, |state, cx| { let api_url = XAiLanguageModelProvider::api_url(cx); (state.api_key_state.key(&api_url), api_url) }); let future = self.request_limiter.stream(async move { let provider = PROVIDER_NAME; let Some(api_key) = api_key else { return Err(LanguageModelCompletionError::NoApiKey { provider }); }; let request = open_ai::stream_completion( http_client.as_ref(), provider.0.as_str(), &api_url, &api_key, request, ); let response = request.await?; Ok(response) }); async move { Ok(future.await?.boxed()) }.boxed() } } fn x_ai_reasoning_efforts(model: &x_ai::Model) -> &'static [open_ai::ReasoningEffort] { if model.supports_reasoning_effort() { &[ open_ai::ReasoningEffort::None, open_ai::ReasoningEffort::Low, open_ai::ReasoningEffort::Medium, open_ai::ReasoningEffort::High, ] } else { &[] } } fn default_thinking_reasoning_effort(model: &x_ai::Model) -> Option { if model.supports_reasoning_effort() { Some(open_ai::ReasoningEffort::Low) } else { None } } fn reasoning_effort_for_request( request: &LanguageModelRequest, model: &x_ai::Model, ) -> Option { let supported_efforts = x_ai_reasoning_efforts(model); if supported_efforts.is_empty() { return None; } if request.thinking_allowed { request .thinking_effort .as_deref() .and_then(|effort| effort.parse::().ok()) .filter(|effort| supported_efforts.contains(effort)) .filter(|effort| *effort != open_ai::ReasoningEffort::None) .or_else(|| default_thinking_reasoning_effort(model)) } else if supported_efforts.contains(&open_ai::ReasoningEffort::None) { Some(open_ai::ReasoningEffort::None) } else { None } } fn supported_thinking_effort_levels(model: &x_ai::Model) -> Vec { let default_effort = default_thinking_reasoning_effort(model); x_ai_reasoning_efforts(model) .iter() .copied() .filter_map(|effort| { let (name, value) = match effort { open_ai::ReasoningEffort::None => return None, open_ai::ReasoningEffort::Minimal => ("Minimal", "minimal"), open_ai::ReasoningEffort::Low => ("Low", "low"), open_ai::ReasoningEffort::Medium => ("Medium", "medium"), open_ai::ReasoningEffort::High => ("High", "high"), open_ai::ReasoningEffort::XHigh => ("Extra High", "xhigh"), }; Some(LanguageModelEffortLevel { name: name.into(), value: value.into(), is_default: Some(effort) == default_effort, }) }) .collect() } impl LanguageModel for XAiLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() } fn name(&self) -> LanguageModelName { LanguageModelName::from(self.model.display_name().to_string()) } fn provider_id(&self) -> LanguageModelProviderId { PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { PROVIDER_NAME } fn supports_tools(&self) -> bool { self.model.supports_tool() } fn supports_images(&self) -> bool { self.model.supports_images() } fn supports_streaming_tools(&self) -> bool { true } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any | LanguageModelToolChoice::None => true, } } fn supports_thinking(&self) -> bool { self.model.supports_reasoning_effort() } fn supported_effort_levels(&self) -> Vec { supported_thinking_effort_levels(&self.model) } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { if self.model.requires_json_schema_subset() { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema } } fn telemetry_id(&self) -> String { format!("x_ai/{}", self.model.id()) } fn max_token_count(&self) -> u64 { self.model.max_token_count() } fn max_output_tokens(&self) -> Option { self.model.max_output_tokens() } fn supports_split_token_display(&self) -> bool { true } fn stream_completion( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture< 'static, Result< futures::stream::BoxStream< 'static, Result, >, LanguageModelCompletionError, >, > { let reasoning_effort = reasoning_effort_for_request(&request, &self.model); let request = crate::provider::open_ai::into_open_ai( request, self.model.id(), self.model.supports_parallel_tool_calls(), self.model.supports_prompt_cache_key(), self.max_output_tokens(), reasoning_effort, false, ); let completions = self.stream_completion(request, cx); async move { let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) } .boxed() } } struct ConfigurationView { api_key_editor: Entity, state: Entity, load_credentials_task: Option>, } impl ConfigurationView { fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { let api_key_editor = cx.new(|cx| { InputField::new( window, cx, "xai-0000000000000000000000000000000000000000000000000", ) .label("API key") }); cx.observe(&state, |_, _, cx| { cx.notify(); }) .detach(); let load_credentials_task = Some(cx.spawn_in(window, { let state = state.clone(); async move |this, cx| { if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) { // We don't log an error, because "not signed in" is also an error. let _ = task.await; } this.update(cx, |this, cx| { this.load_credentials_task = None; cx.notify(); }) .log_err(); } })); Self { api_key_editor, state, load_credentials_task, } } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } // url changes can cause the editor to be displayed again self.api_key_editor .update(cx, |editor, cx| editor.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state .update(cx, |state, cx| state.set_api_key(Some(api_key), cx)) .await }) .detach_and_log_err(cx); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { self.api_key_editor .update(cx, |input, cx| input.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state .update(cx, |state, cx| state.set_api_key(None, cx)) .await }) .detach_and_log_err(cx); } fn should_render_editor(&self, cx: &mut Context) -> bool { !self.state.read(cx).is_authenticated() } } #[cfg(test)] mod tests { use super::*; #[test] fn grok_43_supports_selectable_thinking_effort_levels() { let effort_levels = supported_thinking_effort_levels(&x_ai::Model::Grok43); let values = effort_levels .iter() .map(|level| level.value.as_ref()) .collect::>(); assert_eq!(values, ["low", "medium", "high"]); assert_eq!( effort_levels .iter() .find(|level| level.is_default) .map(|level| level.value.as_ref()), Some("low") ); } #[test] fn grok_43_request_uses_selected_reasoning_effort() { let request = LanguageModelRequest { thinking_allowed: true, thinking_effort: Some("high".to_string()), ..Default::default() }; assert_eq!( reasoning_effort_for_request(&request, &x_ai::Model::Grok43), Some(open_ai::ReasoningEffort::High) ); } #[test] fn grok_43_request_uses_none_when_thinking_is_disabled() { let request = LanguageModelRequest { thinking_allowed: false, ..Default::default() }; assert_eq!( reasoning_effort_for_request(&request, &x_ai::Model::Grok43), Some(open_ai::ReasoningEffort::None) ); } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let configured_card_label = if env_var_set { format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { let api_url = XAiLanguageModelProvider::api_url(cx); if api_url == XAI_API_URL { "API key configured".to_string() } else { format!("API key configured for {}", api_url) } }; let api_key_section = if self.should_render_editor(cx) { v_flex() .on_action(cx.listener(Self::save_api_key)) .child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:")) .child( List::new() .child( ListBulletItem::new("") .child(Label::new("Create one by visiting")) .child(ButtonLink::new("xAI console", "https://console.x.ai/team/default/api-keys")) ) .child( ListBulletItem::new("Paste your API key below and hit enter to start using the agent") ), ) .child(self.api_key_editor.clone()) .child( Label::new(format!( "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), ) .child( Label::new("Note that xAI is a custom OpenAI-compatible provider.") .size(LabelSize::Small) .color(Color::Muted), ) .into_any_element() } else { ConfiguredApiCard::new(configured_card_label) .disabled(env_var_set) .when(env_var_set, |this| { this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) .into_any_element() }; if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials…")).into_any() } else { v_flex().size_full().child(api_key_section).into_any() } } }