language_models: Refactor deps and extract cloud (#53270)

- `language_model` no longer depends on provider-specific crates such as
`anthropic` and `open_ai` (inverted dependency)
- `language_model_core` was extracted from `language_model` which
contains the types for the provider-specific crates to convert to/from.
- `gpui::SharedString` has been extracted into its own crate (still
exposed by `gpui`), so `language_model_core` and provider API crates
don't have to depend on `gpui`.
- Removes some unnecessary `&'static str` | `SharedString` -> `String`
-> `SharedString` conversions across the codebase.
- Extracts the core logic of the cloud `LanguageModelProvider` into its
own crate with simpler dependencies.


Release Notes:

- N/A

---------

Co-authored-by: John Tur <john-tur@outlook.com>
This commit is contained in:
Agus Zubiaga 2026-04-07 12:28:19 -03:00 committed by GitHub
parent a856093cca
commit 98c17ca160
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
95 changed files with 5895 additions and 5995 deletions

89
Cargo.lock generated
View file

@ -629,13 +629,17 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"collections",
"futures 0.3.32",
"http_client",
"language_model_core",
"log",
"schemars",
"serde",
"serde_json",
"strum 0.27.2",
"thiserror 2.0.17",
"tiktoken-rs",
]
[[package]]
@ -2903,7 +2907,6 @@ dependencies = [
"http_client",
"http_client_tls",
"httparse",
"language_model",
"log",
"objc2-foundation",
"parking_lot",
@ -2959,6 +2962,7 @@ dependencies = [
"http_client",
"parking_lot",
"serde_json",
"smol",
"thiserror 2.0.17",
"yawc",
]
@ -5162,6 +5166,7 @@ dependencies = [
"buffer_diff",
"client",
"clock",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
@ -5641,7 +5646,7 @@ dependencies = [
name = "env_var"
version = "0.1.0"
dependencies = [
"gpui",
"gpui_shared_string",
]
[[package]]
@ -7468,11 +7473,13 @@ dependencies = [
"anyhow",
"futures 0.3.32",
"http_client",
"language_model_core",
"log",
"schemars",
"serde",
"serde_json",
"settings",
"strum 0.27.2",
"tiktoken-rs",
]
[[package]]
@ -7541,6 +7548,7 @@ dependencies = [
"getrandom 0.3.4",
"gpui_macros",
"gpui_platform",
"gpui_shared_string",
"gpui_util",
"gpui_web",
"http_client",
@ -7710,6 +7718,16 @@ dependencies = [
"gpui_windows",
]
[[package]]
name = "gpui_shared_string"
version = "0.1.0"
dependencies = [
"derive_more",
"gpui_util",
"schemars",
"serde",
]
[[package]]
name = "gpui_tokio"
version = "0.1.0"
@ -9358,7 +9376,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"gpui",
"gpui_shared_string",
"log",
"lsp",
"parking_lot",
@ -9397,12 +9415,8 @@ dependencies = [
name = "language_model"
version = "0.1.0"
dependencies = [
"anthropic",
"anyhow",
"base64 0.22.1",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
"env_var",
@ -9411,16 +9425,31 @@ dependencies = [
"http_client",
"icons",
"image",
"language_model_core",
"log",
"open_ai",
"open_router",
"parking_lot",
"serde",
"serde_json",
"thiserror 2.0.17",
"util",
]
[[package]]
name = "language_model_core"
version = "0.1.0"
dependencies = [
"anyhow",
"cloud_llm_client",
"futures 0.3.32",
"gpui_shared_string",
"http_client",
"partial-json-fixer",
"schemars",
"serde",
"serde_json",
"smol",
"strum 0.27.2",
"thiserror 2.0.17",
"util",
]
[[package]]
@ -9436,8 +9465,8 @@ dependencies = [
"base64 0.22.1",
"bedrock",
"client",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"component",
"convert_case 0.8.0",
@ -9456,6 +9485,7 @@ dependencies = [
"http_client",
"language",
"language_model",
"language_models_cloud",
"lmstudio",
"log",
"menu",
@ -9464,17 +9494,14 @@ dependencies = [
"open_ai",
"open_router",
"opencode",
"partial-json-fixer",
"pretty_assertions",
"release_channel",
"schemars",
"semver",
"serde",
"serde_json",
"settings",
"smol",
"strum 0.27.2",
"thiserror 2.0.17",
"tiktoken-rs",
"tokio",
"ui",
@ -9484,6 +9511,28 @@ dependencies = [
"x_ai",
]
[[package]]
name = "language_models_cloud"
version = "0.1.0"
dependencies = [
"anthropic",
"anyhow",
"cloud_llm_client",
"futures 0.3.32",
"google_ai",
"gpui",
"http_client",
"language_model",
"open_ai",
"schemars",
"semver",
"serde",
"serde_json",
"smol",
"thiserror 2.0.17",
"x_ai",
]
[[package]]
name = "language_onboarding"
version = "0.1.0"
@ -11631,16 +11680,19 @@ name = "open_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"futures 0.3.32",
"http_client",
"language_model_core",
"log",
"pretty_assertions",
"rand 0.9.2",
"schemars",
"serde",
"serde_json",
"settings",
"strum 0.27.2",
"thiserror 2.0.17",
"tiktoken-rs",
]
[[package]]
@ -11672,6 +11724,7 @@ dependencies = [
"anyhow",
"futures 0.3.32",
"http_client",
"language_model_core",
"schemars",
"serde",
"serde_json",
@ -15801,6 +15854,7 @@ dependencies = [
"collections",
"derive_more",
"gpui",
"language_model_core",
"log",
"schemars",
"serde",
@ -20180,6 +20234,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"futures 0.3.32",
@ -21783,9 +21838,11 @@ name = "x_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"language_model_core",
"schemars",
"serde",
"strum 0.27.2",
"tiktoken-rs",
]
[[package]]

View file

@ -87,6 +87,7 @@ members = [
"crates/google_ai",
"crates/grammars",
"crates/gpui",
"crates/gpui_shared_string",
"crates/gpui_linux",
"crates/gpui_macos",
"crates/gpui_macros",
@ -110,7 +111,9 @@ members = [
"crates/language_core",
"crates/language_extension",
"crates/language_model",
"crates/language_model_core",
"crates/language_models",
"crates/language_models_cloud",
"crates/language_onboarding",
"crates/language_selector",
"crates/language_tools",
@ -335,6 +338,7 @@ go_to_line = { path = "crates/go_to_line" }
google_ai = { path = "crates/google_ai" }
grammars = { path = "crates/grammars" }
gpui = { path = "crates/gpui", default-features = false }
gpui_shared_string = { path = "crates/gpui_shared_string" }
gpui_linux = { path = "crates/gpui_linux", default-features = false }
gpui_macos = { path = "crates/gpui_macos", default-features = false }
gpui_macros = { path = "crates/gpui_macros" }
@ -361,7 +365,9 @@ language = { path = "crates/language" }
language_core = { path = "crates/language_core" }
language_extension = { path = "crates/language_extension" }
language_model = { path = "crates/language_model" }
language_model_core = { path = "crates/language_model_core" }
language_models = { path = "crates/language_models" }
language_models_cloud = { path = "crates/language_models_cloud" }
language_onboarding = { path = "crates/language_onboarding" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }

View file

@ -5,7 +5,7 @@ use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task};
use indoc::formatdoc;
use language::Point;
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use language_model::{LanguageModelImage, LanguageModelImageExt, LanguageModelToolResultContent};
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

View file

@ -325,7 +325,7 @@ impl AcpConnection {
// Use the one the agent provides if we have one
.map(|info| info.name.into())
// Otherwise, just use the name
.unwrap_or_else(|| agent_id.0.to_string().into());
.unwrap_or_else(|| agent_id.0.clone());
let session_list = if response
.agent_capabilities

View file

@ -382,7 +382,7 @@ impl AgentRegistryPage {
self.install_button(agent, install_status, supports_current_platform, cx);
let repository_button = agent.repository().map(|repository| {
let repository_for_tooltip: SharedString = repository.to_string().into();
let repository_for_tooltip = repository.clone();
let repository_for_click = repository.to_string();
IconButton::new(

View file

@ -18,7 +18,7 @@ use gpui::{
use http_client::{AsyncBody, HttpClientWithUrl};
use itertools::Either;
use language::Buffer;
use language_model::LanguageModelImage;
use language_model::{LanguageModelImage, LanguageModelImageExt};
use multi_buffer::MultiBufferRow;
use postage::stream::Stream as _;
use project::{Project, ProjectItem, ProjectPath, Worktree};

View file

@ -18,12 +18,16 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
collections.workspace = true
futures.workspace = true
http_client.workspace = true
language_model_core.workspace = true
log.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true

View file

@ -12,6 +12,7 @@ use strum::{EnumIter, EnumString};
use thiserror::Error;
pub mod batches;
pub mod completion;
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
@ -1026,6 +1027,89 @@ pub async fn count_tokens(
}
}
// -- Conversions from/to `language_model_core` types --
impl From<language_model_core::Speed> for Speed {
fn from(speed: language_model_core::Speed) -> Self {
match speed {
language_model_core::Speed::Standard => Speed::Standard,
language_model_core::Speed::Fast => Speed::Fast,
}
}
}
impl From<AnthropicError> for language_model_core::LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
let provider = language_model_core::ANTHROPIC_PROVIDER_NAME;
match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
AnthropicError::ApiError(api_error) => api_error.into(),
}
}
}
impl From<ApiError> for language_model_core::LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
use ApiErrorCode::*;
let provider = language_model_core::ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
RequestTooLarge => Self::PromptTooLarge {
tokens: language_model_core::parse_prompt_too_long(&error.message),
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
},
None => Self::Other(error.into()),
}
}
}
#[test]
fn test_match_window_exceeded() {
let error = ApiError {

View file

@ -0,0 +1,765 @@
use anyhow::Result;
use collections::HashMap;
use futures::{Stream, StreamExt};
use language_model_core::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
Role, StopReason, TokenUsage,
util::{fix_streamed_json, parse_tool_arguments},
};
use std::pin::Pin;
use std::str::FromStr;
use crate::{
AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta,
CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent,
StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage,
};
fn to_anthropic_content(content: MessageContent) -> Option<RequestContent> {
match content {
MessageContent::Text(text) => {
let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
text.trim_end().to_string()
} else {
text
};
if !text.is_empty() {
Some(RequestContent::Text {
text,
cache_control: None,
})
} else {
None
}
}
MessageContent::Thinking {
text: thinking,
signature,
} => {
if let Some(signature) = signature
&& !thinking.is_empty()
{
Some(RequestContent::Thinking {
thinking,
signature,
cache_control: None,
})
} else {
None
}
}
MessageContent::RedactedThinking(data) => {
if !data.is_empty() {
Some(RequestContent::RedactedThinking { data })
} else {
None
}
}
MessageContent::Image(image) => Some(RequestContent::Image {
source: ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control: None,
}),
MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse {
id: tool_use.id.to_string(),
name: tool_use.name.to_string(),
input: tool_use.input,
cache_control: None,
}),
MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id.to_string(),
is_error: tool_result.is_error,
content: match tool_result.content {
LanguageModelToolResultContent::Text(text) => {
ToolResultContent::Plain(text.to_string())
}
LanguageModelToolResultContent::Image(image) => {
ToolResultContent::Multipart(vec![ToolResultPart::Image {
source: ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
}])
}
},
cache_control: None,
}),
}
}
/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
pub fn into_anthropic_count_tokens_request(
request: LanguageModelRequest,
model: String,
mode: AnthropicModelMode,
) -> CountTokensRequest {
let mut new_messages: Vec<Message> = Vec::new();
let mut system_message = String::new();
for message in request.messages {
if message.contents_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let anthropic_message_content: Vec<RequestContent> = message
.content
.into_iter()
.filter_map(to_anthropic_content)
.collect();
let anthropic_role = match message.role {
Role::User => crate::Role::User,
Role::Assistant => crate::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if anthropic_message_content.is_empty() {
continue;
}
if let Some(last_message) = new_messages.last_mut()
&& last_message.role == anthropic_role
{
last_message.content.extend(anthropic_message_content);
continue;
}
new_messages.push(Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
CountTokensRequest {
model,
messages: new_messages,
system: if system_message.is_empty() {
None
} else {
Some(StringOrContents::String(system_message))
},
thinking: if request.thinking_allowed {
match mode {
AnthropicModelMode::Thinking { budget_tokens } => {
Some(Thinking::Enabled { budget_tokens })
}
AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
AnthropicModelMode::Default => None,
}
} else {
None
},
tools: request
.tools
.into_iter()
.map(|tool| Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
eager_input_streaming: tool.use_input_streaming,
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => ToolChoice::Auto,
LanguageModelToolChoice::Any => ToolChoice::Any,
LanguageModelToolChoice::None => ToolChoice::None,
}),
}
}
/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
let messages = request.messages;
let mut tokens_from_images = 0;
let mut string_messages = Vec::with_capacity(messages.len());
for message in messages {
let mut string_contents = String::new();
for content in message.content {
match content {
MessageContent::Text(text) => {
string_contents.push_str(&text);
}
MessageContent::Thinking { .. } => {
// Thinking blocks are not included in the input token count.
}
MessageContent::RedactedThinking(_) => {
// Thinking blocks are not included in the input token count.
}
MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
MessageContent::ToolUse(_tool_use) => {
// TODO: Estimate token usage from tool uses.
}
MessageContent::ToolResult(tool_result) => match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
string_contents.push_str(text);
}
LanguageModelToolResultContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
},
}
}
if !string_contents.is_empty() {
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(string_contents),
name: None,
function_call: None,
});
}
}
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
.map(|tokens| (tokens + tokens_from_images) as u64)
}
pub fn into_anthropic(
request: LanguageModelRequest,
model: String,
default_temperature: f32,
max_output_tokens: u64,
mode: AnthropicModelMode,
) -> crate::Request {
let mut new_messages: Vec<Message> = Vec::new();
let mut system_message = String::new();
for message in request.messages {
if message.contents_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let mut anthropic_message_content: Vec<RequestContent> = message
.content
.into_iter()
.filter_map(to_anthropic_content)
.collect();
let anthropic_role = match message.role {
Role::User => crate::Role::User,
Role::Assistant => crate::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if anthropic_message_content.is_empty() {
continue;
}
if let Some(last_message) = new_messages.last_mut()
&& last_message.role == anthropic_role
{
last_message.content.extend(anthropic_message_content);
continue;
}
// Mark the last segment of the message as cached
if message.cache {
let cache_control_value = Some(CacheControl {
cache_type: CacheControlType::Ephemeral,
});
for message_content in anthropic_message_content.iter_mut().rev() {
match message_content {
RequestContent::RedactedThinking { .. } => {
// Caching is not possible, fallback to next message
}
RequestContent::Text { cache_control, .. }
| RequestContent::Thinking { cache_control, .. }
| RequestContent::Image { cache_control, .. }
| RequestContent::ToolUse { cache_control, .. }
| RequestContent::ToolResult { cache_control, .. } => {
*cache_control = cache_control_value;
break;
}
}
}
}
new_messages.push(Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
crate::Request {
model,
messages: new_messages,
max_tokens: max_output_tokens,
system: if system_message.is_empty() {
None
} else {
Some(StringOrContents::String(system_message))
},
thinking: if request.thinking_allowed {
match mode {
AnthropicModelMode::Thinking { budget_tokens } => {
Some(Thinking::Enabled { budget_tokens })
}
AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
AnthropicModelMode::Default => None,
}
} else {
None
},
tools: request
.tools
.into_iter()
.map(|tool| Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
eager_input_streaming: tool.use_input_streaming,
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => ToolChoice::Auto,
LanguageModelToolChoice::Any => ToolChoice::Any,
LanguageModelToolChoice::None => ToolChoice::None,
}),
metadata: None,
output_config: if request.thinking_allowed
&& matches!(mode, AnthropicModelMode::AdaptiveThinking)
{
request.thinking_effort.as_deref().and_then(|effort| {
let effort = match effort {
"low" => Some(crate::Effort::Low),
"medium" => Some(crate::Effort::Medium),
"high" => Some(crate::Effort::High),
"max" => Some(crate::Effort::Max),
_ => None,
};
effort.map(|effort| crate::OutputConfig {
effort: Some(effort),
})
})
} else {
None
},
stop_sequences: Vec::new(),
speed: request.speed.map(Into::into),
temperature: request.temperature.or(Some(default_temperature)),
top_k: None,
top_p: None,
}
}
pub struct AnthropicEventMapper {
tool_uses_by_index: HashMap<usize, RawToolUse>,
usage: Usage,
stop_reason: StopReason,
}
impl AnthropicEventMapper {
pub fn new() -> Self {
Self {
tool_uses_by_index: HashMap::default(),
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(error.into())],
})
})
}
pub fn map_event(
&mut self,
event: Event,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
vec![Ok(LanguageModelCompletionEvent::Text(text))]
}
ResponseContent::Thinking { thinking } => {
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking,
signature: None,
})]
}
ResponseContent::RedactedThinking { data } => {
vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
}
ResponseContent::ToolUse { id, name, .. } => {
self.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
Vec::new()
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
vec![Ok(LanguageModelCompletionEvent::Text(text))]
}
ContentDelta::ThinkingDelta { thinking } => {
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking,
signature: None,
})]
}
ContentDelta::SignatureDelta { signature } => {
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: "".to_string(),
signature: Some(signature),
})]
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
// Try to convert invalid (incomplete) JSON into
// valid JSON that serde can accept, e.g. by closing
// unclosed delimiters. This way, we can update the
// UI with whatever has been streamed back so far.
if let Ok(input) =
serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
{
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.clone().into(),
name: tool_use.name.clone().into(),
is_input_complete: false,
raw_input: tool_use.input_json.clone(),
input,
thought_signature: None,
},
))];
}
}
vec![]
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim();
let event_result = match parse_tool_arguments(input_json) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
thought_signature: None,
},
)),
Err(json_parse_err) => {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
vec![event_result]
} else {
Vec::new()
}
}
Event::MessageStart { message } => {
update_usage(&mut self.usage, &message.usage);
vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&self.usage,
))),
Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id,
}),
]
}
Event::MessageDelta { delta, usage } => {
update_usage(&mut self.usage, &usage);
if let Some(stop_reason) = delta.stop_reason.as_deref() {
self.stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
"refusal" => StopReason::Refusal,
_ => {
log::error!("Unexpected anthropic stop_reason: {stop_reason}");
StopReason::EndTurn
}
};
}
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
))]
}
Event::MessageStop => {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
}
Event::Error { error } => {
vec![Err(error.into())]
}
_ => Vec::new(),
}
}
}
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {
usage.input_tokens = Some(input_tokens);
}
if let Some(output_tokens) = new.output_tokens {
usage.output_tokens = Some(output_tokens);
}
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
}
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
}
}
fn convert_usage(usage: &Usage) -> TokenUsage {
TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AnthropicModelMode;
use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
#[test]
fn test_cache_control_only_on_last_segment() {
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![
MessageContent::Text("Some prompt".to_string()),
MessageContent::Image(LanguageModelImage::empty()),
MessageContent::Image(LanguageModelImage::empty()),
MessageContent::Image(LanguageModelImage::empty()),
MessageContent::Image(LanguageModelImage::empty()),
],
cache: true,
reasoning_details: None,
}],
thread_id: None,
prompt_id: None,
intent: None,
stop: vec![],
temperature: None,
tools: vec![],
tool_choice: None,
thinking_allowed: true,
thinking_effort: None,
speed: None,
};
let anthropic_request = into_anthropic(
request,
"claude-3-5-sonnet".to_string(),
0.7,
4096,
AnthropicModelMode::Default,
);
assert_eq!(anthropic_request.messages.len(), 1);
let message = &anthropic_request.messages[0];
assert_eq!(message.content.len(), 5);
assert!(matches!(
message.content[0],
RequestContent::Text {
cache_control: None,
..
}
));
for i in 1..3 {
assert!(matches!(
message.content[i],
RequestContent::Image {
cache_control: None,
..
}
));
}
assert!(matches!(
message.content[4],
RequestContent::Image {
cache_control: Some(CacheControl {
cache_type: CacheControlType::Ephemeral,
}),
..
}
));
}
fn request_with_assistant_content(assistant_content: Vec<MessageContent>) -> crate::Request {
let mut request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text("Hello".to_string())],
cache: false,
reasoning_details: None,
}],
thinking_effort: None,
thread_id: None,
prompt_id: None,
intent: None,
stop: vec![],
temperature: None,
tools: vec![],
tool_choice: None,
thinking_allowed: true,
speed: None,
};
request.messages.push(LanguageModelRequestMessage {
role: Role::Assistant,
content: assistant_content,
cache: false,
reasoning_details: None,
});
into_anthropic(
request,
"claude-sonnet-4-5".to_string(),
1.0,
16000,
AnthropicModelMode::Thinking {
budget_tokens: Some(10000),
},
)
}
#[test]
fn test_unsigned_thinking_blocks_stripped() {
let result = request_with_assistant_content(vec![
MessageContent::Thinking {
text: "Cancelled mid-think, no signature".to_string(),
signature: None,
},
MessageContent::Text("Some response text".to_string()),
]);
let assistant_message = result
.messages
.iter()
.find(|m| m.role == crate::Role::Assistant)
.expect("assistant message should still exist");
assert_eq!(
assistant_message.content.len(),
1,
"Only the text content should remain; unsigned thinking block should be stripped"
);
assert!(matches!(
&assistant_message.content[0],
RequestContent::Text { text, .. } if text == "Some response text"
));
}
#[test]
fn test_signed_thinking_blocks_preserved() {
let result = request_with_assistant_content(vec![
MessageContent::Thinking {
text: "Completed thinking".to_string(),
signature: Some("valid-signature".to_string()),
},
MessageContent::Text("Response".to_string()),
]);
let assistant_message = result
.messages
.iter()
.find(|m| m.role == crate::Role::Assistant)
.expect("assistant message should exist");
assert_eq!(
assistant_message.content.len(),
2,
"Both the signed thinking block and text should be preserved"
);
assert!(matches!(
&assistant_message.content[0],
RequestContent::Thinking { thinking, signature, .. }
if thinking == "Completed thinking" && signature == "valid-signature"
));
}
#[test]
fn test_only_unsigned_thinking_block_omits_entire_message() {
let result = request_with_assistant_content(vec![MessageContent::Thinking {
text: "Cancelled before any text or signature".to_string(),
signature: None,
}]);
let assistant_messages: Vec<_> = result
.messages
.iter()
.filter(|m| m.role == crate::Role::Assistant)
.collect();
assert_eq!(
assistant_messages.len(),
0,
"An assistant message whose only content was an unsigned thinking block \
should be omitted entirely"
);
}
}

View file

@ -36,7 +36,6 @@ gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
paths.workspace = true

View file

@ -14,6 +14,7 @@ use async_tungstenite::tungstenite::{
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
use cloud_api_client::LlmApiToken;
use cloud_api_client::websocket_protocol::MessageToClient;
use cloud_api_client::{ClientApiError, CloudApiClient};
use cloud_api_types::OrganizationId;
@ -26,7 +27,6 @@ use futures::{
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
use language_model::LlmApiToken;
use parking_lot::{Mutex, RwLock};
use postage::watch;
use proxy::connect_proxy_stream;

View file

@ -1,10 +1,10 @@
use super::{Client, UserStore};
use cloud_api_client::LlmApiToken;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
};
use language_model::LlmApiToken;
use std::sync::Arc;
pub trait NeedsLlmTokenRefresh {

View file

@ -20,5 +20,6 @@ gpui_tokio.workspace = true
http_client.workspace = true
parking_lot.workspace = true
serde_json.workspace = true
smol.workspace = true
thiserror.workspace = true
yawc.workspace = true

View file

@ -1,3 +1,4 @@
mod llm_token;
mod websocket;
use std::sync::Arc;
@ -18,6 +19,8 @@ use yawc::WebSocket;
use crate::websocket::Connection;
pub use llm_token::LlmApiToken;
struct Credentials {
user_id: u32,
access_token: String,

View file

@ -0,0 +1,74 @@
use std::sync::Arc;
use cloud_api_types::OrganizationId;
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use crate::{ClientApiError, CloudApiClient};
#[derive(Clone, Default)]
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
pub async fn acquire(
&self,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
Self::fetch(
RwLockUpgradableReadGuard::upgrade(lock).await,
client,
system_id,
organization_id,
)
.await
}
}
pub async fn refresh(
&self,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
Self::fetch(self.0.write().await, client, system_id, organization_id).await
}
/// Clears the existing token before attempting to fetch a new one.
///
/// Used when switching organizations so that a failed refresh doesn't
/// leave a token for the wrong organization.
pub async fn clear_and_refresh(
&self,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
let mut lock = self.0.write().await;
*lock = None;
Self::fetch(lock, client, system_id, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
let result = client.create_llm_token(system_id, organization_id).await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
Ok(response.token.0)
}
Err(err) => {
*lock = None;
Err(err)
}
}
}
}

View file

@ -7,6 +7,7 @@ license = "Apache-2.0"
[features]
test-support = []
predict-edits = ["dep:zeta_prompt"]
[lints]
workspace = true
@ -20,6 +21,6 @@ serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }
uuid = { workspace = true, features = ["serde"] }
zeta_prompt.workspace = true
zeta_prompt = { workspace = true, optional = true }

View file

@ -1,3 +1,4 @@
#[cfg(feature = "predict-edits")]
pub mod predict_edits_v3;
use std::str::FromStr;

View file

@ -2846,11 +2846,11 @@ impl CollabPanel {
}
};
Some(channel.name.as_ref())
Some(channel.name.clone())
});
if let Some(name) = channel_name {
SharedString::from(name.to_string())
name
} else {
SharedString::from("Current Call")
}

View file

@ -21,8 +21,9 @@ heapless.workspace = true
buffer_diff.workspace = true
client.workspace = true
clock.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
cloud_llm_client = { workspace = true, features = ["predict-edits"] }
collections.workspace = true
copilot.workspace = true
copilot_ui.workspace = true

View file

@ -1,5 +1,6 @@
use anyhow::Result;
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_client::LlmApiToken;
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@ -31,7 +32,6 @@ use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::LlmApiToken;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;

View file

@ -57,7 +57,7 @@ pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
let mut models: Vec<SharedString> = provider
.provided_models(cx)
.into_iter()
.map(|model| SharedString::from(model.id().0.to_string()))
.map(|model| model.id().0)
.collect();
models.sort();
models

View file

@ -177,7 +177,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
id: Some(prediction.id.0.clone()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
@ -228,7 +228,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
}
Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
id: Some(prediction.id.0.clone()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
cursor_position: prediction.cursor_position,
edit_preview: Some(prediction.edit_preview.clone()),

View file

@ -22,7 +22,7 @@ http_client.workspace = true
chrono.workspace = true
clap = "4"
client.workspace = true
cloud_llm_client.workspace= true
cloud_llm_client = { workspace = true, features = ["predict-edits"] }
collections.workspace = true
db.workspace = true
debug_adapter_extension.workspace = true

View file

@ -12,4 +12,4 @@ workspace = true
path = "src/env_var.rs"
[dependencies]
gpui.workspace = true
gpui_shared_string.workspace = true

View file

@ -1,4 +1,4 @@
use gpui::SharedString;
use gpui_shared_string::SharedString;
#[derive(Clone)]
pub struct EnvVar {

View file

@ -1906,7 +1906,7 @@ mod tests {
assert_eq!(
remotes,
vec![Remote {
name: SharedString::from("my_new_remote".to_string())
name: SharedString::from("my_new_remote")
}]
);
}

View file

@ -18,8 +18,10 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
language_model_core.workspace = true
log.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
strum.workspace = true
tiktoken-rs.workspace = true

View file

@ -0,0 +1,492 @@
use anyhow::Result;
use futures::{Stream, StreamExt};
use language_model_core::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
StopReason, TokenUsage,
};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{self, AtomicU64};
use crate::{
Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
UsageMetadata,
};
pub fn into_google(
mut request: LanguageModelRequest,
model_id: String,
mode: GoogleModelMode,
) -> crate::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
.into_iter()
.flat_map(|content| match content {
MessageContent::Text(text) => {
if !text.is_empty() {
vec![Part::TextPart(TextPart { text })]
} else {
vec![]
}
}
MessageContent::Thinking {
text: _,
signature: Some(signature),
} => {
if !signature.is_empty() {
vec![Part::ThoughtPart(crate::ThoughtPart {
thought: true,
thought_signature: signature,
})]
} else {
vec![]
}
}
MessageContent::Thinking { .. } => {
vec![]
}
MessageContent::RedactedThinking(_) => vec![],
MessageContent::Image(image) => {
vec![Part::InlineDataPart(InlineDataPart {
inline_data: GenerativeContentBlob {
mime_type: "image/png".to_string(),
data: image.source.to_string(),
},
})]
}
MessageContent::ToolUse(tool_use) => {
// Normalize empty string signatures to None
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
vec![Part::FunctionCallPart(crate::FunctionCallPart {
function_call: crate::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
thought_signature,
})]
}
MessageContent::ToolResult(tool_result) => {
match tool_result.content {
language_model_core::LanguageModelToolResultContent::Text(text) => {
vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
function_response: crate::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": text
}),
},
})]
}
language_model_core::LanguageModelToolResultContent::Image(image) => {
vec![
Part::FunctionResponsePart(crate::FunctionResponsePart {
function_response: crate::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": "Tool responded with an image"
}),
},
}),
Part::InlineDataPart(InlineDataPart {
inline_data: GenerativeContentBlob {
mime_type: "image/png".to_string(),
data: image.source.to_string(),
},
}),
]
}
}
}
})
.collect()
}
let system_instructions = if request
.messages
.first()
.is_some_and(|msg| matches!(msg.role, Role::System))
{
let message = request.messages.remove(0);
Some(SystemInstruction {
parts: map_content(message.content),
})
} else {
None
};
crate::GenerateContentRequest {
model: ModelName { model_id },
system_instruction: system_instructions,
contents: request
.messages
.into_iter()
.filter_map(|message| {
let parts = map_content(message.content);
if parts.is_empty() {
None
} else {
Some(Content {
parts,
role: match message.role {
Role::User => crate::Role::User,
Role::Assistant => crate::Role::Model,
Role::System => crate::Role::User, // Google AI doesn't have a system role
},
})
}
})
.collect(),
generation_config: Some(GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config: match (request.thinking_allowed, mode) {
(true, GoogleModelMode::Thinking { budget_tokens }) => {
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
}
_ => None,
},
top_p: None,
top_k: None,
}),
safety_settings: None,
tools: (!request.tools.is_empty()).then(|| {
vec![crate::Tool {
function_declarations: request
.tools
.into_iter()
.map(|tool| FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
})
.collect(),
}]
}),
tool_config: request.tool_choice.map(|choice| ToolConfig {
function_calling_config: FunctionCallingConfig {
mode: match choice {
LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
LanguageModelToolChoice::Any => FunctionCallingMode::Any,
LanguageModelToolChoice::None => FunctionCallingMode::None,
},
allowed_function_names: None,
},
}),
}
}
pub struct GoogleEventMapper {
usage: UsageMetadata,
stop_reason: StopReason,
}
impl GoogleEventMapper {
pub fn new() -> Self {
Self {
usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn,
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events
.map(Some)
.chain(futures::stream::once(async { None }))
.flat_map(move |event| {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::from(error))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})
})
}
pub fn map_event(
&mut self,
event: GenerateContentResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
)))
}
if let Some(prompt_feedback) = event.prompt_feedback
&& let Some(block_reason) = prompt_feedback.block_reason.as_deref()
{
self.stop_reason = match block_reason {
"SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
StopReason::Refusal
}
_ => {
log::error!("Unexpected Google block_reason: {block_reason}");
StopReason::Refusal
}
};
events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
return events;
}
if let Some(candidates) = event.candidates {
for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
self.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
_ => {
log::error!("Unexpected google finish_reason: {finish_reason}");
StopReason::EndTurn
}
};
}
candidate
.content
.parts
.into_iter()
.for_each(|part| match part {
Part::TextPart(text_part) => {
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
}
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
let name: Arc<str> = function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
// Normalize empty string signatures to None
let thought_signature = function_call_part
.thought_signature
.filter(|s| !s.is_empty());
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
name,
is_input_complete: true,
raw_input: function_call_part.function_call.args.to_string(),
input: function_call_part.function_call.args,
thought_signature,
},
)));
}
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
signature: Some(part.thought_signature),
}));
}
});
}
}
// Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP`
if wants_to_use_tool {
self.stop_reason = StopReason::ToolUse;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
events
}
}
/// Count tokens for a Google AI model using tiktoken. This is synchronous;
/// callers should spawn it on a background thread if needed.
pub fn count_google_tokens(request: LanguageModelRequest) -> Result<u64> {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
}
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
if let Some(prompt_token_count) = new.prompt_token_count {
usage.prompt_token_count = Some(prompt_token_count);
}
if let Some(cached_content_token_count) = new.cached_content_token_count {
usage.cached_content_token_count = Some(cached_content_token_count);
}
if let Some(candidates_token_count) = new.candidates_token_count {
usage.candidates_token_count = Some(candidates_token_count);
}
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
}
if let Some(thoughts_token_count) = new.thoughts_token_count {
usage.thoughts_token_count = Some(thoughts_token_count);
}
if let Some(total_token_count) = new.total_token_count {
usage.total_token_count = Some(total_token_count);
}
}
fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
let input_tokens = prompt_tokens - cached_tokens;
let output_tokens = usage.candidates_token_count.unwrap_or(0);
TokenUsage {
input_tokens,
output_tokens,
cache_read_input_tokens: cached_tokens,
cache_creation_input_tokens: 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
Part, Role as GoogleRole,
};
use serde_json::json;
#[test]
fn test_function_call_with_signature_creates_tool_use_with_signature() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("test_signature_123".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 2);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "test_function");
assert_eq!(
tool_use.thought_signature.as_deref(),
Some("test_signature_123")
);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_function_call_without_signature_has_none() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: None,
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 2);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert!(tool_use.thought_signature.is_none());
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_empty_string_signature_normalized_to_none() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert!(tool_use.thought_signature.is_none());
} else {
panic!("Expected ToolUse event");
}
}
}

View file

@ -3,8 +3,9 @@ use std::mem;
use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
pub use language_model_core::ModelMode as GoogleModelMode;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use settings::ModelMode as GoogleModelMode;
pub mod completion;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";

View file

@ -56,6 +56,7 @@ etagere = "0.2"
futures.workspace = true
futures-concurrency.workspace = true
gpui_macros.workspace = true
gpui_shared_string.workspace = true
http_client.workspace = true
image.workspace = true
inventory.workspace = true

View file

@ -39,7 +39,6 @@ pub mod profiler;
#[expect(missing_docs)]
pub mod queue;
mod scene;
mod shared_string;
mod shared_uri;
mod style;
mod styled;
@ -92,6 +91,7 @@ pub use global::*;
pub use gpui_macros::{
AppContext, IntoElement, Render, VisualContext, property_test, register_action, test,
};
pub use gpui_shared_string::*;
pub use gpui_util::arc_cow::ArcCow;
pub use http_client;
pub use input::*;
@ -106,7 +106,6 @@ pub use profiler::*;
pub use queue::{PriorityQueueReceiver, PriorityQueueSender};
pub use refineable::*;
pub use scene::*;
pub use shared_string::*;
pub use shared_uri::*;
use std::{any::Any, future::Future};
pub use style::*;

View file

@ -882,7 +882,7 @@ mod tests {
],
len: 6,
}),
text: SharedString::new("abcdef".to_string()),
text: "abcdef".into(),
decoration_runs: SmallVec::new(),
};

View file

@ -0,0 +1,17 @@
[package]
name = "gpui_shared_string"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
[lib]
path = "gpui_shared_string.rs"
[dependencies]
derive_more.workspace = true
gpui_util.workspace = true
schemars.workspace = true
serde.workspace = true
[lints]
workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-APACHE

View file

@ -10,7 +10,7 @@ path = "src/language_core.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
gpui_shared_string.workspace = true
log.workspace = true
lsp.workspace = true
parking_lot.workspace = true
@ -22,8 +22,6 @@ toml.workspace = true
tree-sitter.workspace = true
util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
[features]
test-support = []

View file

@ -1,4 +1,4 @@
use gpui::SharedString;
use gpui_shared_string::SharedString;
use lsp::{DiagnosticSeverity, NumberOrString};
use serde::{Deserialize, Serialize};
use serde_json::Value;

View file

@ -4,7 +4,7 @@ use crate::{
};
use anyhow::{Context as _, Result};
use collections::HashMap;
use gpui::SharedString;
use gpui_shared_string::SharedString;
use lsp::LanguageServerName;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};

View file

@ -1,6 +1,6 @@
use crate::LanguageName;
use collections::{HashMap, HashSet, IndexSet};
use gpui::SharedString;
use gpui_shared_string::SharedString;
use lsp::LanguageServerName;
use regex::Regex;
use schemars::{JsonSchema, SchemaGenerator, json_schema};

View file

@ -1,4 +1,4 @@
use gpui::SharedString;
use gpui_shared_string::SharedString;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{

View file

@ -1,4 +1,4 @@
use gpui::SharedString;
use gpui_shared_string::SharedString;
use serde::{Deserialize, Serialize};
/// Converts a value into an LSP position.

View file

@ -1,6 +1,6 @@
use std::borrow::Borrow;
use gpui::SharedString;
use gpui_shared_string::SharedString;
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ManifestName(SharedString);

View file

@ -6,7 +6,7 @@
use std::{path::Path, sync::Arc};
use gpui::SharedString;
use gpui_shared_string::SharedString;
use util::rel_path::RelPath;
use crate::{LanguageName, ManifestName};

View file

@ -16,13 +16,9 @@ doctest = false
test-support = []
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
env_var.workspace = true
futures.workspace = true
@ -30,14 +26,11 @@ gpui.workspace = true
http_client.workspace = true
icons.workspace = true
image.workspace = true
language_model_core.workspace = true
log.workspace = true
open_ai = { workspace = true, features = ["schemars"] }
open_router.workspace = true
parking_lot.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
thiserror.workspace = true
util.workspace = true

View file

@ -5,11 +5,10 @@ use crate::{
LanguageModelRequest, LanguageModelToolChoice,
};
use anyhow::anyhow;
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use smol::stream::StreamExt;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering::SeqCst},

View file

@ -1,380 +1,31 @@
mod api_key;
mod model;
mod provider;
mod rate_limiter;
mod registry;
mod request;
mod role;
pub mod tool_schema;
#[cfg(any(test, feature = "test-support"))]
pub mod fake_provider;
use anyhow::{Result, anyhow};
use cloud_llm_client::CompletionRequestStatus;
pub use language_model_core::*;
use anyhow::Result;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use gpui::{AnyView, App, AsyncApp, Task, Window};
use icons::IconName;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
use thiserror::Error;
use util::serde::is_default;
pub use crate::api_key::{ApiKey, ApiKeyState};
pub use crate::model::*;
pub use crate::rate_limiter::*;
pub use crate::registry::*;
pub use crate::request::*;
pub use crate::role::*;
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
pub use env_var::{EnvVar, env_var};
pub use provider::*;
pub fn init(cx: &mut App) {
registry::init(cx);
}
#[derive(Clone, Debug)]
pub struct LanguageModelCacheConfiguration {
pub max_cache_anchors: usize,
pub should_speculate: bool,
pub min_total_token: u64,
}
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
Queued {
position: usize,
},
Started,
Stop(StopReason),
Text(String),
Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking {
data: String,
},
ToolUse(LanguageModelToolUse),
ToolUseJsonParseError {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
StartMessage {
message_id: String,
},
ReasoningDetails(serde_json::Value),
UsageUpdate(TokenUsage),
}
impl LanguageModelCompletionEvent {
pub fn from_completion_request_status(
status: CompletionRequestStatus,
upstream_provider: LanguageModelProviderName,
) -> Result<Option<Self>, LanguageModelCompletionError> {
match status {
CompletionRequestStatus::Queued { position } => {
Ok(Some(LanguageModelCompletionEvent::Queued { position }))
}
CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
CompletionRequestStatus::Failed {
code,
message,
request_id: _,
retry_after,
} => Err(LanguageModelCompletionError::from_cloud_failure(
upstream_provider,
code,
message,
retry_after.map(Duration::from_secs_f64),
)),
}
}
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> },
#[error("missing {provider} API key")]
NoApiKey { provider: LanguageModelProviderName },
#[error("{provider}'s API rate limit exceeded")]
RateLimitExceeded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API servers are overloaded right now")]
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("{message}")]
UpstreamProviderError {
message: String,
status: StatusCode,
retry_after: Option<Duration>,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
// Client errors
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("Permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("stream from {provider} ended unexpectedly")]
StreamEndedUnexpectedly { provider: LanguageModelProviderName },
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl LanguageModelCompletionError {
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
let upstream_status = error_json
.get("upstream_status")
.and_then(|v| v.as_u64())
.and_then(|status| u16::try_from(status).ok())
.and_then(|status| StatusCode::from_u16(status).ok())?;
let inner_message = error_json
.get("message")
.and_then(|v| v.as_str())
.unwrap_or(message)
.to_string();
Some((upstream_status, inner_message))
}
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
) -> Self {
if let Some(tokens) = parse_prompt_too_long(&message) {
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
// to be reported. This is a temporary workaround to handle this in the case where the
// token limit has been exceeded.
Self::PromptTooLarge {
tokens: Some(tokens),
}
} else if code == "upstream_http_error" {
if let Some((upstream_status, inner_message)) =
Self::parse_upstream_error_json(&message)
{
return Self::from_http_status(
upstream_provider,
upstream_status,
inner_message,
retry_after,
);
}
anyhow!("completion request failed, code: {code}, message: {message}").into()
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
}
}
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
Refusal,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")]
pub input_tokens: u64,
#[serde(default, skip_serializing_if = "is_default")]
pub output_tokens: u64,
#[serde(default, skip_serializing_if = "is_default")]
pub cache_creation_input_tokens: u64,
#[serde(default, skip_serializing_if = "is_default")]
pub cache_read_input_tokens: u64,
}
impl TokenUsage {
pub fn total_tokens(&self) -> u64 {
self.input_tokens
+ self.output_tokens
+ self.cache_read_input_tokens
+ self.cache_creation_input_tokens
}
}
impl Add<TokenUsage> for TokenUsage {
type Output = Self;
fn add(self, other: Self) -> Self {
Self {
input_tokens: self.input_tokens + other.input_tokens,
output_tokens: self.output_tokens + other.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
+ other.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
}
}
}
impl Sub<TokenUsage> for TokenUsage {
type Output = Self;
fn sub(self, other: Self) -> Self {
Self {
input_tokens: self.input_tokens - other.input_tokens,
output_tokens: self.output_tokens - other.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
- other.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUseId(Arc<str>);
impl fmt::Display for LanguageModelToolUseId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl<T> From<T> for LanguageModelToolUseId
where
T: Into<Arc<str>>,
{
fn from(value: T) -> Self {
Self(value.into())
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: LanguageModelToolUseId,
pub name: Arc<str>,
pub raw_input: String,
pub input: serde_json::Value,
pub is_input_complete: bool,
/// Thought signature the model sent us. Some models require that this
/// signature be preserved and sent back in conversation history for validation.
pub thought_signature: Option<String>,
}
pub struct LanguageModelTextStream {
pub message_id: Option<String>,
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
@ -392,13 +43,6 @@ impl Default for LanguageModelTextStream {
}
}
#[derive(Debug, Clone)]
pub struct LanguageModelEffortLevel {
pub name: SharedString,
pub value: SharedString,
pub is_default: bool,
}
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
@ -605,7 +249,7 @@ pub trait LanguageModel: Send + Sync {
}
impl std::fmt::Debug for dyn LanguageModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("<dyn LanguageModel>")
.field("id", &self.id())
.field("name", &self.name())
@ -619,17 +263,6 @@ impl std::fmt::Debug for dyn LanguageModel {
}
}
/// An error that occurred when trying to authenticate the language model provider.
#[derive(Debug, Error)]
pub enum AuthenticateError {
#[error("connection refused")]
ConnectionRefused,
#[error("credentials not found")]
CredentialsNotFound,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Either a built-in icon name or a path to an external SVG.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IconOrSvg {
@ -692,18 +325,6 @@ pub trait LanguageModelProviderState: 'static {
}
}
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelName(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderName(pub SharedString);
#[derive(Clone, Debug, PartialEq)]
pub enum LanguageModelCostInfo {
/// Cost per 1,000 input and output tokens
@ -741,245 +362,3 @@ impl LanguageModelCostInfo {
}
}
}
impl LanguageModelProviderId {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl LanguageModelProviderName {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl fmt::Display for LanguageModelProviderId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl fmt::Display for LanguageModelProviderName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for LanguageModelId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelName {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderName {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<Arc<str>> for LanguageModelProviderId {
fn from(value: Arc<str>) -> Self {
Self(SharedString::from(value))
}
}
impl From<Arc<str>> for LanguageModelProviderName {
fn from(value: Arc<str>) -> Self {
Self(SharedString::from(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_cloud_failure_with_upstream_http_error() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
_ => panic!(
"Expected ServerOverloaded error for 503 status, got: {:?}",
error
),
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
assert_eq!(message, "Internal server error");
}
_ => panic!(
"Expected ApiInternalServerError for 500 status, got: {:?}",
error
),
}
}
#[test]
fn test_from_cloud_failure_with_standard_format() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_503".to_string(),
"Service unavailable".to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
}
}
#[test]
fn test_upstream_http_error_connection_timeout() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
_ => panic!(
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
error
),
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
assert_eq!(
message,
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
);
}
_ => panic!(
"Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
error
),
}
}
#[test]
fn test_language_model_tool_use_serializes_with_signature() {
use serde_json::json;
let tool_use = LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_tool".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("test_signature".to_string()),
};
let serialized = serde_json::to_value(&tool_use).unwrap();
assert_eq!(serialized["id"], "test_id");
assert_eq!(serialized["name"], "test_tool");
assert_eq!(serialized["thought_signature"], "test_signature");
}
#[test]
fn test_language_model_tool_use_deserializes_with_missing_signature() {
use serde_json::json;
let json = json!({
"id": "test_id",
"name": "test_tool",
"raw_input": "{\"arg\":\"value\"}",
"input": {"arg": "value"},
"is_input_complete": true
});
let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
assert_eq!(tool_use.name.as_ref(), "test_tool");
assert_eq!(tool_use.thought_signature, None);
}
#[test]
fn test_language_model_tool_use_round_trip_with_signature() {
use serde_json::json;
let original = LanguageModelToolUse {
id: LanguageModelToolUseId::from("round_trip_id"),
name: "round_trip_tool".into(),
raw_input: json!({"key": "value"}).to_string(),
input: json!({"key": "value"}),
is_input_complete: true,
thought_signature: Some("round_trip_sig".to_string()),
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.id, original.id);
assert_eq!(deserialized.name, original.name);
assert_eq!(deserialized.thought_signature, original.thought_signature);
}
#[test]
fn test_language_model_tool_use_round_trip_without_signature() {
use serde_json::json;
let original = LanguageModelToolUse {
id: LanguageModelToolUseId::from("no_sig_id"),
name: "no_sig_tool".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: None,
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.id, original.id);
assert_eq!(deserialized.name, original.name);
assert_eq!(deserialized.thought_signature, None);
}
}

View file

@ -1,10 +1,5 @@
use std::fmt;
use std::sync::Arc;
use cloud_api_client::ClientApiError;
use cloud_api_client::CloudApiClient;
use cloud_api_types::OrganizationId;
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
#[derive(Error, Debug)]
@ -18,71 +13,3 @@ impl fmt::Display for PaymentRequiredError {
)
}
}
#[derive(Clone, Default)]
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
pub async fn acquire(
&self,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
Self::fetch(
RwLockUpgradableReadGuard::upgrade(lock).await,
client,
system_id,
organization_id,
)
.await
}
}
pub async fn refresh(
&self,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
Self::fetch(self.0.write().await, client, system_id, organization_id).await
}
/// Clears the existing token before attempting to fetch a new one.
///
/// Used when switching organizations so that a failed refresh doesn't
/// leave a token for the wrong organization.
pub async fn clear_and_refresh(
&self,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
let mut lock = self.0.write().await;
*lock = None;
Self::fetch(lock, client, system_id, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &CloudApiClient,
system_id: Option<String>,
organization_id: Option<OrganizationId>,
) -> Result<String, ClientApiError> {
let result = client.create_llm_token(system_id, organization_id).await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
Ok(response.token.0)
}
Err(err) => {
*lock = None;
Err(err)
}
}
}
}

View file

@ -1,12 +0,0 @@
pub mod anthropic;
pub mod google;
pub mod open_ai;
pub mod open_router;
pub mod x_ai;
pub mod zed;
pub use anthropic::*;
pub use google::*;
pub use open_ai::*;
pub use x_ai::*;
pub use zed::*;

View file

@ -1,80 +0,0 @@
use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName};
use anthropic::AnthropicError;
pub use anthropic::parse_prompt_too_long;
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Anthropic");
impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
AnthropicError::ApiError(api_error) => api_error.into(),
}
}
}
impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message),
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
},
None => Self::Other(error.into()),
}
}
}

View file

@ -1,5 +0,0 @@
use crate::{LanguageModelProviderId, LanguageModelProviderName};
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Google AI");

View file

@ -1,28 +0,0 @@
use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName};
use http_client::http;
use std::time::Duration;
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("OpenAI");
impl From<open_ai::RequestError> for LanguageModelCompletionError {
fn from(error: open_ai::RequestError) -> Self {
match error {
open_ai::RequestError::HttpResponseError {
provider,
status_code,
body,
headers,
} => {
let retry_after = headers
.get(http::header::RETRY_AFTER)
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
.map(Duration::from_secs);
Self::from_http_status(provider.into(), status_code, body, retry_after)
}
open_ai::RequestError::Other(e) => Self::Other(e),
}
}
}

View file

@ -1,69 +0,0 @@
use crate::{LanguageModelCompletionError, LanguageModelProviderName};
use http_client::StatusCode;
use open_router::OpenRouterError;
impl From<OpenRouterError> for LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self {
let provider = LanguageModelProviderName::new("OpenRouter");
match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
OpenRouterError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
OpenRouterError::ApiError(api_error) => api_error.into(),
}
}
}
impl From<open_router::ApiError> for LanguageModelCompletionError {
fn from(error: open_router::ApiError) -> Self {
use open_router::ApiErrorCode::*;
let provider = LanguageModelProviderName::new("OpenRouter");
match error.code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message),
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
RequestTimedOut => Self::HttpResponseError {
provider,
status_code: StatusCode::REQUEST_TIMEOUT,
message: error.message,
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
}
}
}

View file

@ -1,4 +0,0 @@
use crate::{LanguageModelProviderId, LanguageModelProviderName};
pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");

View file

@ -1,5 +0,0 @@
use crate::{LanguageModelProviderId, LanguageModelProviderName};
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");

View file

@ -1,6 +1,6 @@
use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
};
use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
@ -101,7 +101,7 @@ impl ConfiguredModel {
}
pub fn is_provided_by_zed(&self) -> bool {
self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID
self.provider.id() == ZED_CLOUD_PROVIDER_ID
}
}

View file

@ -4,78 +4,13 @@ use std::sync::Arc;
use anyhow::Result;
use base64::write::EncoderWriter;
use gpui::{
App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
point, px, size,
App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, Size, Task, point, px, size,
};
use image::GenericImageView as _;
use image::codecs::png::PngEncoder;
use serde::{Deserialize, Serialize};
use util::ResultExt;
use crate::role::Role;
use crate::{LanguageModelToolUse, LanguageModelToolUseId};
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct LanguageModelImage {
/// A base64-encoded PNG image.
pub source: SharedString,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub size: Option<Size<DevicePixels>>,
}
impl LanguageModelImage {
pub fn len(&self) -> usize {
self.source.len()
}
pub fn is_empty(&self) -> bool {
self.source.is_empty()
}
// Parse Self from a JSON object with case-insensitive field names
pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
let mut source = None;
let mut size_obj = None;
// Find source and size fields (case-insensitive)
for (k, v) in obj.iter() {
match k.to_lowercase().as_str() {
"source" => source = v.as_str(),
"size" => size_obj = v.as_object(),
_ => {}
}
}
let source = source?;
let size_obj = size_obj?;
let mut width = None;
let mut height = None;
// Find width and height in size object (case-insensitive)
for (k, v) in size_obj.iter() {
match k.to_lowercase().as_str() {
"width" => width = v.as_i64().map(|w| w as i32),
"height" => height = v.as_i64().map(|h| h as i32),
_ => {}
}
}
Some(Self {
size: Some(size(DevicePixels(width?), DevicePixels(height?))),
source: SharedString::from(source.to_string()),
})
}
}
impl std::fmt::Debug for LanguageModelImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelImage")
.field("source", &format!("<{} bytes>", self.source.len()))
.field("size", &self.size)
.finish()
}
}
use language_model_core::{ImageSize, LanguageModelImage};
/// Anthropic wants uploaded images to be smaller than this in both dimensions.
const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
@ -90,18 +25,16 @@ const DEFAULT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
/// `DEFAULT_IMAGE_MAX_BYTES`.
const MAX_IMAGE_DOWNSCALE_PASSES: usize = 8;
impl LanguageModelImage {
// All language model images are encoded as PNGs.
pub const FORMAT: ImageFormat = ImageFormat::Png;
/// Extension trait for `LanguageModelImage` that provides GPUI-dependent functionality.
pub trait LanguageModelImageExt {
const FORMAT: ImageFormat;
fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<LanguageModelImage>>;
}
pub fn empty() -> Self {
Self {
source: "".into(),
size: None,
}
}
impl LanguageModelImageExt for LanguageModelImage {
const FORMAT: ImageFormat = ImageFormat::Png;
pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<LanguageModelImage>> {
cx.background_spawn(async move {
let image_bytes = Cursor::new(data.bytes());
let dynamic_image = match data.format() {
@ -186,28 +119,14 @@ impl LanguageModelImage {
let source = unsafe { String::from_utf8_unchecked(base64_image) };
Some(LanguageModelImage {
size: Some(image_size),
size: Some(ImageSize {
width: width as i32,
height: height as i32,
}),
source: source.into(),
})
})
}
pub fn estimate_tokens(&self) -> usize {
let Some(size) = self.size.as_ref() else {
return 0;
};
let width = size.width.0.unsigned_abs() as usize;
let height = size.height.0.unsigned_abs() as usize;
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
// Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
// so this method is more of a rough guess.
(width * height) / 750
}
pub fn to_base64_url(&self) -> String {
format!("data:image/png;base64,{}", self.source)
}
}
fn encode_png_bytes(image: &image::DynamicImage) -> Result<Vec<u8>> {
@ -228,280 +147,20 @@ fn encode_bytes_as_base64(bytes: &[u8]) -> Result<Vec<u8>> {
Ok(base64_image)
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LanguageModelToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub tool_name: Arc<str>,
pub is_error: bool,
/// The tool output formatted for presenting to the model
pub content: LanguageModelToolResultContent,
/// The raw tool output, if available, often for debugging or extra state for replay
pub output: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
pub enum LanguageModelToolResultContent {
Text(Arc<str>),
Image(LanguageModelImage),
}
impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value = serde_json::Value::deserialize(deserializer)?;
// Models can provide these responses in several styles. Try each in order.
// 1. Try as plain string
if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
return Ok(Self::Text(Arc::from(text)));
}
// 2. Try as object
if let Some(obj) = value.as_object() {
// get a JSON field case-insensitively
fn get_field<'a>(
obj: &'a serde_json::Map<String, serde_json::Value>,
field: &str,
) -> Option<&'a serde_json::Value> {
obj.iter()
.find(|(k, _)| k.to_lowercase() == field.to_lowercase())
.map(|(_, v)| v)
}
// Accept wrapped text format: { "type": "text", "text": "..." }
if let (Some(type_value), Some(text_value)) =
(get_field(obj, "type"), get_field(obj, "text"))
&& let Some(type_str) = type_value.as_str()
&& type_str.to_lowercase() == "text"
&& let Some(text) = text_value.as_str()
{
return Ok(Self::Text(Arc::from(text)));
}
// Check for wrapped Text variant: { "text": "..." }
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
&& obj.len() == 1
{
// Only one field, and it's "text" (case-insensitive)
if let Some(text) = value.as_str() {
return Ok(Self::Text(Arc::from(text)));
}
}
// Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
&& obj.len() == 1
{
// Only one field, and it's "image" (case-insensitive)
// Try to parse the nested image object
if let Some(image_obj) = value.as_object()
&& let Some(image) = LanguageModelImage::from_json(image_obj)
{
return Ok(Self::Image(image));
}
}
// Try as direct Image (object with "source" and "size" fields)
if let Some(image) = LanguageModelImage::from_json(obj) {
return Ok(Self::Image(image));
}
}
// If none of the variants match, return an error with the problematic JSON
Err(D::Error::custom(format!(
"data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
)))
/// Convert a core `ImageSize` to a gpui `Size<DevicePixels>`.
pub fn image_size_to_gpui(size: ImageSize) -> Size<DevicePixels> {
Size {
width: DevicePixels(size.width),
height: DevicePixels(size.height),
}
}
impl LanguageModelToolResultContent {
pub fn to_str(&self) -> Option<&str> {
match self {
Self::Text(text) => Some(text),
Self::Image(_) => None,
}
/// Convert a gpui `Size<DevicePixels>` to a core `ImageSize`.
pub fn gpui_size_to_image_size(size: Size<DevicePixels>) -> ImageSize {
ImageSize {
width: size.width.0,
height: size.height.0,
}
pub fn is_empty(&self) -> bool {
match self {
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
Self::Image(_) => false,
}
}
}
impl From<&str> for LanguageModelToolResultContent {
fn from(value: &str) -> Self {
Self::Text(Arc::from(value))
}
}
impl From<String> for LanguageModelToolResultContent {
fn from(value: String) -> Self {
Self::Text(Arc::from(value))
}
}
impl From<LanguageModelImage> for LanguageModelToolResultContent {
fn from(image: LanguageModelImage) -> Self {
Self::Image(image)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent {
Text(String),
Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking(String),
Image(LanguageModelImage),
ToolUse(LanguageModelToolUse),
ToolResult(LanguageModelToolResult),
}
impl MessageContent {
pub fn to_str(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(text.as_str()),
MessageContent::Thinking { text, .. } => Some(text.as_str()),
MessageContent::RedactedThinking(_) => None,
MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
}
}
pub fn is_empty(&self) -> bool {
match self {
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
MessageContent::RedactedThinking(_)
| MessageContent::ToolUse(_)
| MessageContent::Image(_) => false,
}
}
}
impl From<String> for MessageContent {
fn from(value: String) -> Self {
MessageContent::Text(value)
}
}
impl From<&str> for MessageContent {
fn from(value: &str) -> Self {
MessageContent::Text(value.to_string())
}
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: Vec<MessageContent>,
pub cache: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_details: Option<serde_json::Value>,
}
impl LanguageModelRequestMessage {
pub fn string_contents(&self) -> String {
let mut buffer = String::new();
for string in self.content.iter().filter_map(|content| content.to_str()) {
buffer.push_str(string);
}
buffer
}
pub fn contents_empty(&self) -> bool {
self.content.iter().all(|content| content.is_empty())
}
}
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelRequestTool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub use_input_streaming: bool,
}
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
pub enum LanguageModelToolChoice {
Auto,
Any,
None,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionIntent {
UserPrompt,
Subagent,
ToolResults,
ThreadSummarization,
ThreadContextSummarization,
CreateFile,
EditFile,
InlineAssist,
TerminalInlineAssist,
GenerateGitCommitMessage,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct LanguageModelRequest {
pub thread_id: Option<String>,
pub prompt_id: Option<String>,
pub intent: Option<CompletionIntent>,
pub messages: Vec<LanguageModelRequestMessage>,
pub tools: Vec<LanguageModelRequestTool>,
pub tool_choice: Option<LanguageModelToolChoice>,
pub stop: Vec<String>,
pub temperature: Option<f32>,
pub thinking_allowed: bool,
pub thinking_effort: Option<String>,
pub speed: Option<Speed>,
}
#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Speed {
#[default]
Standard,
Fast,
}
impl Speed {
pub fn toggle(self) -> Self {
match self {
Speed::Standard => Speed::Fast,
Speed::Fast => Speed::Standard,
}
}
}
impl From<Speed> for anthropic::Speed {
fn from(speed: Speed) -> Self {
match speed {
Speed::Standard => anthropic::Speed::Standard,
Speed::Fast => anthropic::Speed::Fast,
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[cfg(test)]
@ -509,231 +168,64 @@ mod tests {
use super::*;
use base64::Engine as _;
use gpui::TestAppContext;
use image::ImageDecoder as _;
fn base64_to_png_bytes(base64_png: &str) -> Vec<u8> {
fn base64_to_png_bytes(base64: &str) -> Vec<u8> {
base64::engine::general_purpose::STANDARD
.decode(base64_png.as_bytes())
.expect("base64 should decode")
.decode(base64)
.expect("valid base64")
}
fn png_dimensions(png_bytes: &[u8]) -> (u32, u32) {
let decoder =
image::codecs::png::PngDecoder::new(Cursor::new(png_bytes)).expect("png should decode");
decoder.dimensions()
let img = image::load_from_memory(png_bytes).expect("valid png");
(img.width(), img.height())
}
fn make_noisy_png_bytes(width: u32, height: u32) -> Vec<u8> {
// Create an RGBA image with per-pixel variance to avoid PNG compressing too well.
let mut img = image::RgbaImage::new(width, height);
for y in 0..height {
for x in 0..width {
let r = ((x ^ y) & 0xFF) as u8;
let g = ((x.wrapping_mul(31) ^ y.wrapping_mul(17)) & 0xFF) as u8;
let b = ((x.wrapping_mul(131) ^ y.wrapping_mul(7)) & 0xFF) as u8;
img.put_pixel(x, y, image::Rgba([r, g, b, 0xFF]));
}
}
use image::{ImageBuffer, Rgba};
use std::hash::{Hash, Hasher};
let mut out = Vec::new();
image::DynamicImage::ImageRgba8(img)
.write_with_encoder(PngEncoder::new(&mut out))
.expect("png encoding should succeed");
out
let img = ImageBuffer::from_fn(width, height, |x, y| {
let mut hasher = std::hash::DefaultHasher::new();
(x, y, width, height).hash(&mut hasher);
let h = hasher.finish();
Rgba([h as u8, (h >> 8) as u8, (h >> 16) as u8, 255])
});
let mut buf = Cursor::new(Vec::new());
img.write_with_encoder(PngEncoder::new(&mut buf))
.expect("encode");
buf.into_inner()
}
#[gpui::test]
async fn test_from_image_downscales_to_default_5mb_limit(cx: &mut TestAppContext) {
// Pick a size that reliably produces a PNG > 5MB when filled with noise.
// If this fails (image is too small), bump dimensions.
let original_png = make_noisy_png_bytes(4096, 4096);
let raw_png = make_noisy_png_bytes(4096, 4096);
assert!(
original_png.len() > DEFAULT_IMAGE_MAX_BYTES,
"precondition failed: noisy PNG must exceed DEFAULT_IMAGE_MAX_BYTES"
raw_png.len() > DEFAULT_IMAGE_MAX_BYTES,
"Test image should exceed the 5 MB limit (actual: {} bytes)",
raw_png.len()
);
let image = gpui::Image::from_bytes(ImageFormat::Png, original_png);
let image = Arc::new(gpui::Image::from_bytes(ImageFormat::Png, raw_png));
let lm_image = cx
.update(|cx| LanguageModelImage::from_image(Arc::new(image), cx))
.update(|cx| LanguageModelImage::from_image(Arc::clone(&image), cx))
.await
.expect("image conversion should succeed");
.expect("from_image should succeed");
let encoded_png = base64_to_png_bytes(lm_image.source.as_ref());
let decoded_png = base64_to_png_bytes(lm_image.source.as_ref());
assert!(
encoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES,
"expected encoded PNG <= DEFAULT_IMAGE_MAX_BYTES, got {} bytes",
encoded_png.len()
decoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES,
"Encoded PNG should be ≤ {} bytes after downscale, but was {} bytes",
DEFAULT_IMAGE_MAX_BYTES,
decoded_png.len()
);
// Ensure we actually downscaled in pixels (not just re-encoded).
let (w, h) = png_dimensions(&encoded_png);
let (w, h) = png_dimensions(&decoded_png);
assert!(
w < 4096 || h < 4096,
"expected image to be downscaled in at least one dimension; got {w}x{h}"
w < 4096 && h < 4096,
"Dimensions should have shrunk: got {}×{}",
w,
h
);
}
#[test]
fn test_language_model_tool_result_content_deserialization() {
let json = r#""This is plain text""#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("This is plain text".into())
);
let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("This is wrapped text".into())
);
let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Case insensitive".into())
);
let json = r#"{"Text": "Wrapped variant"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Wrapped variant".into())
);
let json = r#"{"text": "Lowercase wrapped"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Lowercase wrapped".into())
);
// Test image deserialization
let json = r#"{
"source": "base64encodedimagedata",
"size": {
"width": 100,
"height": 200
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "base64encodedimagedata");
let size = image.size.expect("size");
assert_eq!(size.width.0, 100);
assert_eq!(size.height.0, 200);
}
_ => panic!("Expected Image variant"),
}
// Test wrapped Image variant
let json = r#"{
"Image": {
"source": "wrappedimagedata",
"size": {
"width": 50,
"height": 75
}
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "wrappedimagedata");
let size = image.size.expect("size");
assert_eq!(size.width.0, 50);
assert_eq!(size.height.0, 75);
}
_ => panic!("Expected Image variant"),
}
// Test wrapped Image variant with case insensitive
let json = r#"{
"image": {
"Source": "caseinsensitive",
"SIZE": {
"width": 30,
"height": 40
}
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "caseinsensitive");
let size = image.size.expect("size");
assert_eq!(size.width.0, 30);
assert_eq!(size.height.0, 40);
}
_ => panic!("Expected Image variant"),
}
// Test that wrapped text with wrong type fails
let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
// Test that malformed JSON fails
let json = r#"{"invalid": "structure"}"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
// Test edge cases
let json = r#""""#; // Empty string
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
// Test with extra fields in wrapped text (should be ignored)
let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
// Test direct image with case-insensitive fields
let json = r#"{
"SOURCE": "directimage",
"Size": {
"width": 200,
"height": 300
}
}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
match result {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "directimage");
let size = image.size.expect("size");
assert_eq!(size.width.0, 200);
assert_eq!(size.height.0, 300);
}
_ => panic!("Expected Image variant"),
}
// Test that multiple fields prevent wrapped variant interpretation
let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
// Test wrapped text with uppercase TEXT variant
let json = r#"{"TEXT": "Uppercase variant"}"#;
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
assert_eq!(
result,
LanguageModelToolResultContent::Text("Uppercase variant".into())
);
// Test that numbers and other JSON values fail gracefully
let json = r#"123"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
let json = r#"null"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
let json = r#"[1, 2, 3]"#;
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
assert!(result.is_err());
}
}

View file

@ -0,0 +1,27 @@
[package]
name = "language_model_core"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/language_model_core.rs"
doctest = false
[dependencies]
anyhow.workspace = true
cloud_llm_client.workspace = true
futures.workspace = true
gpui_shared_string.workspace = true
http_client.workspace = true
partial-json-fixer.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
strum.workspace = true
thiserror.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,658 @@
mod provider;
mod rate_limiter;
mod request;
mod role;
pub mod tool_schema;
pub mod util;
use anyhow::{Result, anyhow};
use cloud_llm_client::CompletionRequestStatus;
use http_client::{StatusCode, http};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
use thiserror::Error;
fn is_default<T: Default + PartialEq>(value: &T) -> bool {
*value == T::default()
}
pub use crate::provider::*;
pub use crate::rate_limiter::*;
pub use crate::request::*;
pub use crate::role::*;
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
pub use crate::util::{fix_streamed_json, parse_prompt_too_long, parse_tool_arguments};
pub use gpui_shared_string::SharedString;
#[derive(Clone, Debug)]
pub struct LanguageModelCacheConfiguration {
pub max_cache_anchors: usize,
pub should_speculate: bool,
pub min_total_token: u64,
}
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
Queued {
position: usize,
},
Started,
Stop(StopReason),
Text(String),
Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking {
data: String,
},
ToolUse(LanguageModelToolUse),
ToolUseJsonParseError {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
StartMessage {
message_id: String,
},
ReasoningDetails(serde_json::Value),
UsageUpdate(TokenUsage),
}
impl LanguageModelCompletionEvent {
pub fn from_completion_request_status(
status: CompletionRequestStatus,
upstream_provider: LanguageModelProviderName,
) -> Result<Option<Self>, LanguageModelCompletionError> {
match status {
CompletionRequestStatus::Queued { position } => {
Ok(Some(LanguageModelCompletionEvent::Queued { position }))
}
CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
CompletionRequestStatus::Failed {
code,
message,
request_id: _,
retry_after,
} => Err(LanguageModelCompletionError::from_cloud_failure(
upstream_provider,
code,
message,
retry_after.map(Duration::from_secs_f64),
)),
}
}
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> },
#[error("missing {provider} API key")]
NoApiKey { provider: LanguageModelProviderName },
#[error("{provider}'s API rate limit exceeded")]
RateLimitExceeded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API servers are overloaded right now")]
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("{message}")]
UpstreamProviderError {
message: String,
status: StatusCode,
retry_after: Option<Duration>,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("Permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("stream from {provider} ended unexpectedly")]
StreamEndedUnexpectedly { provider: LanguageModelProviderName },
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl LanguageModelCompletionError {
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
let upstream_status = error_json
.get("upstream_status")
.and_then(|v| v.as_u64())
.and_then(|status| u16::try_from(status).ok())
.and_then(|status| StatusCode::from_u16(status).ok())?;
let inner_message = error_json
.get("message")
.and_then(|v| v.as_str())
.unwrap_or(message)
.to_string();
Some((upstream_status, inner_message))
}
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
) -> Self {
if let Some(tokens) = parse_prompt_too_long(&message) {
Self::PromptTooLarge {
tokens: Some(tokens),
}
} else if code == "upstream_http_error" {
if let Some((upstream_status, inner_message)) =
Self::parse_upstream_error_json(&message)
{
return Self::from_http_status(
upstream_provider,
upstream_status,
inner_message,
retry_after,
);
}
anyhow!("completion request failed, code: {code}, message: {message}").into()
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
}
}
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
Refusal,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")]
pub input_tokens: u64,
#[serde(default, skip_serializing_if = "is_default")]
pub output_tokens: u64,
#[serde(default, skip_serializing_if = "is_default")]
pub cache_creation_input_tokens: u64,
#[serde(default, skip_serializing_if = "is_default")]
pub cache_read_input_tokens: u64,
}
impl TokenUsage {
pub fn total_tokens(&self) -> u64 {
self.input_tokens
+ self.output_tokens
+ self.cache_read_input_tokens
+ self.cache_creation_input_tokens
}
}
impl Add<TokenUsage> for TokenUsage {
type Output = Self;
fn add(self, other: Self) -> Self {
Self {
input_tokens: self.input_tokens + other.input_tokens,
output_tokens: self.output_tokens + other.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
+ other.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
}
}
}
impl Sub<TokenUsage> for TokenUsage {
type Output = Self;
fn sub(self, other: Self) -> Self {
Self {
input_tokens: self.input_tokens - other.input_tokens,
output_tokens: self.output_tokens - other.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
- other.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUseId(Arc<str>);
impl fmt::Display for LanguageModelToolUseId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl<T> From<T> for LanguageModelToolUseId
where
T: Into<Arc<str>>,
{
fn from(value: T) -> Self {
Self(value.into())
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: LanguageModelToolUseId,
pub name: Arc<str>,
pub raw_input: String,
pub input: serde_json::Value,
pub is_input_complete: bool,
/// Thought signature the model sent us. Some models require that this
/// signature be preserved and sent back in conversation history for validation.
pub thought_signature: Option<String>,
}
#[derive(Debug, Clone)]
pub struct LanguageModelEffortLevel {
pub name: SharedString,
pub value: SharedString,
pub is_default: bool,
}
/// An error that occurred when trying to authenticate the language model provider.
#[derive(Debug, Error)]
pub enum AuthenticateError {
#[error("connection refused")]
ConnectionRefused,
#[error("credentials not found")]
CredentialsNotFound,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelName(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderName(pub SharedString);
impl LanguageModelProviderId {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl LanguageModelProviderName {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl fmt::Display for LanguageModelProviderId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl fmt::Display for LanguageModelProviderName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for LanguageModelId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelName {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderName {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<Arc<str>> for LanguageModelProviderId {
fn from(value: Arc<str>) -> Self {
Self(SharedString::from(value))
}
}
impl From<Arc<str>> for LanguageModelProviderName {
fn from(value: Arc<str>) -> Self {
Self(SharedString::from(value))
}
}
/// Settings-layerfree model mode enum.
///
/// Mirrors the shape of `settings_content::ModelMode` but lives here so that
/// crates below the settings layer can reference it.
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u32>,
},
}
/// Settings-layerfree reasoning-effort enum.
///
/// Mirrors the shape of `settings_content::OpenAiReasoningEffort` but lives
/// here so that crates below the settings layer can reference it.
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, strum::EnumString,
)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase")]
pub enum ReasoningEffort {
Minimal,
Low,
Medium,
High,
XHigh,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_cloud_failure_with_upstream_http_error() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
_ => panic!(
"Expected ServerOverloaded error for 503 status, got: {:?}",
error
),
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
assert_eq!(message, "Internal server error");
}
_ => panic!(
"Expected ApiInternalServerError for 500 status, got: {:?}",
error
),
}
}
#[test]
fn test_from_cloud_failure_with_standard_format() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_503".to_string(),
"Service unavailable".to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
}
}
#[test]
fn test_upstream_http_error_connection_timeout() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
_ => panic!(
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
error
),
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
assert_eq!(
message,
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
);
}
_ => panic!(
"Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
error
),
}
}
#[test]
fn test_language_model_tool_use_serializes_with_signature() {
use serde_json::json;
let tool_use = LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_tool".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("test_signature".to_string()),
};
let serialized = serde_json::to_value(&tool_use).unwrap();
assert_eq!(serialized["id"], "test_id");
assert_eq!(serialized["name"], "test_tool");
assert_eq!(serialized["thought_signature"], "test_signature");
}
#[test]
fn test_language_model_tool_use_deserializes_with_missing_signature() {
use serde_json::json;
let json = json!({
"id": "test_id",
"name": "test_tool",
"raw_input": "{\"arg\":\"value\"}",
"input": {"arg": "value"},
"is_input_complete": true
});
let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
assert_eq!(tool_use.name.as_ref(), "test_tool");
assert_eq!(tool_use.thought_signature, None);
}
#[test]
fn test_language_model_tool_use_round_trip_with_signature() {
use serde_json::json;
let original = LanguageModelToolUse {
id: LanguageModelToolUseId::from("round_trip_id"),
name: "round_trip_tool".into(),
raw_input: json!({"key": "value"}).to_string(),
input: json!({"key": "value"}),
is_input_complete: true,
thought_signature: Some("round_trip_sig".to_string()),
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.id, original.id);
assert_eq!(deserialized.name, original.name);
assert_eq!(deserialized.thought_signature, original.thought_signature);
}
#[test]
fn test_language_model_tool_use_round_trip_without_signature() {
use serde_json::json;
let original = LanguageModelToolUse {
id: LanguageModelToolUseId::from("no_sig_id"),
name: "no_sig_tool".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: None,
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.id, original.id);
assert_eq!(deserialized.name, original.name);
assert_eq!(deserialized.thought_signature, None);
}
}

View file

@ -0,0 +1,21 @@
use crate::{LanguageModelProviderId, LanguageModelProviderName};
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Anthropic");
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("OpenAI");
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Google AI");
pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");

View file

@ -0,0 +1,463 @@
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::role::Role;
use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
/// Dimensions of a `LanguageModelImage`
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ImageSize {
pub width: i32,
pub height: i32,
}
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct LanguageModelImage {
/// A base64-encoded PNG image.
pub source: SharedString,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub size: Option<ImageSize>,
}
impl LanguageModelImage {
pub fn len(&self) -> usize {
self.source.len()
}
pub fn is_empty(&self) -> bool {
self.source.is_empty()
}
pub fn empty() -> Self {
Self {
source: "".into(),
size: None,
}
}
/// Parse Self from a JSON object with case-insensitive field names
pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
let mut source = None;
let mut size_obj = None;
for (k, v) in obj.iter() {
match k.to_lowercase().as_str() {
"source" => source = v.as_str(),
"size" => size_obj = v.as_object(),
_ => {}
}
}
let source = source?;
let size_obj = size_obj?;
let mut width = None;
let mut height = None;
for (k, v) in size_obj.iter() {
match k.to_lowercase().as_str() {
"width" => width = v.as_i64().map(|w| w as i32),
"height" => height = v.as_i64().map(|h| h as i32),
_ => {}
}
}
Some(Self {
size: Some(ImageSize {
width: width?,
height: height?,
}),
source: SharedString::from(source.to_string()),
})
}
pub fn estimate_tokens(&self) -> usize {
let Some(size) = self.size.as_ref() else {
return 0;
};
let width = size.width.unsigned_abs() as usize;
let height = size.height.unsigned_abs() as usize;
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
(width * height) / 750
}
pub fn to_base64_url(&self) -> String {
format!("data:image/png;base64,{}", self.source)
}
}
impl std::fmt::Debug for LanguageModelImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelImage")
.field("source", &format!("<{} bytes>", self.source.len()))
.field("size", &self.size)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LanguageModelToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub tool_name: Arc<str>,
pub is_error: bool,
/// The tool output formatted for presenting to the model
pub content: LanguageModelToolResultContent,
/// The raw tool output, if available, often for debugging or extra state for replay
pub output: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
pub enum LanguageModelToolResultContent {
Text(Arc<str>),
Image(LanguageModelImage),
}
impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value = serde_json::Value::deserialize(deserializer)?;
// 1. Try as plain string
if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
return Ok(Self::Text(Arc::from(text)));
}
// 2. Try as object
if let Some(obj) = value.as_object() {
fn get_field<'a>(
obj: &'a serde_json::Map<String, serde_json::Value>,
field: &str,
) -> Option<&'a serde_json::Value> {
obj.iter()
.find(|(k, _)| k.to_lowercase() == field.to_lowercase())
.map(|(_, v)| v)
}
// Accept wrapped text format: { "type": "text", "text": "..." }
if let (Some(type_value), Some(text_value)) =
(get_field(obj, "type"), get_field(obj, "text"))
&& let Some(type_str) = type_value.as_str()
&& type_str.to_lowercase() == "text"
&& let Some(text) = text_value.as_str()
{
return Ok(Self::Text(Arc::from(text)));
}
// Check for wrapped Text variant: { "text": "..." }
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
&& obj.len() == 1
{
if let Some(text) = value.as_str() {
return Ok(Self::Text(Arc::from(text)));
}
}
// Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
&& obj.len() == 1
{
if let Some(image_obj) = value.as_object()
&& let Some(image) = LanguageModelImage::from_json(image_obj)
{
return Ok(Self::Image(image));
}
}
// Try as direct Image
if let Some(image) = LanguageModelImage::from_json(obj) {
return Ok(Self::Image(image));
}
}
Err(D::Error::custom(format!(
"data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
)))
}
}
impl LanguageModelToolResultContent {
pub fn to_str(&self) -> Option<&str> {
match self {
Self::Text(text) => Some(text),
Self::Image(_) => None,
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
Self::Image(_) => false,
}
}
}
impl From<&str> for LanguageModelToolResultContent {
fn from(value: &str) -> Self {
Self::Text(Arc::from(value))
}
}
impl From<String> for LanguageModelToolResultContent {
fn from(value: String) -> Self {
Self::Text(Arc::from(value))
}
}
impl From<LanguageModelImage> for LanguageModelToolResultContent {
fn from(image: LanguageModelImage) -> Self {
Self::Image(image)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent {
Text(String),
Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking(String),
Image(LanguageModelImage),
ToolUse(LanguageModelToolUse),
ToolResult(LanguageModelToolResult),
}
impl MessageContent {
pub fn to_str(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(text.as_str()),
MessageContent::Thinking { text, .. } => Some(text.as_str()),
MessageContent::RedactedThinking(_) => None,
MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
}
}
pub fn is_empty(&self) -> bool {
match self {
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
MessageContent::RedactedThinking(_)
| MessageContent::ToolUse(_)
| MessageContent::Image(_) => false,
}
}
}
impl From<String> for MessageContent {
fn from(value: String) -> Self {
MessageContent::Text(value)
}
}
impl From<&str> for MessageContent {
fn from(value: &str) -> Self {
MessageContent::Text(value.to_string())
}
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: Vec<MessageContent>,
pub cache: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_details: Option<serde_json::Value>,
}
impl LanguageModelRequestMessage {
pub fn string_contents(&self) -> String {
let mut buffer = String::new();
for string in self.content.iter().filter_map(|content| content.to_str()) {
buffer.push_str(string);
}
buffer
}
pub fn contents_empty(&self) -> bool {
self.content.iter().all(|content| content.is_empty())
}
}
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelRequestTool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub use_input_streaming: bool,
}
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
pub enum LanguageModelToolChoice {
Auto,
Any,
None,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionIntent {
UserPrompt,
Subagent,
ToolResults,
ThreadSummarization,
ThreadContextSummarization,
CreateFile,
EditFile,
InlineAssist,
TerminalInlineAssist,
GenerateGitCommitMessage,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct LanguageModelRequest {
pub thread_id: Option<String>,
pub prompt_id: Option<String>,
pub intent: Option<CompletionIntent>,
pub messages: Vec<LanguageModelRequestMessage>,
pub tools: Vec<LanguageModelRequestTool>,
pub tool_choice: Option<LanguageModelToolChoice>,
pub stop: Vec<String>,
pub temperature: Option<f32>,
pub thinking_allowed: bool,
pub thinking_effort: Option<String>,
pub speed: Option<Speed>,
}
#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Speed {
#[default]
Standard,
Fast,
}
impl Speed {
pub fn toggle(self) -> Self {
match self {
Speed::Standard => Speed::Fast,
Speed::Fast => Speed::Standard,
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_language_model_tool_result_content_deserialization() {
// Test plain string
let json = serde_json::json!("hello world");
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
assert_eq!(
content,
LanguageModelToolResultContent::Text(Arc::from("hello world"))
);
// Test wrapped text format: { "type": "text", "text": "..." }
let json = serde_json::json!({"type": "text", "text": "hello"});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
assert_eq!(
content,
LanguageModelToolResultContent::Text(Arc::from("hello"))
);
// Test single-field text object: { "text": "..." }
let json = serde_json::json!({"text": "hello"});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
assert_eq!(
content,
LanguageModelToolResultContent::Text(Arc::from("hello"))
);
// Test case-insensitive type field
let json = serde_json::json!({"Type": "Text", "Text": "hello"});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
assert_eq!(
content,
LanguageModelToolResultContent::Text(Arc::from("hello"))
);
// Test image object
let json = serde_json::json!({
"source": "base64encodedimagedata",
"size": {"width": 100, "height": 200}
});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
match content {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "base64encodedimagedata");
let size = image.size.expect("size");
assert_eq!(size.width, 100);
assert_eq!(size.height, 200);
}
_ => panic!("Expected Image variant"),
}
// Test wrapped image: { "image": { "source": "...", "size": ... } }
let json = serde_json::json!({
"image": {
"source": "wrappedimagedata",
"size": {"width": 50, "height": 75}
}
});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
match content {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "wrappedimagedata");
let size = image.size.expect("size");
assert_eq!(size.width, 50);
assert_eq!(size.height, 75);
}
_ => panic!("Expected Image variant"),
}
// Test case insensitive
let json = serde_json::json!({
"Source": "caseinsensitive",
"Size": {"Width": 30, "Height": 40}
});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
match content {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "caseinsensitive");
let size = image.size.expect("size");
assert_eq!(size.width, 30);
assert_eq!(size.height, 40);
}
_ => panic!("Expected Image variant"),
}
// Test direct image object
let json = serde_json::json!({
"source": "directimage",
"size": {"width": 200, "height": 300}
});
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
match content {
LanguageModelToolResultContent::Image(image) => {
assert_eq!(image.source.as_ref(), "directimage");
let size = image.size.expect("size");
assert_eq!(size.width, 200);
assert_eq!(size.height, 300);
}
_ => panic!("Expected Image variant"),
}
}
}

View file

@ -77,8 +77,6 @@ pub fn adapt_schema_to_format(
}
fn preprocess_json_schema(json: &mut Value) -> Result<()> {
// `additionalProperties` defaults to `false` unless explicitly specified.
// This prevents models from hallucinating tool parameters.
if let Value::Object(obj) = json
&& matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
{
@ -86,7 +84,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
// OpenAI API requires non-missing `properties`
if !obj.contains_key("properties") {
obj.insert("properties".to_string(), Value::Object(Default::default()));
}
@ -94,7 +91,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
Ok(())
}
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
if let Value::Object(obj) = json {
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
@ -108,9 +104,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
("format", |value| value.is_string()),
// Gemini doesn't support `additionalProperties` in any form (boolean or schema object)
("additionalProperties", |_| true),
// Gemini doesn't support `propertyNames`
("propertyNames", |_| true),
("exclusiveMinimum", |value| value.is_number()),
("exclusiveMaximum", |value| value.is_number()),
@ -124,7 +118,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
}
}
// If a type is not specified for an input parameter, add a default type
if matches!(obj.get("description"), Some(Value::String(_)))
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
@ -134,7 +127,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
obj.insert("type".to_string(), Value::String("string".to_string()));
}
// Handle oneOf -> anyOf conversion
if let Some(subschemas) = obj.get_mut("oneOf")
&& subschemas.is_array()
{
@ -143,7 +135,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
obj.insert("anyOf".to_string(), subschemas_clone);
}
// Recursively process all nested objects and arrays
for (_, value) in obj.iter_mut() {
if let Value::Object(_) | Value::Array(_) = value {
adapt_to_json_schema_subset(value)?;
@ -178,7 +169,6 @@ mod tests {
})
);
// Ensure that we do not add a type if it is an object
let mut json = json!({
"description": {
"value": "abc",
@ -221,7 +211,6 @@ mod tests {
})
);
// Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
let mut json = json!({
"description": "A test field",
"type": "integer",
@ -239,7 +228,6 @@ mod tests {
})
);
// additionalProperties as an object schema is also unsupported by Gemini
let mut json = json!({
"type": "object",
"properties": {

View file

@ -38,13 +38,22 @@ fn strip_trailing_incomplete_escape(json: &str) -> &str {
}
}
/// Parses a "prompt is too long: N tokens ..." message and extracts the token count.
pub fn parse_prompt_too_long(message: &str) -> Option<u64> {
message
.strip_prefix("prompt is too long: ")?
.split_once(" tokens")?
.0
.parse()
.ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fix_streamed_json_strips_incomplete_escape() {
// Trailing `\` inside a string — incomplete escape sequence
let fixed = fix_streamed_json(r#"{"text": "hello\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "hello");
@ -52,7 +61,6 @@ mod tests {
#[test]
fn test_fix_streamed_json_preserves_complete_escape() {
// `\\` is a complete escape (literal backslash)
let fixed = fix_streamed_json(r#"{"text": "hello\\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "hello\\");
@ -60,7 +68,6 @@ mod tests {
#[test]
fn test_fix_streamed_json_strips_escape_after_complete_escape() {
// `\\\` = complete `\\` (literal backslash) + incomplete `\`
let fixed = fix_streamed_json(r#"{"text": "hello\\\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "hello\\");
@ -75,12 +82,10 @@ mod tests {
#[test]
fn test_fix_streamed_json_newline_escape_boundary() {
// Simulates a stream boundary landing between `\` and `n`
let fixed = fix_streamed_json(r#"{"text": "line1\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "line1");
// Next chunk completes the escape
let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "line1\nline2");
@ -88,8 +93,6 @@ mod tests {
#[test]
fn test_fix_streamed_json_incremental_delta_correctness() {
// This is the actual scenario that causes the bug:
// chunk 1 ends mid-escape, chunk 2 completes it.
let chunk1 = r#"{"replacement_text": "fn foo() {\"#;
let fixed1 = fix_streamed_json(chunk1);
let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json");
@ -102,7 +105,6 @@ mod tests {
let text2 = parsed2["replacement_text"].as_str().expect("string");
assert_eq!(text2, "fn foo() {\n return bar;\n}");
// The delta should be the newline + rest, with no spurious backslash
let delta = &text2[text1.len()..];
assert_eq!(delta, "\n return bar;\n}");
}

View file

@ -21,8 +21,8 @@ aws_http_client.workspace = true
base64.workspace = true
bedrock = { workspace = true, features = ["schemars"] }
client.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
convert_case.workspace = true
@ -41,6 +41,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
language.workspace = true
language_model.workspace = true
language_models_cloud.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
menu.workspace = true
@ -49,16 +50,13 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
opencode = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
release_channel.workspace = true
schemars.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strum.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
@ -70,4 +68,3 @@ x_ai = { workspace = true, features = ["schemars"] }
[dev-dependencies]
language_model = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true

View file

@ -11,7 +11,7 @@ pub mod open_ai;
pub mod open_ai_compatible;
pub mod open_router;
pub mod opencode;
mod util;
pub mod vercel;
pub mod vercel_ai_gateway;
pub mod x_ai;

View file

@ -1,13 +1,10 @@
pub mod telemetry;
use anthropic::{
ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event,
ResponseContent, ToolResultContent, ToolResultPart, Usage,
};
use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode};
use anyhow::Result;
use collections::{BTreeMap, HashMap};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
use language_model::{
@ -16,20 +13,19 @@ use language_model::{
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, env_var,
LanguageModelToolChoice, RateLimiter, env_var,
};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
pub use anthropic::completion::{
AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
into_anthropic_count_tokens_request,
};
pub use settings::AnthropicAvailableModel as AvailableModel;
const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID;
@ -249,228 +245,6 @@ pub struct AnthropicModel {
request_limiter: RateLimiter,
}
fn to_anthropic_content(content: MessageContent) -> Option<anthropic::RequestContent> {
match content {
MessageContent::Text(text) => {
let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
text.trim_end().to_string()
} else {
text
};
if !text.is_empty() {
Some(anthropic::RequestContent::Text {
text,
cache_control: None,
})
} else {
None
}
}
MessageContent::Thinking {
text: thinking,
signature,
} => {
if let Some(signature) = signature
&& !thinking.is_empty()
{
Some(anthropic::RequestContent::Thinking {
thinking,
signature,
cache_control: None,
})
} else {
None
}
}
MessageContent::RedactedThinking(data) => {
if !data.is_empty() {
Some(anthropic::RequestContent::RedactedThinking { data })
} else {
None
}
}
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control: None,
}),
MessageContent::ToolUse(tool_use) => Some(anthropic::RequestContent::ToolUse {
id: tool_use.id.to_string(),
name: tool_use.name.to_string(),
input: tool_use.input,
cache_control: None,
}),
MessageContent::ToolResult(tool_result) => Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id.to_string(),
is_error: tool_result.is_error,
content: match tool_result.content {
LanguageModelToolResultContent::Text(text) => {
ToolResultContent::Plain(text.to_string())
}
LanguageModelToolResultContent::Image(image) => {
ToolResultContent::Multipart(vec![ToolResultPart::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
}])
}
},
cache_control: None,
}),
}
}
/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
pub fn into_anthropic_count_tokens_request(
request: LanguageModelRequest,
model: String,
mode: AnthropicModelMode,
) -> CountTokensRequest {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
for message in request.messages {
if message.contents_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let anthropic_message_content: Vec<anthropic::RequestContent> = message
.content
.into_iter()
.filter_map(to_anthropic_content)
.collect();
let anthropic_role = match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if anthropic_message_content.is_empty() {
continue;
}
if let Some(last_message) = new_messages.last_mut()
&& last_message.role == anthropic_role
{
last_message.content.extend(anthropic_message_content);
continue;
}
new_messages.push(anthropic::Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
CountTokensRequest {
model,
messages: new_messages,
system: if system_message.is_empty() {
None
} else {
Some(anthropic::StringOrContents::String(system_message))
},
thinking: if request.thinking_allowed {
match mode {
AnthropicModelMode::Thinking { budget_tokens } => {
Some(anthropic::Thinking::Enabled { budget_tokens })
}
AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
AnthropicModelMode::Default => None,
}
} else {
None
},
tools: request
.tools
.into_iter()
.map(|tool| anthropic::Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
eager_input_streaming: tool.use_input_streaming,
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
}),
}
}
/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
let messages = request.messages;
let mut tokens_from_images = 0;
let mut string_messages = Vec::with_capacity(messages.len());
for message in messages {
let mut string_contents = String::new();
for content in message.content {
match content {
MessageContent::Text(text) => {
string_contents.push_str(&text);
}
MessageContent::Thinking { .. } => {
// Thinking blocks are not included in the input token count.
}
MessageContent::RedactedThinking(_) => {
// Thinking blocks are not included in the input token count.
}
MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
MessageContent::ToolUse(_tool_use) => {
// TODO: Estimate token usage from tool uses.
}
MessageContent::ToolResult(tool_result) => match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
string_contents.push_str(text);
}
LanguageModelToolResultContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
},
}
}
if !string_contents.is_empty() {
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(string_contents),
name: None,
function_call: None,
});
}
}
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
.map(|tokens| (tokens + tokens_from_images) as u64)
}
impl AnthropicModel {
fn stream_completion(
&self,
@ -617,10 +391,13 @@ impl LanguageModel for AnthropicModel {
)
});
let background = cx.background_executor().clone();
async move {
// If no API key, fall back to tiktoken estimation
let Some(api_key) = api_key else {
return count_anthropic_tokens_with_tiktoken(request);
return background
.spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
.await;
};
let count_request =
@ -634,7 +411,9 @@ impl LanguageModel for AnthropicModel {
log::error!(
"Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
);
count_anthropic_tokens_with_tiktoken(request)
background
.spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
.await
}
}
}
@ -678,345 +457,6 @@ impl LanguageModel for AnthropicModel {
}
}
pub fn into_anthropic(
request: LanguageModelRequest,
model: String,
default_temperature: f32,
max_output_tokens: u64,
mode: AnthropicModelMode,
) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
for message in request.messages {
if message.contents_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
.content
.into_iter()
.filter_map(to_anthropic_content)
.collect();
let anthropic_role = match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if anthropic_message_content.is_empty() {
continue;
}
if let Some(last_message) = new_messages.last_mut()
&& last_message.role == anthropic_role
{
last_message.content.extend(anthropic_message_content);
continue;
}
// Mark the last segment of the message as cached
if message.cache {
let cache_control_value = Some(anthropic::CacheControl {
cache_type: anthropic::CacheControlType::Ephemeral,
});
for message_content in anthropic_message_content.iter_mut().rev() {
match message_content {
anthropic::RequestContent::RedactedThinking { .. } => {
// Caching is not possible, fallback to next message
}
anthropic::RequestContent::Text { cache_control, .. }
| anthropic::RequestContent::Thinking { cache_control, .. }
| anthropic::RequestContent::Image { cache_control, .. }
| anthropic::RequestContent::ToolUse { cache_control, .. }
| anthropic::RequestContent::ToolResult { cache_control, .. } => {
*cache_control = cache_control_value;
break;
}
}
}
}
new_messages.push(anthropic::Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
anthropic::Request {
model,
messages: new_messages,
max_tokens: max_output_tokens,
system: if system_message.is_empty() {
None
} else {
Some(anthropic::StringOrContents::String(system_message))
},
thinking: if request.thinking_allowed {
match mode {
AnthropicModelMode::Thinking { budget_tokens } => {
Some(anthropic::Thinking::Enabled { budget_tokens })
}
AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
AnthropicModelMode::Default => None,
}
} else {
None
},
tools: request
.tools
.into_iter()
.map(|tool| anthropic::Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
eager_input_streaming: tool.use_input_streaming,
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
}),
metadata: None,
output_config: if request.thinking_allowed
&& matches!(mode, AnthropicModelMode::AdaptiveThinking)
{
request.thinking_effort.as_deref().and_then(|effort| {
let effort = match effort {
"low" => Some(anthropic::Effort::Low),
"medium" => Some(anthropic::Effort::Medium),
"high" => Some(anthropic::Effort::High),
"max" => Some(anthropic::Effort::Max),
_ => None,
};
effort.map(|effort| anthropic::OutputConfig {
effort: Some(effort),
})
})
} else {
None
},
stop_sequences: Vec::new(),
speed: request.speed.map(From::from),
temperature: request.temperature.or(Some(default_temperature)),
top_k: None,
top_p: None,
}
}
pub struct AnthropicEventMapper {
tool_uses_by_index: HashMap<usize, RawToolUse>,
usage: Usage,
stop_reason: StopReason,
}
impl AnthropicEventMapper {
pub fn new() -> Self {
Self {
tool_uses_by_index: HashMap::default(),
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(error.into())],
})
})
}
pub fn map_event(
&mut self,
event: Event,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
vec![Ok(LanguageModelCompletionEvent::Text(text))]
}
ResponseContent::Thinking { thinking } => {
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking,
signature: None,
})]
}
ResponseContent::RedactedThinking { data } => {
vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
}
ResponseContent::ToolUse { id, name, .. } => {
self.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
Vec::new()
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
vec![Ok(LanguageModelCompletionEvent::Text(text))]
}
ContentDelta::ThinkingDelta { thinking } => {
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: thinking,
signature: None,
})]
}
ContentDelta::SignatureDelta { signature } => {
vec![Ok(LanguageModelCompletionEvent::Thinking {
text: "".to_string(),
signature: Some(signature),
})]
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
// Try to convert invalid (incomplete) JSON into
// valid JSON that serde can accept, e.g. by closing
// unclosed delimiters. This way, we can update the
// UI with whatever has been streamed back so far.
if let Ok(input) =
serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
{
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.clone().into(),
name: tool_use.name.clone().into(),
is_input_complete: false,
raw_input: tool_use.input_json.clone(),
input,
thought_signature: None,
},
))];
}
}
vec![]
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim();
let event_result = match parse_tool_arguments(input_json) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
thought_signature: None,
},
)),
Err(json_parse_err) => {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
vec![event_result]
} else {
Vec::new()
}
}
Event::MessageStart { message } => {
update_usage(&mut self.usage, &message.usage);
vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&self.usage,
))),
Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id,
}),
]
}
Event::MessageDelta { delta, usage } => {
update_usage(&mut self.usage, &usage);
if let Some(stop_reason) = delta.stop_reason.as_deref() {
self.stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
"refusal" => StopReason::Refusal,
_ => {
log::error!("Unexpected anthropic stop_reason: {stop_reason}");
StopReason::EndTurn
}
};
}
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
))]
}
Event::MessageStop => {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
}
Event::Error { error } => {
vec![Err(error.into())]
}
_ => Vec::new(),
}
}
}
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {
usage.input_tokens = Some(input_tokens);
}
if let Some(output_tokens) = new.output_tokens {
usage.output_tokens = Some(output_tokens);
}
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
}
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
}
}
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
language_model::TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
}
}
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
@ -1157,192 +597,3 @@ impl Render for ConfigurationView {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use anthropic::AnthropicModelMode;
use language_model::{LanguageModelRequestMessage, MessageContent};
#[test]
fn test_cache_control_only_on_last_segment() {
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![
MessageContent::Text("Some prompt".to_string()),
MessageContent::Image(language_model::LanguageModelImage::empty()),
MessageContent::Image(language_model::LanguageModelImage::empty()),
MessageContent::Image(language_model::LanguageModelImage::empty()),
MessageContent::Image(language_model::LanguageModelImage::empty()),
],
cache: true,
reasoning_details: None,
}],
thread_id: None,
prompt_id: None,
intent: None,
stop: vec![],
temperature: None,
tools: vec![],
tool_choice: None,
thinking_allowed: true,
thinking_effort: None,
speed: None,
};
let anthropic_request = into_anthropic(
request,
"claude-3-5-sonnet".to_string(),
0.7,
4096,
AnthropicModelMode::Default,
);
assert_eq!(anthropic_request.messages.len(), 1);
let message = &anthropic_request.messages[0];
assert_eq!(message.content.len(), 5);
assert!(matches!(
message.content[0],
anthropic::RequestContent::Text {
cache_control: None,
..
}
));
for i in 1..3 {
assert!(matches!(
message.content[i],
anthropic::RequestContent::Image {
cache_control: None,
..
}
));
}
assert!(matches!(
message.content[4],
anthropic::RequestContent::Image {
cache_control: Some(anthropic::CacheControl {
cache_type: anthropic::CacheControlType::Ephemeral,
}),
..
}
));
}
fn request_with_assistant_content(
assistant_content: Vec<MessageContent>,
) -> anthropic::Request {
let mut request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text("Hello".to_string())],
cache: false,
reasoning_details: None,
}],
thinking_effort: None,
thread_id: None,
prompt_id: None,
intent: None,
stop: vec![],
temperature: None,
tools: vec![],
tool_choice: None,
thinking_allowed: true,
speed: None,
};
request.messages.push(LanguageModelRequestMessage {
role: Role::Assistant,
content: assistant_content,
cache: false,
reasoning_details: None,
});
into_anthropic(
request,
"claude-sonnet-4-5".to_string(),
1.0,
16000,
AnthropicModelMode::Thinking {
budget_tokens: Some(10000),
},
)
}
#[test]
fn test_unsigned_thinking_blocks_stripped() {
let result = request_with_assistant_content(vec![
MessageContent::Thinking {
text: "Cancelled mid-think, no signature".to_string(),
signature: None,
},
MessageContent::Text("Some response text".to_string()),
]);
let assistant_message = result
.messages
.iter()
.find(|m| m.role == anthropic::Role::Assistant)
.expect("assistant message should still exist");
assert_eq!(
assistant_message.content.len(),
1,
"Only the text content should remain; unsigned thinking block should be stripped"
);
assert!(matches!(
&assistant_message.content[0],
anthropic::RequestContent::Text { text, .. } if text == "Some response text"
));
}
#[test]
fn test_signed_thinking_blocks_preserved() {
let result = request_with_assistant_content(vec![
MessageContent::Thinking {
text: "Completed thinking".to_string(),
signature: Some("valid-signature".to_string()),
},
MessageContent::Text("Response".to_string()),
]);
let assistant_message = result
.messages
.iter()
.find(|m| m.role == anthropic::Role::Assistant)
.expect("assistant message should exist");
assert_eq!(
assistant_message.content.len(),
2,
"Both the signed thinking block and text should be preserved"
);
assert!(matches!(
&assistant_message.content[0],
anthropic::RequestContent::Thinking { thinking, signature, .. }
if thinking == "Completed thinking" && signature == "valid-signature"
));
}
#[test]
fn test_only_unsigned_thinking_block_omits_entire_message() {
let result = request_with_assistant_content(vec![MessageContent::Thinking {
text: "Cancelled before any text or signature".to_string(),
signature: None,
}]);
let assistant_messages: Vec<_> = result
.messages
.iter()
.filter(|m| m.role == anthropic::Role::Assistant)
.collect();
assert_eq!(
assistant_messages.len(),
0,
"An assistant message whose only content was an unsigned thinking block \
should be omitted entirely"
);
}
}

View file

@ -48,7 +48,7 @@ use ui_input::InputField;
use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
use language_model::util::{fix_streamed_json, parse_tool_arguments};
actions!(bedrock, [Tab, TabPrev]);

File diff suppressed because it is too large Load diff

View file

@ -32,7 +32,7 @@ use ui::prelude::*;
use util::debug_panic;
use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic};
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
@ -268,15 +268,15 @@ impl LanguageModel for CopilotChatLanguageModel {
levels
.iter()
.map(|level| {
let name: SharedString = match level.as_str() {
let name = match level.as_str() {
"low" => "Low".into(),
"medium" => "Medium".into(),
"high" => "High".into(),
_ => SharedString::from(level.clone()),
_ => language_model::SharedString::from(level.clone()),
};
LanguageModelEffortLevel {
name,
value: SharedString::from(level.clone()),
value: language_model::SharedString::from(level.clone()),
is_default: level == "high",
}
})

View file

@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");

View file

@ -1,32 +1,25 @@
use anyhow::{Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
use futures::{FutureExt, StreamExt, future::BoxFuture};
pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google};
use google_ai::{GenerateContentResponse, GoogleModelMode};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub use settings::GoogleAvailableModel as AvailableModel;
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::{
Arc, LazyLock,
atomic::{self, AtomicU64},
};
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
@ -394,369 +387,6 @@ impl LanguageModel for GoogleLanguageModel {
}
}
pub fn into_google(
mut request: LanguageModelRequest,
model_id: String,
mode: GoogleModelMode,
) -> google_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
.into_iter()
.flat_map(|content| match content {
language_model::MessageContent::Text(text) => {
if !text.is_empty() {
vec![Part::TextPart(google_ai::TextPart { text })]
} else {
vec![]
}
}
language_model::MessageContent::Thinking {
text: _,
signature: Some(signature),
} => {
if !signature.is_empty() {
vec![Part::ThoughtPart(google_ai::ThoughtPart {
thought: true,
thought_signature: signature,
})]
} else {
vec![]
}
}
language_model::MessageContent::Thinking { .. } => {
vec![]
}
language_model::MessageContent::RedactedThinking(_) => vec![],
language_model::MessageContent::Image(image) => {
vec![Part::InlineDataPart(google_ai::InlineDataPart {
inline_data: google_ai::GenerativeContentBlob {
mime_type: "image/png".to_string(),
data: image.source.to_string(),
},
})]
}
language_model::MessageContent::ToolUse(tool_use) => {
// Normalize empty string signatures to None
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
function_call: google_ai::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
thought_signature,
})]
}
language_model::MessageContent::ToolResult(tool_result) => {
match tool_result.content {
language_model::LanguageModelToolResultContent::Text(text) => {
vec![Part::FunctionResponsePart(
google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": text
}),
},
},
)]
}
language_model::LanguageModelToolResultContent::Image(image) => {
vec![
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": "Tool responded with an image"
}),
},
}),
Part::InlineDataPart(google_ai::InlineDataPart {
inline_data: google_ai::GenerativeContentBlob {
mime_type: "image/png".to_string(),
data: image.source.to_string(),
},
}),
]
}
}
}
})
.collect()
}
let system_instructions = if request
.messages
.first()
.is_some_and(|msg| matches!(msg.role, Role::System))
{
let message = request.messages.remove(0);
Some(SystemInstruction {
parts: map_content(message.content),
})
} else {
None
};
google_ai::GenerateContentRequest {
model: google_ai::ModelName { model_id },
system_instruction: system_instructions,
contents: request
.messages
.into_iter()
.filter_map(|message| {
let parts = map_content(message.content);
if parts.is_empty() {
None
} else {
Some(google_ai::Content {
parts,
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
})
}
})
.collect(),
generation_config: Some(google_ai::GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config: match (request.thinking_allowed, mode) {
(true, GoogleModelMode::Thinking { budget_tokens }) => {
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
}
_ => None,
},
top_p: None,
top_k: None,
}),
safety_settings: None,
tools: (!request.tools.is_empty()).then(|| {
vec![google_ai::Tool {
function_declarations: request
.tools
.into_iter()
.map(|tool| FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
})
.collect(),
}]
}),
tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
function_calling_config: google_ai::FunctionCallingConfig {
mode: match choice {
LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
},
allowed_function_names: None,
},
}),
}
}
pub struct GoogleEventMapper {
usage: UsageMetadata,
stop_reason: StopReason,
}
impl GoogleEventMapper {
pub fn new() -> Self {
Self {
usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn,
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events
.map(Some)
.chain(futures::stream::once(async { None }))
.flat_map(move |event| {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::from(error))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})
})
}
pub fn map_event(
&mut self,
event: GenerateContentResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
)))
}
if let Some(prompt_feedback) = event.prompt_feedback
&& let Some(block_reason) = prompt_feedback.block_reason.as_deref()
{
self.stop_reason = match block_reason {
"SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
StopReason::Refusal
}
_ => {
log::error!("Unexpected Google block_reason: {block_reason}");
StopReason::Refusal
}
};
events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
return events;
}
if let Some(candidates) = event.candidates {
for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
self.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
_ => {
log::error!("Unexpected google finish_reason: {finish_reason}");
StopReason::EndTurn
}
};
}
candidate
.content
.parts
.into_iter()
.for_each(|part| match part {
Part::TextPart(text_part) => {
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
}
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
let name: Arc<str> = function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
// Normalize empty string signatures to None
let thought_signature = function_call_part
.thought_signature
.filter(|s| !s.is_empty());
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
name,
is_input_complete: true,
raw_input: function_call_part.function_call.args.to_string(),
input: function_call_part.function_call.args,
thought_signature,
},
)));
}
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
signature: Some(part.thought_signature),
}));
}
});
}
}
// Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP`
if wants_to_use_tool {
self.stop_reason = StopReason::ToolUse;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
events
}
}
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
// We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
// So we have to use tokenizer from tiktoken_rs to count tokens.
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
}
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
if let Some(prompt_token_count) = new.prompt_token_count {
usage.prompt_token_count = Some(prompt_token_count);
}
if let Some(cached_content_token_count) = new.cached_content_token_count {
usage.cached_content_token_count = Some(cached_content_token_count);
}
if let Some(candidates_token_count) = new.candidates_token_count {
usage.candidates_token_count = Some(candidates_token_count);
}
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
}
if let Some(thoughts_token_count) = new.thoughts_token_count {
usage.thoughts_token_count = Some(thoughts_token_count);
}
if let Some(total_token_count) = new.total_token_count {
usage.total_token_count = Some(total_token_count);
}
}
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
let input_tokens = prompt_tokens - cached_tokens;
let output_tokens = usage.candidates_token_count.unwrap_or(0);
language_model::TokenUsage {
input_tokens,
output_tokens,
cache_read_input_tokens: cached_tokens,
cache_creation_input_tokens: 0,
}
}
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
@ -895,428 +525,3 @@ impl Render for ConfigurationView {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use google_ai::{
Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
Part, Role as GoogleRole, TextPart,
};
use language_model::{LanguageModelToolUseId, MessageContent, Role};
use serde_json::json;
#[test]
fn test_function_call_with_signature_creates_tool_use_with_signature() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("test_signature_123".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 2); // ToolUse event + Stop event
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "test_function");
assert_eq!(
tool_use.thought_signature.as_deref(),
Some("test_signature_123")
);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_function_call_without_signature_has_none() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: None,
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.thought_signature, None);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_empty_string_signature_normalized_to_none() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.thought_signature, None);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_parallel_function_calls_preserve_signatures() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "function_1".to_string(),
args: json!({"arg": "value1"}),
},
thought_signature: Some("signature_1".to_string()),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "function_2".to_string(),
args: json!({"arg": "value2"}),
},
thought_signature: None,
}),
],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "function_1");
assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
} else {
panic!("Expected ToolUse event for function_1");
}
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
assert_eq!(tool_use.name.as_ref(), "function_2");
assert_eq!(tool_use.thought_signature, None);
} else {
panic!("Expected ToolUse event for function_2");
}
}
#[test]
fn test_tool_use_with_signature_converts_to_function_call_part() {
let tool_use = language_model::LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_function".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("test_signature_456".to_string()),
};
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
reasoning_details: None,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
assert_eq!(request.contents[0].parts.len(), 1);
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.function_call.name, "test_function");
assert_eq!(
fc_part.thought_signature.as_deref(),
Some("test_signature_456")
);
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_tool_use_without_signature_omits_field() {
let tool_use = language_model::LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_function".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: None,
};
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
reasoning_details: None,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
assert_eq!(request.contents[0].parts.len(), 1);
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_empty_signature_in_tool_use_normalized_to_none() {
let tool_use = language_model::LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_function".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("".to_string()),
};
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
reasoning_details: None,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_round_trip_preserves_signature() {
let mut mapper = GoogleEventMapper::new();
// Simulate receiving a response from Google with a signature
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("round_trip_sig".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
tool_use.clone()
} else {
panic!("Expected ToolUse event");
};
// Convert back to Google format
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
reasoning_details: None,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
// Verify signature is preserved
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_mixed_text_and_function_call_with_signature() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![
Part::TextPart(TextPart {
text: "I'll help with that.".to_string(),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "helper_function".to_string(),
args: json!({"query": "help"}),
},
thought_signature: Some("mixed_sig".to_string()),
}),
],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
assert_eq!(text, "I'll help with that.");
} else {
panic!("Expected Text event");
}
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
assert_eq!(tool_use.name.as_ref(), "helper_function");
assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_special_characters_in_signature_preserved() {
let mut mapper = GoogleEventMapper::new();
let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some(signature_with_special_chars.clone()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(
tool_use.thought_signature.as_deref(),
Some(signature_with_special_chars.as_str())
);
} else {
panic!("Expected ToolUse event");
}
}
}

View file

@ -28,7 +28,7 @@ use ui::{
use ui_input::InputField;
use crate::AllLanguageModelSettings;
use crate::provider::util::parse_tool_arguments;
use language_model::util::parse_tool_arguments;
const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";

View file

@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");

File diff suppressed because it is too large Load diff

View file

@ -402,7 +402,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
self.model.capabilities.parallel_tool_calls,
self.model.capabilities.prompt_cache_key,
self.max_output_tokens(),
self.model.reasoning_effort.clone(),
self.model.reasoning_effort,
);
let completions = self.stream_completion(request, cx);
async move {
@ -417,7 +417,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
self.model.capabilities.parallel_tool_calls,
self.model.capabilities.prompt_cache_key,
self.max_output_tokens(),
self.model.reasoning_effort.clone(),
self.model.reasoning_effort,
);
let completions = self.stream_response(request, cx);
async move {

View file

@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");

View file

@ -9,7 +9,7 @@ use language_model::{
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
Role, env_var,
env_var,
};
use open_ai::ResponseStreamEvent;
pub use settings::XaiAvailableModel as AvailableModel;
@ -19,7 +19,8 @@ use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use x_ai::{Model, XAI_API_URL};
use x_ai::XAI_API_URL;
pub use x_ai::completion::count_xai_tokens;
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
@ -320,7 +321,9 @@ impl LanguageModel for XAiLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
count_xai_tokens(request, self.model.clone(), cx)
let model = self.model.clone();
cx.background_spawn(async move { count_xai_tokens(request, model) })
.boxed()
}
fn stream_completion(
@ -354,37 +357,6 @@ impl LanguageModel for XAiLanguageModel {
}
}
pub fn count_xai_tokens(
request: LanguageModelRequest,
model: Model,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
let model_name = if model.max_token_count() >= 100_000 {
"gpt-4o"
} else {
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
})
.boxed()
}
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,

View file

@ -0,0 +1,33 @@
[package]
name = "language_models_cloud"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/language_models_cloud.rs"
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
cloud_llm_client.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
language_model.workspace = true
open_ai = { workspace = true, features = ["schemars"] }
schemars.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
thiserror.workspace = true
x_ai = { workspace = true, features = ["schemars"] }
[dev-dependencies]
language_model = { workspace = true, features = ["test-support"] }

View file

@ -0,0 +1 @@
../../LICENSE-GPL

File diff suppressed because it is too large Load diff

View file

@ -17,13 +17,18 @@ schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
collections.workspace = true
futures.workspace = true
http_client.workspace = true
language_model_core.workspace = true
rand.workspace = true
schemars = { workspace = true, optional = true }
log.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
strum.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,5 @@
pub mod batches;
pub mod completion;
pub mod responses;
use anyhow::{Context as _, Result, anyhow};
@ -7,9 +8,9 @@ use http_client::{
AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
http::{HeaderMap, HeaderValue},
};
pub use language_model_core::ReasoningEffort;
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use settings::OpenAiReasoningEffort as ReasoningEffort;
use std::{convert::TryFrom, future::Future};
use strum::EnumIter;
use thiserror::Error;
@ -717,3 +718,26 @@ pub fn embed<'a>(
Ok(response)
}
}
// -- Conversions to `language_model_core` types --
impl From<RequestError> for language_model_core::LanguageModelCompletionError {
fn from(error: RequestError) -> Self {
match error {
RequestError::HttpResponseError {
provider,
status_code,
body,
headers,
} => {
let retry_after = headers
.get(http_client::http::header::RETRY_AFTER)
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
.map(std::time::Duration::from_secs);
Self::from_http_status(provider.into(), status_code, body, retry_after)
}
RequestError::Other(e) => Self::Other(e),
}
}
}

View file

@ -19,6 +19,7 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
language_model_core.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true

View file

@ -744,3 +744,71 @@ impl ApiErrorCode {
}
}
}
// -- Conversions to `language_model_core` types --
impl From<OpenRouterError> for language_model_core::LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self {
let provider = language_model_core::LanguageModelProviderName::new("OpenRouter");
match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
OpenRouterError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
OpenRouterError::ApiError(api_error) => api_error.into(),
}
}
}
impl From<ApiError> for language_model_core::LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
use ApiErrorCode::*;
let provider = language_model_core::LanguageModelProviderName::new("OpenRouter");
match error.code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message),
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
RequestTimedOut => Self::HttpResponseError {
provider,
status_code: http_client::StatusCode::REQUEST_TIMEOUT,
message: error.message,
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
}
}
}

View file

@ -412,7 +412,7 @@ impl PrettierStore {
prettier_store
.update(cx, |prettier_store, cx| {
let name = if is_default {
LanguageServerName("prettier (default)".to_string().into())
LanguageServerName("prettier (default)".into())
} else {
let worktree_path = worktree_id
.and_then(|id| {

View file

@ -19,6 +19,7 @@ anyhow.workspace = true
collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
language_model_core.workspace = true
log.workspace = true
schemars.workspace = true
serde.workspace = true

View file

@ -1,8 +1,8 @@
use crate::merge_from::MergeFrom;
use collections::HashMap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings_macros::{MergeFrom, with_fallible_options};
use strum::EnumString;
use std::sync::Arc;
@ -237,15 +237,12 @@ pub struct OpenAiAvailableModel {
pub capabilities: OpenAiModelCapabilities,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, EnumString, JsonSchema, MergeFrom)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase")]
pub enum OpenAiReasoningEffort {
Minimal,
Low,
Medium,
High,
XHigh,
pub use language_model_core::ReasoningEffort as OpenAiReasoningEffort;
impl MergeFrom for OpenAiReasoningEffort {
fn merge_from(&mut self, other: &Self) {
*self = *other;
}
}
#[with_fallible_options]
@ -479,15 +476,10 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: u64,
}
#[derive(
Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, MergeFrom,
)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
budget_tokens: Option<u32>,
},
pub use language_model_core::ModelMode;
impl MergeFrom for ModelMode {
fn merge_from(&mut self, other: &Self) {
*self = *other;
}
}

View file

@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
futures.workspace = true

View file

@ -2,12 +2,12 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
use cloud_api_client::LlmApiToken;
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
use language_model::LlmApiToken;
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {

View file

@ -17,6 +17,8 @@ schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
language_model_core.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
strum.workspace = true
tiktoken-rs.workspace = true

View file

@ -0,0 +1,30 @@
use anyhow::Result;
use language_model_core::{LanguageModelRequest, Role};
use crate::Model;
/// Count tokens for an xAI model using tiktoken. This is synchronous;
/// callers should spawn it on a background thread if needed.
pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result<u64> {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
let model_name = if model.max_token_count() >= 100_000 {
"gpt-4o"
} else {
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
}

View file

@ -1,3 +1,5 @@
pub mod completion;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use strum::EnumIter;