mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
language_models: Refactor deps and extract cloud (#53270)
- `language_model` no longer depends on provider-specific crates such as `anthropic` and `open_ai` (inverted dependency) - `language_model_core` was extracted from `language_model` which contains the types for the provider-specific crates to convert to/from. - `gpui::SharedString` has been extracted into its own crate (still exposed by `gpui`), so `language_model_core` and provider API crates don't have to depend on `gpui`. - Removes some unnecessary `&'static str` | `SharedString` -> `String` -> `SharedString` conversions across the codebase. - Extracts the core logic of the cloud `LanguageModelProvider` into its own crate with simpler dependencies. Release Notes: - N/A --------- Co-authored-by: John Tur <john-tur@outlook.com>
This commit is contained in:
parent
a856093cca
commit
98c17ca160
95 changed files with 5895 additions and 5995 deletions
89
Cargo.lock
generated
89
Cargo.lock
generated
|
|
@ -629,13 +629,17 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"collections",
|
||||
"futures 0.3.32",
|
||||
"http_client",
|
||||
"language_model_core",
|
||||
"log",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.2",
|
||||
"thiserror 2.0.17",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2903,7 +2907,6 @@ dependencies = [
|
|||
"http_client",
|
||||
"http_client_tls",
|
||||
"httparse",
|
||||
"language_model",
|
||||
"log",
|
||||
"objc2-foundation",
|
||||
"parking_lot",
|
||||
|
|
@ -2959,6 +2962,7 @@ dependencies = [
|
|||
"http_client",
|
||||
"parking_lot",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"thiserror 2.0.17",
|
||||
"yawc",
|
||||
]
|
||||
|
|
@ -5162,6 +5166,7 @@ dependencies = [
|
|||
"buffer_diff",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_api_client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
|
|
@ -5641,7 +5646,7 @@ dependencies = [
|
|||
name = "env_var"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"gpui",
|
||||
"gpui_shared_string",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -7468,11 +7473,13 @@ dependencies = [
|
|||
"anyhow",
|
||||
"futures 0.3.32",
|
||||
"http_client",
|
||||
"language_model_core",
|
||||
"log",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"strum 0.27.2",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -7541,6 +7548,7 @@ dependencies = [
|
|||
"getrandom 0.3.4",
|
||||
"gpui_macros",
|
||||
"gpui_platform",
|
||||
"gpui_shared_string",
|
||||
"gpui_util",
|
||||
"gpui_web",
|
||||
"http_client",
|
||||
|
|
@ -7710,6 +7718,16 @@ dependencies = [
|
|||
"gpui_windows",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gpui_shared_string"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"derive_more",
|
||||
"gpui_util",
|
||||
"schemars",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gpui_tokio"
|
||||
version = "0.1.0"
|
||||
|
|
@ -9358,7 +9376,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"gpui",
|
||||
"gpui_shared_string",
|
||||
"log",
|
||||
"lsp",
|
||||
"parking_lot",
|
||||
|
|
@ -9397,12 +9415,8 @@ dependencies = [
|
|||
name = "language_model"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anthropic",
|
||||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
"cloud_api_client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"env_var",
|
||||
|
|
@ -9411,16 +9425,31 @@ dependencies = [
|
|||
"http_client",
|
||||
"icons",
|
||||
"image",
|
||||
"language_model_core",
|
||||
"log",
|
||||
"open_ai",
|
||||
"open_router",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "language_model_core"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cloud_llm_client",
|
||||
"futures 0.3.32",
|
||||
"gpui_shared_string",
|
||||
"http_client",
|
||||
"partial-json-fixer",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"strum 0.27.2",
|
||||
"thiserror 2.0.17",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -9436,8 +9465,8 @@ dependencies = [
|
|||
"base64 0.22.1",
|
||||
"bedrock",
|
||||
"client",
|
||||
"cloud_api_client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"component",
|
||||
"convert_case 0.8.0",
|
||||
|
|
@ -9456,6 +9485,7 @@ dependencies = [
|
|||
"http_client",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models_cloud",
|
||||
"lmstudio",
|
||||
"log",
|
||||
"menu",
|
||||
|
|
@ -9464,17 +9494,14 @@ dependencies = [
|
|||
"open_ai",
|
||||
"open_router",
|
||||
"opencode",
|
||||
"partial-json-fixer",
|
||||
"pretty_assertions",
|
||||
"release_channel",
|
||||
"schemars",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"strum 0.27.2",
|
||||
"thiserror 2.0.17",
|
||||
"tiktoken-rs",
|
||||
"tokio",
|
||||
"ui",
|
||||
|
|
@ -9484,6 +9511,28 @@ dependencies = [
|
|||
"x_ai",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "language_models_cloud"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anthropic",
|
||||
"anyhow",
|
||||
"cloud_llm_client",
|
||||
"futures 0.3.32",
|
||||
"google_ai",
|
||||
"gpui",
|
||||
"http_client",
|
||||
"language_model",
|
||||
"open_ai",
|
||||
"schemars",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"thiserror 2.0.17",
|
||||
"x_ai",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "language_onboarding"
|
||||
version = "0.1.0"
|
||||
|
|
@ -11631,16 +11680,19 @@ name = "open_ai"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"futures 0.3.32",
|
||||
"http_client",
|
||||
"language_model_core",
|
||||
"log",
|
||||
"pretty_assertions",
|
||||
"rand 0.9.2",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"strum 0.27.2",
|
||||
"thiserror 2.0.17",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -11672,6 +11724,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"futures 0.3.32",
|
||||
"http_client",
|
||||
"language_model_core",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -15801,6 +15854,7 @@ dependencies = [
|
|||
"collections",
|
||||
"derive_more",
|
||||
"gpui",
|
||||
"language_model_core",
|
||||
"log",
|
||||
"schemars",
|
||||
"serde",
|
||||
|
|
@ -20180,6 +20234,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"cloud_api_client",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"futures 0.3.32",
|
||||
|
|
@ -21783,9 +21838,11 @@ name = "x_ai"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"language_model_core",
|
||||
"schemars",
|
||||
"serde",
|
||||
"strum 0.27.2",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ members = [
|
|||
"crates/google_ai",
|
||||
"crates/grammars",
|
||||
"crates/gpui",
|
||||
"crates/gpui_shared_string",
|
||||
"crates/gpui_linux",
|
||||
"crates/gpui_macos",
|
||||
"crates/gpui_macros",
|
||||
|
|
@ -110,7 +111,9 @@ members = [
|
|||
"crates/language_core",
|
||||
"crates/language_extension",
|
||||
"crates/language_model",
|
||||
"crates/language_model_core",
|
||||
"crates/language_models",
|
||||
"crates/language_models_cloud",
|
||||
"crates/language_onboarding",
|
||||
"crates/language_selector",
|
||||
"crates/language_tools",
|
||||
|
|
@ -335,6 +338,7 @@ go_to_line = { path = "crates/go_to_line" }
|
|||
google_ai = { path = "crates/google_ai" }
|
||||
grammars = { path = "crates/grammars" }
|
||||
gpui = { path = "crates/gpui", default-features = false }
|
||||
gpui_shared_string = { path = "crates/gpui_shared_string" }
|
||||
gpui_linux = { path = "crates/gpui_linux", default-features = false }
|
||||
gpui_macos = { path = "crates/gpui_macos", default-features = false }
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
|
|
@ -361,7 +365,9 @@ language = { path = "crates/language" }
|
|||
language_core = { path = "crates/language_core" }
|
||||
language_extension = { path = "crates/language_extension" }
|
||||
language_model = { path = "crates/language_model" }
|
||||
language_model_core = { path = "crates/language_model_core" }
|
||||
language_models = { path = "crates/language_models" }
|
||||
language_models_cloud = { path = "crates/language_models_cloud" }
|
||||
language_onboarding = { path = "crates/language_onboarding" }
|
||||
language_selector = { path = "crates/language_selector" }
|
||||
language_tools = { path = "crates/language_tools" }
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use futures::FutureExt as _;
|
|||
use gpui::{App, Entity, SharedString, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::Point;
|
||||
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
|
||||
use language_model::{LanguageModelImage, LanguageModelImageExt, LanguageModelToolResultContent};
|
||||
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ impl AcpConnection {
|
|||
// Use the one the agent provides if we have one
|
||||
.map(|info| info.name.into())
|
||||
// Otherwise, just use the name
|
||||
.unwrap_or_else(|| agent_id.0.to_string().into());
|
||||
.unwrap_or_else(|| agent_id.0.clone());
|
||||
|
||||
let session_list = if response
|
||||
.agent_capabilities
|
||||
|
|
|
|||
|
|
@ -382,7 +382,7 @@ impl AgentRegistryPage {
|
|||
self.install_button(agent, install_status, supports_current_platform, cx);
|
||||
|
||||
let repository_button = agent.repository().map(|repository| {
|
||||
let repository_for_tooltip: SharedString = repository.to_string().into();
|
||||
let repository_for_tooltip = repository.clone();
|
||||
let repository_for_click = repository.to_string();
|
||||
|
||||
IconButton::new(
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ use gpui::{
|
|||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use itertools::Either;
|
||||
use language::Buffer;
|
||||
use language_model::LanguageModelImage;
|
||||
use language_model::{LanguageModelImage, LanguageModelImageExt};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use postage::stream::Stream as _;
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
|
|
|
|||
|
|
@ -18,12 +18,16 @@ path = "src/anthropic.rs"
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model_core.workspace = true
|
||||
log.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ use strum::{EnumIter, EnumString};
|
|||
use thiserror::Error;
|
||||
|
||||
pub mod batches;
|
||||
pub mod completion;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
|
|
@ -1026,6 +1027,89 @@ pub async fn count_tokens(
|
|||
}
|
||||
}
|
||||
|
||||
// -- Conversions from/to `language_model_core` types --
|
||||
|
||||
impl From<language_model_core::Speed> for Speed {
|
||||
fn from(speed: language_model_core::Speed) -> Self {
|
||||
match speed {
|
||||
language_model_core::Speed::Standard => Speed::Standard,
|
||||
language_model_core::Speed::Fast => Speed::Fast,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AnthropicError> for language_model_core::LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
let provider = language_model_core::ANTHROPIC_PROVIDER_NAME;
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
AnthropicError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
AnthropicError::HttpResponseError {
|
||||
status_code,
|
||||
message,
|
||||
} => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ApiError> for language_model_core::LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
use ApiErrorCode::*;
|
||||
let provider = language_model_core::ANTHROPIC_PROVIDER_NAME;
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||
RequestTooLarge => Self::PromptTooLarge {
|
||||
tokens: language_model_core::parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
},
|
||||
None => Self::Other(error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_match_window_exceeded() {
|
||||
let error = ApiError {
|
||||
|
|
|
|||
765
crates/anthropic/src/completion.rs
Normal file
765
crates/anthropic/src/completion.rs
Normal file
|
|
@ -0,0 +1,765 @@
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{Stream, StreamExt};
|
||||
use language_model_core::{
|
||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
Role, StopReason, TokenUsage,
|
||||
util::{fix_streamed_json, parse_tool_arguments},
|
||||
};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta,
|
||||
CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent,
|
||||
StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage,
|
||||
};
|
||||
|
||||
fn to_anthropic_content(content: MessageContent) -> Option<RequestContent> {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
|
||||
text.trim_end().to_string()
|
||||
} else {
|
||||
text
|
||||
};
|
||||
if !text.is_empty() {
|
||||
Some(RequestContent::Text {
|
||||
text,
|
||||
cache_control: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking {
|
||||
text: thinking,
|
||||
signature,
|
||||
} => {
|
||||
if let Some(signature) = signature
|
||||
&& !thinking.is_empty()
|
||||
{
|
||||
Some(RequestContent::Thinking {
|
||||
thinking,
|
||||
signature,
|
||||
cache_control: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(data) => {
|
||||
if !data.is_empty() {
|
||||
Some(RequestContent::RedactedThinking { data })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => Some(RequestContent::Image {
|
||||
source: ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
cache_control: None,
|
||||
}),
|
||||
MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse {
|
||||
id: tool_use.id.to_string(),
|
||||
name: tool_use.name.to_string(),
|
||||
input: tool_use.input,
|
||||
cache_control: None,
|
||||
}),
|
||||
MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult {
|
||||
tool_use_id: tool_result.tool_use_id.to_string(),
|
||||
is_error: tool_result.is_error,
|
||||
content: match tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
ToolResultContent::Plain(text.to_string())
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
ToolResultContent::Multipart(vec![ToolResultPart::Image {
|
||||
source: ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
}])
|
||||
}
|
||||
},
|
||||
cache_control: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
|
||||
pub fn into_anthropic_count_tokens_request(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
mode: AnthropicModelMode,
|
||||
) -> CountTokensRequest {
|
||||
let mut new_messages: Vec<Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages {
|
||||
if message.contents_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let anthropic_message_content: Vec<RequestContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(to_anthropic_content)
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => crate::Role::User,
|
||||
Role::Assistant => crate::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if anthropic_message_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(last_message) = new_messages.last_mut()
|
||||
&& last_message.role == anthropic_role
|
||||
{
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
|
||||
new_messages.push(Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.string_contents());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CountTokensRequest {
|
||||
model,
|
||||
messages: new_messages,
|
||||
system: if system_message.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(StringOrContents::String(system_message))
|
||||
},
|
||||
thinking: if request.thinking_allowed {
|
||||
match mode {
|
||||
AnthropicModelMode::Thinking { budget_tokens } => {
|
||||
Some(Thinking::Enabled { budget_tokens })
|
||||
}
|
||||
AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
|
||||
AnthropicModelMode::Default => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| Tool {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
eager_input_streaming: tool.use_input_streaming,
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||
LanguageModelToolChoice::Auto => ToolChoice::Auto,
|
||||
LanguageModelToolChoice::Any => ToolChoice::Any,
|
||||
LanguageModelToolChoice::None => ToolChoice::None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
|
||||
/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
|
||||
pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
|
||||
let messages = request.messages;
|
||||
let mut tokens_from_images = 0;
|
||||
let mut string_messages = Vec::with_capacity(messages.len());
|
||||
|
||||
for message in messages {
|
||||
let mut string_contents = String::new();
|
||||
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
string_contents.push_str(&text);
|
||||
}
|
||||
MessageContent::Thinking { .. } => {
|
||||
// Thinking blocks are not included in the input token count.
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {
|
||||
// Thinking blocks are not included in the input token count.
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
MessageContent::ToolUse(_tool_use) => {
|
||||
// TODO: Estimate token usage from tool uses.
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
string_contents.push_str(text);
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if !string_contents.is_empty() {
|
||||
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(string_contents),
|
||||
name: None,
|
||||
function_call: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
|
||||
.map(|tokens| (tokens + tokens_from_images) as u64)
|
||||
}
|
||||
|
||||
pub fn into_anthropic(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
default_temperature: f32,
|
||||
max_output_tokens: u64,
|
||||
mode: AnthropicModelMode,
|
||||
) -> crate::Request {
|
||||
let mut new_messages: Vec<Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages {
|
||||
if message.contents_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let mut anthropic_message_content: Vec<RequestContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(to_anthropic_content)
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => crate::Role::User,
|
||||
Role::Assistant => crate::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if anthropic_message_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(last_message) = new_messages.last_mut()
|
||||
&& last_message.role == anthropic_role
|
||||
{
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Mark the last segment of the message as cached
|
||||
if message.cache {
|
||||
let cache_control_value = Some(CacheControl {
|
||||
cache_type: CacheControlType::Ephemeral,
|
||||
});
|
||||
for message_content in anthropic_message_content.iter_mut().rev() {
|
||||
match message_content {
|
||||
RequestContent::RedactedThinking { .. } => {
|
||||
// Caching is not possible, fallback to next message
|
||||
}
|
||||
RequestContent::Text { cache_control, .. }
|
||||
| RequestContent::Thinking { cache_control, .. }
|
||||
| RequestContent::Image { cache_control, .. }
|
||||
| RequestContent::ToolUse { cache_control, .. }
|
||||
| RequestContent::ToolResult { cache_control, .. } => {
|
||||
*cache_control = cache_control_value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.string_contents());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crate::Request {
|
||||
model,
|
||||
messages: new_messages,
|
||||
max_tokens: max_output_tokens,
|
||||
system: if system_message.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(StringOrContents::String(system_message))
|
||||
},
|
||||
thinking: if request.thinking_allowed {
|
||||
match mode {
|
||||
AnthropicModelMode::Thinking { budget_tokens } => {
|
||||
Some(Thinking::Enabled { budget_tokens })
|
||||
}
|
||||
AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
|
||||
AnthropicModelMode::Default => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| Tool {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
eager_input_streaming: tool.use_input_streaming,
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||
LanguageModelToolChoice::Auto => ToolChoice::Auto,
|
||||
LanguageModelToolChoice::Any => ToolChoice::Any,
|
||||
LanguageModelToolChoice::None => ToolChoice::None,
|
||||
}),
|
||||
metadata: None,
|
||||
output_config: if request.thinking_allowed
|
||||
&& matches!(mode, AnthropicModelMode::AdaptiveThinking)
|
||||
{
|
||||
request.thinking_effort.as_deref().and_then(|effort| {
|
||||
let effort = match effort {
|
||||
"low" => Some(crate::Effort::Low),
|
||||
"medium" => Some(crate::Effort::Medium),
|
||||
"high" => Some(crate::Effort::High),
|
||||
"max" => Some(crate::Effort::Max),
|
||||
_ => None,
|
||||
};
|
||||
effort.map(|effort| crate::OutputConfig {
|
||||
effort: Some(effort),
|
||||
})
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
stop_sequences: Vec::new(),
|
||||
speed: request.speed.map(Into::into),
|
||||
temperature: request.temperature.or(Some(default_temperature)),
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicEventMapper {
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
usage: Usage,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
impl AnthropicEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
usage: Usage::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(error.into())],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: Event,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
match event {
|
||||
Event::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Text(text))]
|
||||
}
|
||||
ResponseContent::Thinking { thinking } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thinking,
|
||||
signature: None,
|
||||
})]
|
||||
}
|
||||
ResponseContent::RedactedThinking { data } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
self.tool_uses_by_index.insert(
|
||||
index,
|
||||
RawToolUse {
|
||||
id,
|
||||
name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Text(text))]
|
||||
}
|
||||
ContentDelta::ThinkingDelta { thinking } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thinking,
|
||||
signature: None,
|
||||
})]
|
||||
}
|
||||
ContentDelta::SignatureDelta { signature } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".to_string(),
|
||||
signature: Some(signature),
|
||||
})]
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
|
||||
// Try to convert invalid (incomplete) JSON into
|
||||
// valid JSON that serde can accept, e.g. by closing
|
||||
// unclosed delimiters. This way, we can update the
|
||||
// UI with whatever has been streamed back so far.
|
||||
if let Ok(input) =
|
||||
serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
|
||||
{
|
||||
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.clone().into(),
|
||||
name: tool_use.name.clone().into(),
|
||||
is_input_complete: false,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input,
|
||||
thought_signature: None,
|
||||
},
|
||||
))];
|
||||
}
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
|
||||
let input_json = tool_use.input_json.trim();
|
||||
let event_result = match parse_tool_arguments(input_json) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(json_parse_err) => {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_use.id.into(),
|
||||
tool_name: tool_use.name.into(),
|
||||
raw_input: input_json.into(),
|
||||
json_parse_error: json_parse_err.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
vec![event_result]
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
Event::MessageStart { message } => {
|
||||
update_usage(&mut self.usage, &message.usage);
|
||||
vec![
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
||||
&self.usage,
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
}),
|
||||
]
|
||||
}
|
||||
Event::MessageDelta { delta, usage } => {
|
||||
update_usage(&mut self.usage, &usage);
|
||||
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
||||
self.stop_reason = match stop_reason {
|
||||
"end_turn" => StopReason::EndTurn,
|
||||
"max_tokens" => StopReason::MaxTokens,
|
||||
"tool_use" => StopReason::ToolUse,
|
||||
"refusal" => StopReason::Refusal,
|
||||
_ => {
|
||||
log::error!("Unexpected anthropic stop_reason: {stop_reason}");
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
}
|
||||
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
))]
|
||||
}
|
||||
Event::MessageStop => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||
}
|
||||
Event::Error { error } => {
|
||||
vec![Err(error.into())]
|
||||
}
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
/// Updates usage data by preferring counts from `new`.
|
||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||
if let Some(input_tokens) = new.input_tokens {
|
||||
usage.input_tokens = Some(input_tokens);
|
||||
}
|
||||
if let Some(output_tokens) = new.output_tokens {
|
||||
usage.output_tokens = Some(output_tokens);
|
||||
}
|
||||
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
|
||||
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
|
||||
}
|
||||
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
|
||||
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_usage(usage: &Usage) -> TokenUsage {
|
||||
TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
|
||||
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::AnthropicModelMode;
|
||||
use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
|
||||
|
||||
#[test]
|
||||
fn test_cache_control_only_on_last_segment() {
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
MessageContent::Text("Some prompt".to_string()),
|
||||
MessageContent::Image(LanguageModelImage::empty()),
|
||||
MessageContent::Image(LanguageModelImage::empty()),
|
||||
MessageContent::Image(LanguageModelImage::empty()),
|
||||
MessageContent::Image(LanguageModelImage::empty()),
|
||||
],
|
||||
cache: true,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thinking_allowed: true,
|
||||
thinking_effort: None,
|
||||
speed: None,
|
||||
};
|
||||
|
||||
let anthropic_request = into_anthropic(
|
||||
request,
|
||||
"claude-3-5-sonnet".to_string(),
|
||||
0.7,
|
||||
4096,
|
||||
AnthropicModelMode::Default,
|
||||
);
|
||||
|
||||
assert_eq!(anthropic_request.messages.len(), 1);
|
||||
|
||||
let message = &anthropic_request.messages[0];
|
||||
assert_eq!(message.content.len(), 5);
|
||||
|
||||
assert!(matches!(
|
||||
message.content[0],
|
||||
RequestContent::Text {
|
||||
cache_control: None,
|
||||
..
|
||||
}
|
||||
));
|
||||
for i in 1..3 {
|
||||
assert!(matches!(
|
||||
message.content[i],
|
||||
RequestContent::Image {
|
||||
cache_control: None,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
assert!(matches!(
|
||||
message.content[4],
|
||||
RequestContent::Image {
|
||||
cache_control: Some(CacheControl {
|
||||
cache_type: CacheControlType::Ephemeral,
|
||||
}),
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
fn request_with_assistant_content(assistant_content: Vec<MessageContent>) -> crate::Request {
|
||||
let mut request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello".to_string())],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
thinking_effort: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thinking_allowed: true,
|
||||
speed: None,
|
||||
};
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: assistant_content,
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
});
|
||||
into_anthropic(
|
||||
request,
|
||||
"claude-sonnet-4-5".to_string(),
|
||||
1.0,
|
||||
16000,
|
||||
AnthropicModelMode::Thinking {
|
||||
budget_tokens: Some(10000),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_thinking_blocks_stripped() {
|
||||
let result = request_with_assistant_content(vec![
|
||||
MessageContent::Thinking {
|
||||
text: "Cancelled mid-think, no signature".to_string(),
|
||||
signature: None,
|
||||
},
|
||||
MessageContent::Text("Some response text".to_string()),
|
||||
]);
|
||||
|
||||
let assistant_message = result
|
||||
.messages
|
||||
.iter()
|
||||
.find(|m| m.role == crate::Role::Assistant)
|
||||
.expect("assistant message should still exist");
|
||||
|
||||
assert_eq!(
|
||||
assistant_message.content.len(),
|
||||
1,
|
||||
"Only the text content should remain; unsigned thinking block should be stripped"
|
||||
);
|
||||
assert!(matches!(
|
||||
&assistant_message.content[0],
|
||||
RequestContent::Text { text, .. } if text == "Some response text"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signed_thinking_blocks_preserved() {
|
||||
let result = request_with_assistant_content(vec![
|
||||
MessageContent::Thinking {
|
||||
text: "Completed thinking".to_string(),
|
||||
signature: Some("valid-signature".to_string()),
|
||||
},
|
||||
MessageContent::Text("Response".to_string()),
|
||||
]);
|
||||
|
||||
let assistant_message = result
|
||||
.messages
|
||||
.iter()
|
||||
.find(|m| m.role == crate::Role::Assistant)
|
||||
.expect("assistant message should exist");
|
||||
|
||||
assert_eq!(
|
||||
assistant_message.content.len(),
|
||||
2,
|
||||
"Both the signed thinking block and text should be preserved"
|
||||
);
|
||||
assert!(matches!(
|
||||
&assistant_message.content[0],
|
||||
RequestContent::Thinking { thinking, signature, .. }
|
||||
if thinking == "Completed thinking" && signature == "valid-signature"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_only_unsigned_thinking_block_omits_entire_message() {
|
||||
let result = request_with_assistant_content(vec![MessageContent::Thinking {
|
||||
text: "Cancelled before any text or signature".to_string(),
|
||||
signature: None,
|
||||
}]);
|
||||
|
||||
let assistant_messages: Vec<_> = result
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| m.role == crate::Role::Assistant)
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
assistant_messages.len(),
|
||||
0,
|
||||
"An assistant message whose only content was an unsigned thinking block \
|
||||
should be omitted entirely"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -36,7 +36,6 @@ gpui_tokio.workspace = true
|
|||
http_client.workspace = true
|
||||
http_client_tls.workspace = true
|
||||
httparse = "1.10"
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ use async_tungstenite::tungstenite::{
|
|||
http::{HeaderValue, Request, StatusCode},
|
||||
};
|
||||
use clock::SystemClock;
|
||||
use cloud_api_client::LlmApiToken;
|
||||
use cloud_api_client::websocket_protocol::MessageToClient;
|
||||
use cloud_api_client::{ClientApiError, CloudApiClient};
|
||||
use cloud_api_types::OrganizationId;
|
||||
|
|
@ -26,7 +27,6 @@ use futures::{
|
|||
};
|
||||
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
|
||||
use language_model::LlmApiToken;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
use super::{Client, UserStore};
|
||||
use cloud_api_client::LlmApiToken;
|
||||
use cloud_api_types::websocket_protocol::MessageToClient;
|
||||
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
|
||||
};
|
||||
use language_model::LlmApiToken;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait NeedsLlmTokenRefresh {
|
||||
|
|
|
|||
|
|
@ -20,5 +20,6 @@ gpui_tokio.workspace = true
|
|||
http_client.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
yawc.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
mod llm_token;
|
||||
mod websocket;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
|
@ -18,6 +19,8 @@ use yawc::WebSocket;
|
|||
|
||||
use crate::websocket::Connection;
|
||||
|
||||
pub use llm_token::LlmApiToken;
|
||||
|
||||
struct Credentials {
|
||||
user_id: u32,
|
||||
access_token: String,
|
||||
|
|
|
|||
74
crates/cloud_api_client/src/llm_token.rs
Normal file
74
crates/cloud_api_client/src/llm_token.rs
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use cloud_api_types::OrganizationId;
|
||||
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
||||
|
||||
use crate::{ClientApiError, CloudApiClient};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LlmApiToken {
|
||||
pub async fn acquire(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
if let Some(token) = lock.as_ref() {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
Self::fetch(
|
||||
RwLockUpgradableReadGuard::upgrade(lock).await,
|
||||
client,
|
||||
system_id,
|
||||
organization_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
Self::fetch(self.0.write().await, client, system_id, organization_id).await
|
||||
}
|
||||
|
||||
/// Clears the existing token before attempting to fetch a new one.
|
||||
///
|
||||
/// Used when switching organizations so that a failed refresh doesn't
|
||||
/// leave a token for the wrong organization.
|
||||
pub async fn clear_and_refresh(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
let mut lock = self.0.write().await;
|
||||
*lock = None;
|
||||
Self::fetch(lock, client, system_id, organization_id).await
|
||||
}
|
||||
|
||||
async fn fetch(
|
||||
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
let result = client.create_llm_token(system_id, organization_id).await;
|
||||
match result {
|
||||
Ok(response) => {
|
||||
*lock = Some(response.token.0.clone());
|
||||
Ok(response.token.0)
|
||||
}
|
||||
Err(err) => {
|
||||
*lock = None;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ license = "Apache-2.0"
|
|||
|
||||
[features]
|
||||
test-support = []
|
||||
predict-edits = ["dep:zeta_prompt"]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
|
@ -20,6 +21,6 @@ serde = { workspace = true, features = ["derive", "rc"] }
|
|||
serde_json.workspace = true
|
||||
strum = { workspace = true, features = ["derive"] }
|
||||
uuid = { workspace = true, features = ["serde"] }
|
||||
zeta_prompt.workspace = true
|
||||
zeta_prompt = { workspace = true, optional = true }
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#[cfg(feature = "predict-edits")]
|
||||
pub mod predict_edits_v3;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
|
|
|||
|
|
@ -2846,11 +2846,11 @@ impl CollabPanel {
|
|||
}
|
||||
};
|
||||
|
||||
Some(channel.name.as_ref())
|
||||
Some(channel.name.clone())
|
||||
});
|
||||
|
||||
if let Some(name) = channel_name {
|
||||
SharedString::from(name.to_string())
|
||||
name
|
||||
} else {
|
||||
SharedString::from("Current Call")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,8 +21,9 @@ heapless.workspace = true
|
|||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_llm_client = { workspace = true, features = ["predict-edits"] }
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
copilot_ui.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
|
||||
use cloud_api_client::LlmApiToken;
|
||||
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
|
||||
|
|
@ -31,7 +32,6 @@ use heapless::Vec as ArrayVec;
|
|||
use language::language_settings::all_language_settings;
|
||||
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
|
||||
use language::{BufferSnapshot, OffsetRangeExt};
|
||||
use language_model::LlmApiToken;
|
||||
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
|
||||
use release_channel::AppVersion;
|
||||
use semver::Version;
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
|
|||
let mut models: Vec<SharedString> = provider
|
||||
.provided_models(cx)
|
||||
.into_iter()
|
||||
.map(|model| SharedString::from(model.id().0.to_string()))
|
||||
.map(|model| model.id().0)
|
||||
.collect();
|
||||
models.sort();
|
||||
models
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
|||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
id: Some(prediction.id.0.clone()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
|
|
@ -228,7 +228,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
|||
}
|
||||
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
id: Some(prediction.id.0.clone()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
cursor_position: prediction.cursor_position,
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ http_client.workspace = true
|
|||
chrono.workspace = true
|
||||
clap = "4"
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace= true
|
||||
cloud_llm_client = { workspace = true, features = ["predict-edits"] }
|
||||
collections.workspace = true
|
||||
db.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
|
|
|
|||
|
|
@ -12,4 +12,4 @@ workspace = true
|
|||
path = "src/env_var.rs"
|
||||
|
||||
[dependencies]
|
||||
gpui.workspace = true
|
||||
gpui_shared_string.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EnvVar {
|
||||
|
|
|
|||
|
|
@ -1906,7 +1906,7 @@ mod tests {
|
|||
assert_eq!(
|
||||
remotes,
|
||||
vec![Remote {
|
||||
name: SharedString::from("my_new_remote".to_string())
|
||||
name: SharedString::from("my_new_remote")
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ schemars = ["dep:schemars"]
|
|||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model_core.workspace = true
|
||||
log.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
strum.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
|
|
|
|||
492
crates/google_ai/src/completion.rs
Normal file
492
crates/google_ai/src/completion.rs
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
use anyhow::Result;
|
||||
use futures::{Stream, StreamExt};
|
||||
use language_model_core::{
|
||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{self, AtomicU64};
|
||||
|
||||
use crate::{
|
||||
Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
|
||||
GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
|
||||
InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
|
||||
UsageMetadata,
|
||||
};
|
||||
|
||||
pub fn into_google(
|
||||
mut request: LanguageModelRequest,
|
||||
model_id: String,
|
||||
mode: GoogleModelMode,
|
||||
) -> crate::GenerateContentRequest {
|
||||
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||
content
|
||||
.into_iter()
|
||||
.flat_map(|content| match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
vec![Part::TextPart(TextPart { text })]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking {
|
||||
text: _,
|
||||
signature: Some(signature),
|
||||
} => {
|
||||
if !signature.is_empty() {
|
||||
vec![Part::ThoughtPart(crate::ThoughtPart {
|
||||
thought: true,
|
||||
thought_signature: signature,
|
||||
})]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking { .. } => {
|
||||
vec![]
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => vec![],
|
||||
MessageContent::Image(image) => {
|
||||
vec![Part::InlineDataPart(InlineDataPart {
|
||||
inline_data: GenerativeContentBlob {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
})]
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
// Normalize empty string signatures to None
|
||||
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
|
||||
|
||||
vec![Part::FunctionCallPart(crate::FunctionCallPart {
|
||||
function_call: crate::FunctionCall {
|
||||
name: tool_use.name.to_string(),
|
||||
args: tool_use.input,
|
||||
},
|
||||
thought_signature,
|
||||
})]
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
match tool_result.content {
|
||||
language_model_core::LanguageModelToolResultContent::Text(text) => {
|
||||
vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
|
||||
function_response: crate::FunctionResponse {
|
||||
name: tool_result.tool_name.to_string(),
|
||||
// The API expects a valid JSON object
|
||||
response: serde_json::json!({
|
||||
"output": text
|
||||
}),
|
||||
},
|
||||
})]
|
||||
}
|
||||
language_model_core::LanguageModelToolResultContent::Image(image) => {
|
||||
vec![
|
||||
Part::FunctionResponsePart(crate::FunctionResponsePart {
|
||||
function_response: crate::FunctionResponse {
|
||||
name: tool_result.tool_name.to_string(),
|
||||
// The API expects a valid JSON object
|
||||
response: serde_json::json!({
|
||||
"output": "Tool responded with an image"
|
||||
}),
|
||||
},
|
||||
}),
|
||||
Part::InlineDataPart(InlineDataPart {
|
||||
inline_data: GenerativeContentBlob {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
}),
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
let system_instructions = if request
|
||||
.messages
|
||||
.first()
|
||||
.is_some_and(|msg| matches!(msg.role, Role::System))
|
||||
{
|
||||
let message = request.messages.remove(0);
|
||||
Some(SystemInstruction {
|
||||
parts: map_content(message.content),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
crate::GenerateContentRequest {
|
||||
model: ModelName { model_id },
|
||||
system_instruction: system_instructions,
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
let parts = map_content(message.content);
|
||||
if parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Content {
|
||||
parts,
|
||||
role: match message.role {
|
||||
Role::User => crate::Role::User,
|
||||
Role::Assistant => crate::Role::Model,
|
||||
Role::System => crate::Role::User, // Google AI doesn't have a system role
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
generation_config: Some(GenerationConfig {
|
||||
candidate_count: Some(1),
|
||||
stop_sequences: Some(request.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||
thinking_config: match (request.thinking_allowed, mode) {
|
||||
(true, GoogleModelMode::Thinking { budget_tokens }) => {
|
||||
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
tools: (!request.tools.is_empty()).then(|| {
|
||||
vec![crate::Tool {
|
||||
function_declarations: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| FunctionDeclaration {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
})
|
||||
.collect(),
|
||||
}]
|
||||
}),
|
||||
tool_config: request.tool_choice.map(|choice| ToolConfig {
|
||||
function_calling_config: FunctionCallingConfig {
|
||||
mode: match choice {
|
||||
LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
|
||||
LanguageModelToolChoice::Any => FunctionCallingMode::Any,
|
||||
LanguageModelToolChoice::None => FunctionCallingMode::None,
|
||||
},
|
||||
allowed_function_names: None,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GoogleEventMapper {
|
||||
usage: UsageMetadata,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
impl GoogleEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
usage: UsageMetadata::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events
|
||||
.map(Some)
|
||||
.chain(futures::stream::once(async { None }))
|
||||
.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Some(Ok(event)) => self.map_event(event),
|
||||
Some(Err(error)) => {
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: GenerateContentResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
let mut events: Vec<_> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut self.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
)))
|
||||
}
|
||||
|
||||
if let Some(prompt_feedback) = event.prompt_feedback
|
||||
&& let Some(block_reason) = prompt_feedback.block_reason.as_deref()
|
||||
{
|
||||
self.stop_reason = match block_reason {
|
||||
"SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
|
||||
StopReason::Refusal
|
||||
}
|
||||
_ => {
|
||||
log::error!("Unexpected Google block_reason: {block_reason}");
|
||||
StopReason::Refusal
|
||||
}
|
||||
};
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
if let Some(candidates) = event.candidates {
|
||||
for candidate in candidates {
|
||||
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
|
||||
self.stop_reason = match finish_reason {
|
||||
"STOP" => StopReason::EndTurn,
|
||||
"MAX_TOKENS" => StopReason::MaxTokens,
|
||||
_ => {
|
||||
log::error!("Unexpected google finish_reason: {finish_reason}");
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
}
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.for_each(|part| match part {
|
||||
Part::TextPart(text_part) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
|
||||
}
|
||||
Part::InlineDataPart(_) => {}
|
||||
Part::FunctionCallPart(function_call_part) => {
|
||||
wants_to_use_tool = true;
|
||||
let name: Arc<str> = function_call_part.function_call.name.into();
|
||||
let next_tool_id =
|
||||
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
let id: LanguageModelToolUseId =
|
||||
format!("{}-{}", name, next_tool_id).into();
|
||||
|
||||
// Normalize empty string signatures to None
|
||||
let thought_signature = function_call_part
|
||||
.thought_signature
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
is_input_complete: true,
|
||||
raw_input: function_call_part.function_call.args.to_string(),
|
||||
input: function_call_part.function_call.args,
|
||||
thought_signature,
|
||||
},
|
||||
)));
|
||||
}
|
||||
Part::FunctionResponsePart(_) => {}
|
||||
Part::ThoughtPart(part) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
|
||||
signature: Some(part.thought_signature),
|
||||
}));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Even when Gemini wants to use a Tool, the API
|
||||
// responds with `finish_reason: STOP`
|
||||
if wants_to_use_tool {
|
||||
self.stop_reason = StopReason::ToolUse;
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
/// Count tokens for a Google AI model using tiktoken. This is synchronous;
|
||||
/// callers should spawn it on a background thread if needed.
|
||||
pub fn count_google_tokens(request: LanguageModelRequest) -> Result<u64> {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
|
||||
}
|
||||
|
||||
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
|
||||
if let Some(prompt_token_count) = new.prompt_token_count {
|
||||
usage.prompt_token_count = Some(prompt_token_count);
|
||||
}
|
||||
if let Some(cached_content_token_count) = new.cached_content_token_count {
|
||||
usage.cached_content_token_count = Some(cached_content_token_count);
|
||||
}
|
||||
if let Some(candidates_token_count) = new.candidates_token_count {
|
||||
usage.candidates_token_count = Some(candidates_token_count);
|
||||
}
|
||||
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
|
||||
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
|
||||
}
|
||||
if let Some(thoughts_token_count) = new.thoughts_token_count {
|
||||
usage.thoughts_token_count = Some(thoughts_token_count);
|
||||
}
|
||||
if let Some(total_token_count) = new.total_token_count {
|
||||
usage.total_token_count = Some(total_token_count);
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
|
||||
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
|
||||
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
|
||||
let input_tokens = prompt_tokens - cached_tokens;
|
||||
let output_tokens = usage.candidates_token_count.unwrap_or(0);
|
||||
|
||||
TokenUsage {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_input_tokens: cached_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
|
||||
Part, Role as GoogleRole,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_function_call_with_signature_creates_tool_use_with_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("test_signature_123".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.name.as_ref(), "test_function");
|
||||
assert_eq!(
|
||||
tool_use.thought_signature.as_deref(),
|
||||
Some("test_signature_123")
|
||||
);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_without_signature_has_none() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: None,
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
assert_eq!(events.len(), 2);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert!(tool_use.thought_signature.is_none());
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_signature_normalized_to_none() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert!(tool_use.thought_signature.is_none());
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3,8 +3,9 @@ use std::mem;
|
|||
use anyhow::{Result, anyhow, bail};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
pub use language_model_core::ModelMode as GoogleModelMode;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
pub use settings::ModelMode as GoogleModelMode;
|
||||
pub mod completion;
|
||||
|
||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ etagere = "0.2"
|
|||
futures.workspace = true
|
||||
futures-concurrency.workspace = true
|
||||
gpui_macros.workspace = true
|
||||
gpui_shared_string.workspace = true
|
||||
http_client.workspace = true
|
||||
image.workspace = true
|
||||
inventory.workspace = true
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ pub mod profiler;
|
|||
#[expect(missing_docs)]
|
||||
pub mod queue;
|
||||
mod scene;
|
||||
mod shared_string;
|
||||
mod shared_uri;
|
||||
mod style;
|
||||
mod styled;
|
||||
|
|
@ -92,6 +91,7 @@ pub use global::*;
|
|||
pub use gpui_macros::{
|
||||
AppContext, IntoElement, Render, VisualContext, property_test, register_action, test,
|
||||
};
|
||||
pub use gpui_shared_string::*;
|
||||
pub use gpui_util::arc_cow::ArcCow;
|
||||
pub use http_client;
|
||||
pub use input::*;
|
||||
|
|
@ -106,7 +106,6 @@ pub use profiler::*;
|
|||
pub use queue::{PriorityQueueReceiver, PriorityQueueSender};
|
||||
pub use refineable::*;
|
||||
pub use scene::*;
|
||||
pub use shared_string::*;
|
||||
pub use shared_uri::*;
|
||||
use std::{any::Any, future::Future};
|
||||
pub use style::*;
|
||||
|
|
|
|||
|
|
@ -882,7 +882,7 @@ mod tests {
|
|||
],
|
||||
len: 6,
|
||||
}),
|
||||
text: SharedString::new("abcdef".to_string()),
|
||||
text: "abcdef".into(),
|
||||
decoration_runs: SmallVec::new(),
|
||||
};
|
||||
|
||||
|
|
|
|||
17
crates/gpui_shared_string/Cargo.toml
Normal file
17
crates/gpui_shared_string/Cargo.toml
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "gpui_shared_string"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "gpui_shared_string.rs"
|
||||
|
||||
[dependencies]
|
||||
derive_more.workspace = true
|
||||
gpui_util.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
1
crates/gpui_shared_string/LICENSE-APACHE
Symbolic link
1
crates/gpui_shared_string/LICENSE-APACHE
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
../../LICENSE-APACHE
|
||||
|
|
@ -10,7 +10,7 @@ path = "src/language_core.rs"
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
gpui_shared_string.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
parking_lot.workspace = true
|
||||
|
|
@ -22,8 +22,6 @@ toml.workspace = true
|
|||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
use lsp::{DiagnosticSeverity, NumberOrString};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
use lsp::LanguageServerName;
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::LanguageName;
|
||||
use collections::{HashMap, HashSet, IndexSet};
|
||||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
use lsp::LanguageServerName;
|
||||
use regex::Regex;
|
||||
use schemars::{JsonSchema, SchemaGenerator, json_schema};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Converts a value into an LSP position.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::borrow::Borrow;
|
||||
|
||||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct ManifestName(SharedString);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use gpui::SharedString;
|
||||
use gpui_shared_string::SharedString;
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
use crate::{LanguageName, ManifestName};
|
||||
|
|
|
|||
|
|
@ -16,13 +16,9 @@ doctest = false
|
|||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
base64.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
env_var.workspace = true
|
||||
futures.workspace = true
|
||||
|
|
@ -30,14 +26,11 @@ gpui.workspace = true
|
|||
http_client.workspace = true
|
||||
icons.workspace = true
|
||||
image.workspace = true
|
||||
language_model_core.workspace = true
|
||||
log.workspace = true
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
open_router.workspace = true
|
||||
parking_lot.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,10 @@ use crate::{
|
|||
LanguageModelRequest, LanguageModelToolChoice,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||
use http_client::Result;
|
||||
use parking_lot::Mutex;
|
||||
use smol::stream::StreamExt;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
|
|
|
|||
|
|
@ -1,380 +1,31 @@
|
|||
mod api_key;
|
||||
mod model;
|
||||
mod provider;
|
||||
mod rate_limiter;
|
||||
mod registry;
|
||||
mod request;
|
||||
mod role;
|
||||
pub mod tool_schema;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod fake_provider;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::CompletionRequestStatus;
|
||||
pub use language_model_core::*;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use http_client::{StatusCode, http};
|
||||
use gpui::{AnyView, App, AsyncApp, Task, Window};
|
||||
use icons::IconName;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::{Add, Sub};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
use thiserror::Error;
|
||||
use util::serde::is_default;
|
||||
|
||||
pub use crate::api_key::{ApiKey, ApiKeyState};
|
||||
pub use crate::model::*;
|
||||
pub use crate::rate_limiter::*;
|
||||
pub use crate::registry::*;
|
||||
pub use crate::request::*;
|
||||
pub use crate::role::*;
|
||||
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
|
||||
pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
|
||||
pub use env_var::{EnvVar, env_var};
|
||||
pub use provider::*;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
registry::init(cx);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LanguageModelCacheConfiguration {
|
||||
pub max_cache_anchors: usize,
|
||||
pub should_speculate: bool,
|
||||
pub min_total_token: u64,
|
||||
}
|
||||
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
Queued {
|
||||
position: usize,
|
||||
},
|
||||
Started,
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking {
|
||||
data: String,
|
||||
},
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolUseJsonParseError {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
},
|
||||
ReasoningDetails(serde_json::Value),
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionEvent {
|
||||
pub fn from_completion_request_status(
|
||||
status: CompletionRequestStatus,
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
) -> Result<Option<Self>, LanguageModelCompletionError> {
|
||||
match status {
|
||||
CompletionRequestStatus::Queued { position } => {
|
||||
Ok(Some(LanguageModelCompletionEvent::Queued { position }))
|
||||
}
|
||||
CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
|
||||
CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
|
||||
CompletionRequestStatus::Failed {
|
||||
code,
|
||||
message,
|
||||
request_id: _,
|
||||
retry_after,
|
||||
} => Err(LanguageModelCompletionError::from_cloud_failure(
|
||||
upstream_provider,
|
||||
code,
|
||||
message,
|
||||
retry_after.map(Duration::from_secs_f64),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("missing {provider} API key")]
|
||||
NoApiKey { provider: LanguageModelProviderName },
|
||||
#[error("{provider}'s API rate limit exceeded")]
|
||||
RateLimitExceeded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API servers are overloaded right now")]
|
||||
ServerOverloaded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("{message}")]
|
||||
UpstreamProviderError {
|
||||
message: String,
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
// Client errors
|
||||
#[error("invalid request format to {provider}'s API: {message}")]
|
||||
BadRequestFormat {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("authentication error with {provider}'s API: {message}")]
|
||||
AuthenticationError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("Permission error with {provider}'s API: {message}")]
|
||||
PermissionError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||
#[error("I/O error reading response from {provider}'s API")]
|
||||
ApiReadResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: io::Error,
|
||||
},
|
||||
#[error("error serializing request to {provider} API")]
|
||||
SerializeRequest {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("error building request body to {provider} API")]
|
||||
BuildRequestBody {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: http::Error,
|
||||
},
|
||||
#[error("error sending HTTP request to {provider} API")]
|
||||
HttpSend {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: anyhow::Error,
|
||||
},
|
||||
#[error("error deserializing {provider} API response")]
|
||||
DeserializeResponse {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
|
||||
#[error("stream from {provider} ended unexpectedly")]
|
||||
StreamEndedUnexpectedly { provider: LanguageModelProviderName },
|
||||
|
||||
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
||||
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
||||
let upstream_status = error_json
|
||||
.get("upstream_status")
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(|status| u16::try_from(status).ok())
|
||||
.and_then(|status| StatusCode::from_u16(status).ok())?;
|
||||
let inner_message = error_json
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(message)
|
||||
.to_string();
|
||||
Some((upstream_status, inner_message))
|
||||
}
|
||||
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
if let Some(tokens) = parse_prompt_too_long(&message) {
|
||||
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
|
||||
// to be reported. This is a temporary workaround to handle this in the case where the
|
||||
// token limit has been exceeded.
|
||||
Self::PromptTooLarge {
|
||||
tokens: Some(tokens),
|
||||
}
|
||||
} else if code == "upstream_http_error" {
|
||||
if let Some((upstream_status, inner_message)) =
|
||||
Self::parse_upstream_error_json(&message)
|
||||
{
|
||||
return Self::from_http_status(
|
||||
upstream_provider,
|
||||
upstream_status,
|
||||
inner_message,
|
||||
retry_after,
|
||||
);
|
||||
}
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||
} else {
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_http_status(
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
match status_code {
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&message),
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StopReason {
|
||||
EndTurn,
|
||||
MaxTokens,
|
||||
ToolUse,
|
||||
Refusal,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
|
||||
pub struct TokenUsage {
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub input_tokens: u64,
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub output_tokens: u64,
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub cache_creation_input_tokens: u64,
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub cache_read_input_tokens: u64,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total_tokens(&self) -> u64 {
|
||||
self.input_tokens
|
||||
+ self.output_tokens
|
||||
+ self.cache_read_input_tokens
|
||||
+ self.cache_creation_input_tokens
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<TokenUsage> for TokenUsage {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> Self {
|
||||
Self {
|
||||
input_tokens: self.input_tokens + other.input_tokens,
|
||||
output_tokens: self.output_tokens + other.output_tokens,
|
||||
cache_creation_input_tokens: self.cache_creation_input_tokens
|
||||
+ other.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<TokenUsage> for TokenUsage {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> Self {
|
||||
Self {
|
||||
input_tokens: self.input_tokens - other.input_tokens,
|
||||
output_tokens: self.output_tokens - other.output_tokens,
|
||||
cache_creation_input_tokens: self.cache_creation_input_tokens
|
||||
- other.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUseId(Arc<str>);
|
||||
|
||||
impl fmt::Display for LanguageModelToolUseId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for LanguageModelToolUseId
|
||||
where
|
||||
T: Into<Arc<str>>,
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUse {
|
||||
pub id: LanguageModelToolUseId,
|
||||
pub name: Arc<str>,
|
||||
pub raw_input: String,
|
||||
pub input: serde_json::Value,
|
||||
pub is_input_complete: bool,
|
||||
/// Thought signature the model sent us. Some models require that this
|
||||
/// signature be preserved and sent back in conversation history for validation.
|
||||
pub thought_signature: Option<String>,
|
||||
}
|
||||
|
||||
pub struct LanguageModelTextStream {
|
||||
pub message_id: Option<String>,
|
||||
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
|
||||
|
|
@ -392,13 +43,6 @@ impl Default for LanguageModelTextStream {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LanguageModelEffortLevel {
|
||||
pub name: SharedString,
|
||||
pub value: SharedString,
|
||||
pub is_default: bool,
|
||||
}
|
||||
|
||||
pub trait LanguageModel: Send + Sync {
|
||||
fn id(&self) -> LanguageModelId;
|
||||
fn name(&self) -> LanguageModelName;
|
||||
|
|
@ -605,7 +249,7 @@ pub trait LanguageModel: Send + Sync {
|
|||
}
|
||||
|
||||
impl std::fmt::Debug for dyn LanguageModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("<dyn LanguageModel>")
|
||||
.field("id", &self.id())
|
||||
.field("name", &self.name())
|
||||
|
|
@ -619,17 +263,6 @@ impl std::fmt::Debug for dyn LanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
/// An error that occurred when trying to authenticate the language model provider.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthenticateError {
|
||||
#[error("connection refused")]
|
||||
ConnectionRefused,
|
||||
#[error("credentials not found")]
|
||||
CredentialsNotFound,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Either a built-in icon name or a path to an external SVG.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum IconOrSvg {
|
||||
|
|
@ -692,18 +325,6 @@ pub trait LanguageModelProviderState: 'static {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
pub struct LanguageModelId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelName(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelProviderId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelProviderName(pub SharedString);
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum LanguageModelCostInfo {
|
||||
/// Cost per 1,000 input and output tokens
|
||||
|
|
@ -741,245 +362,3 @@ impl LanguageModelCostInfo {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderId {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderName {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderName {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<str>> for LanguageModelProviderId {
|
||||
fn from(value: Arc<str>) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<str>> for LanguageModelProviderName {
|
||||
fn from(value: Arc<str>) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_upstream_http_error() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
assert_eq!(message, "Internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for 500 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_standard_format() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_503".to_string(),
|
||||
"Service unavailable".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upstream_http_error_connection_timeout() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_serializes_with_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_tool".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("test_signature".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&tool_use).unwrap();
|
||||
|
||||
assert_eq!(serialized["id"], "test_id");
|
||||
assert_eq!(serialized["name"], "test_tool");
|
||||
assert_eq!(serialized["thought_signature"], "test_signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_deserializes_with_missing_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let json = json!({
|
||||
"id": "test_id",
|
||||
"name": "test_tool",
|
||||
"raw_input": "{\"arg\":\"value\"}",
|
||||
"input": {"arg": "value"},
|
||||
"is_input_complete": true
|
||||
});
|
||||
|
||||
let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
|
||||
|
||||
assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
|
||||
assert_eq!(tool_use.name.as_ref(), "test_tool");
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_round_trip_with_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let original = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("round_trip_id"),
|
||||
name: "round_trip_tool".into(),
|
||||
raw_input: json!({"key": "value"}).to_string(),
|
||||
input: json!({"key": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("round_trip_sig".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.id, original.id);
|
||||
assert_eq!(deserialized.name, original.name);
|
||||
assert_eq!(deserialized.thought_signature, original.thought_signature);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_round_trip_without_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let original = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("no_sig_id"),
|
||||
name: "no_sig_tool".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.id, original.id);
|
||||
assert_eq!(deserialized.name, original.name);
|
||||
assert_eq!(deserialized.thought_signature, None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cloud_api_client::ClientApiError;
|
||||
use cloud_api_client::CloudApiClient;
|
||||
use cloud_api_types::OrganizationId;
|
||||
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
|
@ -18,71 +13,3 @@ impl fmt::Display for PaymentRequiredError {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LlmApiToken {
|
||||
pub async fn acquire(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
if let Some(token) = lock.as_ref() {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
Self::fetch(
|
||||
RwLockUpgradableReadGuard::upgrade(lock).await,
|
||||
client,
|
||||
system_id,
|
||||
organization_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
Self::fetch(self.0.write().await, client, system_id, organization_id).await
|
||||
}
|
||||
|
||||
/// Clears the existing token before attempting to fetch a new one.
|
||||
///
|
||||
/// Used when switching organizations so that a failed refresh doesn't
|
||||
/// leave a token for the wrong organization.
|
||||
pub async fn clear_and_refresh(
|
||||
&self,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
let mut lock = self.0.write().await;
|
||||
*lock = None;
|
||||
Self::fetch(lock, client, system_id, organization_id).await
|
||||
}
|
||||
|
||||
async fn fetch(
|
||||
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
||||
client: &CloudApiClient,
|
||||
system_id: Option<String>,
|
||||
organization_id: Option<OrganizationId>,
|
||||
) -> Result<String, ClientApiError> {
|
||||
let result = client.create_llm_token(system_id, organization_id).await;
|
||||
match result {
|
||||
Ok(response) => {
|
||||
*lock = Some(response.token.0.clone());
|
||||
Ok(response.token.0)
|
||||
}
|
||||
Err(err) => {
|
||||
*lock = None;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +0,0 @@
|
|||
pub mod anthropic;
|
||||
pub mod google;
|
||||
pub mod open_ai;
|
||||
pub mod open_router;
|
||||
pub mod x_ai;
|
||||
pub mod zed;
|
||||
|
||||
pub use anthropic::*;
|
||||
pub use google::*;
|
||||
pub use open_ai::*;
|
||||
pub use x_ai::*;
|
||||
pub use zed::*;
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName};
|
||||
use anthropic::AnthropicError;
|
||||
pub use anthropic::parse_prompt_too_long;
|
||||
|
||||
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
|
||||
LanguageModelProviderId::new("anthropic");
|
||||
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Anthropic");
|
||||
|
||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
AnthropicError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
AnthropicError::HttpResponseError {
|
||||
status_code,
|
||||
message,
|
||||
} => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: anthropic::ApiError) -> Self {
|
||||
use anthropic::ApiErrorCode::*;
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||
RequestTooLarge => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
},
|
||||
None => Self::Other(error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
use crate::{LanguageModelProviderId, LanguageModelProviderName};
|
||||
|
||||
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
|
||||
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Google AI");
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName};
|
||||
use http_client::http;
|
||||
use std::time::Duration;
|
||||
|
||||
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
|
||||
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("OpenAI");
|
||||
|
||||
impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
fn from(error: open_ai::RequestError) -> Self {
|
||||
match error {
|
||||
open_ai::RequestError::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
body,
|
||||
headers,
|
||||
} => {
|
||||
let retry_after = headers
|
||||
.get(http::header::RETRY_AFTER)
|
||||
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
|
||||
Self::from_http_status(provider.into(), status_code, body, retry_after)
|
||||
}
|
||||
open_ai::RequestError::Other(e) => Self::Other(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
use crate::{LanguageModelCompletionError, LanguageModelProviderName};
|
||||
use http_client::StatusCode;
|
||||
use open_router::OpenRouterError;
|
||||
|
||||
impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||
fn from(error: OpenRouterError) -> Self {
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error {
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
OpenRouterError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: open_router::ApiError) -> Self {
|
||||
use open_router::ApiErrorCode::*;
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error.code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PaymentRequiredError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: format!("Payment required: {}", error.message),
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
RequestTimedOut => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code: StatusCode::REQUEST_TIMEOUT,
|
||||
message: error.message,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
use crate::{LanguageModelProviderId, LanguageModelProviderName};
|
||||
|
||||
pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
|
||||
pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
use crate::{LanguageModelProviderId, LanguageModelProviderName};
|
||||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
|
||||
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Zed");
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderState,
|
||||
LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use collections::{BTreeMap, HashSet};
|
||||
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
|
||||
|
|
@ -101,7 +101,7 @@ impl ConfiguredModel {
|
|||
}
|
||||
|
||||
pub fn is_provided_by_zed(&self) -> bool {
|
||||
self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID
|
||||
self.provider.id() == ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,78 +4,13 @@ use std::sync::Arc;
|
|||
use anyhow::Result;
|
||||
use base64::write::EncoderWriter;
|
||||
use gpui::{
|
||||
App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
|
||||
point, px, size,
|
||||
App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, Size, Task, point, px, size,
|
||||
};
|
||||
use image::GenericImageView as _;
|
||||
use image::codecs::png::PngEncoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::role::Role;
|
||||
use crate::{LanguageModelToolUse, LanguageModelToolUseId};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct LanguageModelImage {
|
||||
/// A base64-encoded PNG image.
|
||||
pub source: SharedString,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<Size<DevicePixels>>,
|
||||
}
|
||||
|
||||
impl LanguageModelImage {
|
||||
pub fn len(&self) -> usize {
|
||||
self.source.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.source.is_empty()
|
||||
}
|
||||
|
||||
// Parse Self from a JSON object with case-insensitive field names
|
||||
pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
|
||||
let mut source = None;
|
||||
let mut size_obj = None;
|
||||
|
||||
// Find source and size fields (case-insensitive)
|
||||
for (k, v) in obj.iter() {
|
||||
match k.to_lowercase().as_str() {
|
||||
"source" => source = v.as_str(),
|
||||
"size" => size_obj = v.as_object(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let source = source?;
|
||||
let size_obj = size_obj?;
|
||||
|
||||
let mut width = None;
|
||||
let mut height = None;
|
||||
|
||||
// Find width and height in size object (case-insensitive)
|
||||
for (k, v) in size_obj.iter() {
|
||||
match k.to_lowercase().as_str() {
|
||||
"width" => width = v.as_i64().map(|w| w as i32),
|
||||
"height" => height = v.as_i64().map(|h| h as i32),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
size: Some(size(DevicePixels(width?), DevicePixels(height?))),
|
||||
source: SharedString::from(source.to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for LanguageModelImage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("LanguageModelImage")
|
||||
.field("source", &format!("<{} bytes>", self.source.len()))
|
||||
.field("size", &self.size)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
use language_model_core::{ImageSize, LanguageModelImage};
|
||||
|
||||
/// Anthropic wants uploaded images to be smaller than this in both dimensions.
|
||||
const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
|
||||
|
|
@ -90,18 +25,16 @@ const DEFAULT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
|
|||
/// `DEFAULT_IMAGE_MAX_BYTES`.
|
||||
const MAX_IMAGE_DOWNSCALE_PASSES: usize = 8;
|
||||
|
||||
impl LanguageModelImage {
|
||||
// All language model images are encoded as PNGs.
|
||||
pub const FORMAT: ImageFormat = ImageFormat::Png;
|
||||
/// Extension trait for `LanguageModelImage` that provides GPUI-dependent functionality.
|
||||
pub trait LanguageModelImageExt {
|
||||
const FORMAT: ImageFormat;
|
||||
fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<LanguageModelImage>>;
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
source: "".into(),
|
||||
size: None,
|
||||
}
|
||||
}
|
||||
impl LanguageModelImageExt for LanguageModelImage {
|
||||
const FORMAT: ImageFormat = ImageFormat::Png;
|
||||
|
||||
pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
|
||||
fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<LanguageModelImage>> {
|
||||
cx.background_spawn(async move {
|
||||
let image_bytes = Cursor::new(data.bytes());
|
||||
let dynamic_image = match data.format() {
|
||||
|
|
@ -186,28 +119,14 @@ impl LanguageModelImage {
|
|||
let source = unsafe { String::from_utf8_unchecked(base64_image) };
|
||||
|
||||
Some(LanguageModelImage {
|
||||
size: Some(image_size),
|
||||
size: Some(ImageSize {
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
}),
|
||||
source: source.into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn estimate_tokens(&self) -> usize {
|
||||
let Some(size) = self.size.as_ref() else {
|
||||
return 0;
|
||||
};
|
||||
let width = size.width.0.unsigned_abs() as usize;
|
||||
let height = size.height.0.unsigned_abs() as usize;
|
||||
|
||||
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
|
||||
// Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
|
||||
// so this method is more of a rough guess.
|
||||
(width * height) / 750
|
||||
}
|
||||
|
||||
pub fn to_base64_url(&self) -> String {
|
||||
format!("data:image/png;base64,{}", self.source)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_png_bytes(image: &image::DynamicImage) -> Result<Vec<u8>> {
|
||||
|
|
@ -228,280 +147,20 @@ fn encode_bytes_as_base64(bytes: &[u8]) -> Result<Vec<u8>> {
|
|||
Ok(base64_image)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub struct LanguageModelToolResult {
|
||||
pub tool_use_id: LanguageModelToolUseId,
|
||||
pub tool_name: Arc<str>,
|
||||
pub is_error: bool,
|
||||
/// The tool output formatted for presenting to the model
|
||||
pub content: LanguageModelToolResultContent,
|
||||
/// The raw tool output, if available, often for debugging or extra state for replay
|
||||
pub output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
|
||||
pub enum LanguageModelToolResultContent {
|
||||
Text(Arc<str>),
|
||||
Image(LanguageModelImage),
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
let value = serde_json::Value::deserialize(deserializer)?;
|
||||
|
||||
// Models can provide these responses in several styles. Try each in order.
|
||||
|
||||
// 1. Try as plain string
|
||||
if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
|
||||
// 2. Try as object
|
||||
if let Some(obj) = value.as_object() {
|
||||
// get a JSON field case-insensitively
|
||||
fn get_field<'a>(
|
||||
obj: &'a serde_json::Map<String, serde_json::Value>,
|
||||
field: &str,
|
||||
) -> Option<&'a serde_json::Value> {
|
||||
obj.iter()
|
||||
.find(|(k, _)| k.to_lowercase() == field.to_lowercase())
|
||||
.map(|(_, v)| v)
|
||||
}
|
||||
|
||||
// Accept wrapped text format: { "type": "text", "text": "..." }
|
||||
if let (Some(type_value), Some(text_value)) =
|
||||
(get_field(obj, "type"), get_field(obj, "text"))
|
||||
&& let Some(type_str) = type_value.as_str()
|
||||
&& type_str.to_lowercase() == "text"
|
||||
&& let Some(text) = text_value.as_str()
|
||||
{
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
|
||||
// Check for wrapped Text variant: { "text": "..." }
|
||||
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
|
||||
&& obj.len() == 1
|
||||
{
|
||||
// Only one field, and it's "text" (case-insensitive)
|
||||
if let Some(text) = value.as_str() {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
|
||||
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
|
||||
&& obj.len() == 1
|
||||
{
|
||||
// Only one field, and it's "image" (case-insensitive)
|
||||
// Try to parse the nested image object
|
||||
if let Some(image_obj) = value.as_object()
|
||||
&& let Some(image) = LanguageModelImage::from_json(image_obj)
|
||||
{
|
||||
return Ok(Self::Image(image));
|
||||
}
|
||||
}
|
||||
|
||||
// Try as direct Image (object with "source" and "size" fields)
|
||||
if let Some(image) = LanguageModelImage::from_json(obj) {
|
||||
return Ok(Self::Image(image));
|
||||
}
|
||||
}
|
||||
|
||||
// If none of the variants match, return an error with the problematic JSON
|
||||
Err(D::Error::custom(format!(
|
||||
"data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
|
||||
an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
|
||||
serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
|
||||
)))
|
||||
/// Convert a core `ImageSize` to a gpui `Size<DevicePixels>`.
|
||||
pub fn image_size_to_gpui(size: ImageSize) -> Size<DevicePixels> {
|
||||
Size {
|
||||
width: DevicePixels(size.width),
|
||||
height: DevicePixels(size.height),
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelToolResultContent {
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Text(text) => Some(text),
|
||||
Self::Image(_) => None,
|
||||
}
|
||||
/// Convert a gpui `Size<DevicePixels>` to a core `ImageSize`.
|
||||
pub fn gpui_size_to_image_size(size: Size<DevicePixels>) -> ImageSize {
|
||||
ImageSize {
|
||||
width: size.width.0,
|
||||
height: size.height.0,
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
|
||||
Self::Image(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for LanguageModelToolResultContent {
|
||||
fn from(value: &str) -> Self {
|
||||
Self::Text(Arc::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelToolResultContent {
|
||||
fn from(value: String) -> Self {
|
||||
Self::Text(Arc::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for LanguageModelToolResultContent {
|
||||
fn from(image: LanguageModelImage) -> Self {
|
||||
Self::Image(image)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking(String),
|
||||
Image(LanguageModelImage),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolResult(LanguageModelToolResult),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(text) => Some(text.as_str()),
|
||||
MessageContent::Thinking { text, .. } => Some(text.as_str()),
|
||||
MessageContent::RedactedThinking(_) => None,
|
||||
MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
|
||||
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
|
||||
MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
|
||||
MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
|
||||
MessageContent::RedactedThinking(_)
|
||||
| MessageContent::ToolUse(_)
|
||||
| MessageContent::Image(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for MessageContent {
|
||||
fn from(value: String) -> Self {
|
||||
MessageContent::Text(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for MessageContent {
|
||||
fn from(value: &str) -> Self {
|
||||
MessageContent::Text(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
pub cache: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
pub fn string_contents(&self) -> String {
|
||||
let mut buffer = String::new();
|
||||
for string in self.content.iter().filter_map(|content| content.to_str()) {
|
||||
buffer.push_str(string);
|
||||
}
|
||||
|
||||
buffer
|
||||
}
|
||||
|
||||
pub fn contents_empty(&self) -> bool {
|
||||
self.content.iter().all(|content| content.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelRequestTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
pub use_input_streaming: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionIntent {
|
||||
UserPrompt,
|
||||
Subagent,
|
||||
ToolResults,
|
||||
ThreadSummarization,
|
||||
ThreadContextSummarization,
|
||||
CreateFile,
|
||||
EditFile,
|
||||
InlineAssist,
|
||||
TerminalInlineAssist,
|
||||
GenerateGitCommitMessage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub thread_id: Option<String>,
|
||||
pub prompt_id: Option<String>,
|
||||
pub intent: Option<CompletionIntent>,
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub tools: Vec<LanguageModelRequestTool>,
|
||||
pub tool_choice: Option<LanguageModelToolChoice>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: Option<f32>,
|
||||
pub thinking_allowed: bool,
|
||||
pub thinking_effort: Option<String>,
|
||||
pub speed: Option<Speed>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Speed {
|
||||
#[default]
|
||||
Standard,
|
||||
Fast,
|
||||
}
|
||||
|
||||
impl Speed {
|
||||
pub fn toggle(self) -> Self {
|
||||
match self {
|
||||
Speed::Standard => Speed::Fast,
|
||||
Speed::Fast => Speed::Standard,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Speed> for anthropic::Speed {
|
||||
fn from(speed: Speed) -> Self {
|
||||
match speed {
|
||||
Speed::Standard => anthropic::Speed::Standard,
|
||||
Speed::Fast => anthropic::Speed::Fast,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -509,231 +168,64 @@ mod tests {
|
|||
use super::*;
|
||||
use base64::Engine as _;
|
||||
use gpui::TestAppContext;
|
||||
use image::ImageDecoder as _;
|
||||
|
||||
fn base64_to_png_bytes(base64_png: &str) -> Vec<u8> {
|
||||
fn base64_to_png_bytes(base64: &str) -> Vec<u8> {
|
||||
base64::engine::general_purpose::STANDARD
|
||||
.decode(base64_png.as_bytes())
|
||||
.expect("base64 should decode")
|
||||
.decode(base64)
|
||||
.expect("valid base64")
|
||||
}
|
||||
|
||||
fn png_dimensions(png_bytes: &[u8]) -> (u32, u32) {
|
||||
let decoder =
|
||||
image::codecs::png::PngDecoder::new(Cursor::new(png_bytes)).expect("png should decode");
|
||||
decoder.dimensions()
|
||||
let img = image::load_from_memory(png_bytes).expect("valid png");
|
||||
(img.width(), img.height())
|
||||
}
|
||||
|
||||
fn make_noisy_png_bytes(width: u32, height: u32) -> Vec<u8> {
|
||||
// Create an RGBA image with per-pixel variance to avoid PNG compressing too well.
|
||||
let mut img = image::RgbaImage::new(width, height);
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let r = ((x ^ y) & 0xFF) as u8;
|
||||
let g = ((x.wrapping_mul(31) ^ y.wrapping_mul(17)) & 0xFF) as u8;
|
||||
let b = ((x.wrapping_mul(131) ^ y.wrapping_mul(7)) & 0xFF) as u8;
|
||||
img.put_pixel(x, y, image::Rgba([r, g, b, 0xFF]));
|
||||
}
|
||||
}
|
||||
use image::{ImageBuffer, Rgba};
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let mut out = Vec::new();
|
||||
image::DynamicImage::ImageRgba8(img)
|
||||
.write_with_encoder(PngEncoder::new(&mut out))
|
||||
.expect("png encoding should succeed");
|
||||
out
|
||||
let img = ImageBuffer::from_fn(width, height, |x, y| {
|
||||
let mut hasher = std::hash::DefaultHasher::new();
|
||||
(x, y, width, height).hash(&mut hasher);
|
||||
let h = hasher.finish();
|
||||
Rgba([h as u8, (h >> 8) as u8, (h >> 16) as u8, 255])
|
||||
});
|
||||
|
||||
let mut buf = Cursor::new(Vec::new());
|
||||
img.write_with_encoder(PngEncoder::new(&mut buf))
|
||||
.expect("encode");
|
||||
buf.into_inner()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_from_image_downscales_to_default_5mb_limit(cx: &mut TestAppContext) {
|
||||
// Pick a size that reliably produces a PNG > 5MB when filled with noise.
|
||||
// If this fails (image is too small), bump dimensions.
|
||||
let original_png = make_noisy_png_bytes(4096, 4096);
|
||||
let raw_png = make_noisy_png_bytes(4096, 4096);
|
||||
assert!(
|
||||
original_png.len() > DEFAULT_IMAGE_MAX_BYTES,
|
||||
"precondition failed: noisy PNG must exceed DEFAULT_IMAGE_MAX_BYTES"
|
||||
raw_png.len() > DEFAULT_IMAGE_MAX_BYTES,
|
||||
"Test image should exceed the 5 MB limit (actual: {} bytes)",
|
||||
raw_png.len()
|
||||
);
|
||||
|
||||
let image = gpui::Image::from_bytes(ImageFormat::Png, original_png);
|
||||
let image = Arc::new(gpui::Image::from_bytes(ImageFormat::Png, raw_png));
|
||||
let lm_image = cx
|
||||
.update(|cx| LanguageModelImage::from_image(Arc::new(image), cx))
|
||||
.update(|cx| LanguageModelImage::from_image(Arc::clone(&image), cx))
|
||||
.await
|
||||
.expect("image conversion should succeed");
|
||||
.expect("from_image should succeed");
|
||||
|
||||
let encoded_png = base64_to_png_bytes(lm_image.source.as_ref());
|
||||
let decoded_png = base64_to_png_bytes(lm_image.source.as_ref());
|
||||
assert!(
|
||||
encoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES,
|
||||
"expected encoded PNG <= DEFAULT_IMAGE_MAX_BYTES, got {} bytes",
|
||||
encoded_png.len()
|
||||
decoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES,
|
||||
"Encoded PNG should be ≤ {} bytes after downscale, but was {} bytes",
|
||||
DEFAULT_IMAGE_MAX_BYTES,
|
||||
decoded_png.len()
|
||||
);
|
||||
|
||||
// Ensure we actually downscaled in pixels (not just re-encoded).
|
||||
let (w, h) = png_dimensions(&encoded_png);
|
||||
let (w, h) = png_dimensions(&decoded_png);
|
||||
assert!(
|
||||
w < 4096 || h < 4096,
|
||||
"expected image to be downscaled in at least one dimension; got {w}x{h}"
|
||||
w < 4096 && h < 4096,
|
||||
"Dimensions should have shrunk: got {}×{}",
|
||||
w,
|
||||
h
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_result_content_deserialization() {
|
||||
let json = r#""This is plain text""#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("This is plain text".into())
|
||||
);
|
||||
|
||||
let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("This is wrapped text".into())
|
||||
);
|
||||
|
||||
let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Case insensitive".into())
|
||||
);
|
||||
|
||||
let json = r#"{"Text": "Wrapped variant"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Wrapped variant".into())
|
||||
);
|
||||
|
||||
let json = r#"{"text": "Lowercase wrapped"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Lowercase wrapped".into())
|
||||
);
|
||||
|
||||
// Test image deserialization
|
||||
let json = r#"{
|
||||
"source": "base64encodedimagedata",
|
||||
"size": {
|
||||
"width": 100,
|
||||
"height": 200
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "base64encodedimagedata");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width.0, 100);
|
||||
assert_eq!(size.height.0, 200);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test wrapped Image variant
|
||||
let json = r#"{
|
||||
"Image": {
|
||||
"source": "wrappedimagedata",
|
||||
"size": {
|
||||
"width": 50,
|
||||
"height": 75
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "wrappedimagedata");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width.0, 50);
|
||||
assert_eq!(size.height.0, 75);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test wrapped Image variant with case insensitive
|
||||
let json = r#"{
|
||||
"image": {
|
||||
"Source": "caseinsensitive",
|
||||
"SIZE": {
|
||||
"width": 30,
|
||||
"height": 40
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "caseinsensitive");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width.0, 30);
|
||||
assert_eq!(size.height.0, 40);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test that wrapped text with wrong type fails
|
||||
let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test that malformed JSON fails
|
||||
let json = r#"{"invalid": "structure"}"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test edge cases
|
||||
let json = r#""""#; // Empty string
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
|
||||
|
||||
// Test with extra fields in wrapped text (should be ignored)
|
||||
let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
|
||||
|
||||
// Test direct image with case-insensitive fields
|
||||
let json = r#"{
|
||||
"SOURCE": "directimage",
|
||||
"Size": {
|
||||
"width": 200,
|
||||
"height": 300
|
||||
}
|
||||
}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
match result {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "directimage");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width.0, 200);
|
||||
assert_eq!(size.height.0, 300);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test that multiple fields prevent wrapped variant interpretation
|
||||
let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test wrapped text with uppercase TEXT variant
|
||||
let json = r#"{"TEXT": "Uppercase variant"}"#;
|
||||
let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
LanguageModelToolResultContent::Text("Uppercase variant".into())
|
||||
);
|
||||
|
||||
// Test that numbers and other JSON values fail gracefully
|
||||
let json = r#"123"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
let json = r#"null"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
|
||||
let json = r#"[1, 2, 3]"#;
|
||||
let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
27
crates/language_model_core/Cargo.toml
Normal file
27
crates/language_model_core/Cargo.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "language_model_core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/language_model_core.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
futures.workspace = true
|
||||
gpui_shared_string.workspace = true
|
||||
http_client.workspace = true
|
||||
partial-json-fixer.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
1
crates/language_model_core/LICENSE-GPL
Symbolic link
1
crates/language_model_core/LICENSE-GPL
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
||||
658
crates/language_model_core/src/language_model_core.rs
Normal file
658
crates/language_model_core/src/language_model_core.rs
Normal file
|
|
@ -0,0 +1,658 @@
|
|||
mod provider;
|
||||
mod rate_limiter;
|
||||
mod request;
|
||||
mod role;
|
||||
pub mod tool_schema;
|
||||
pub mod util;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::CompletionRequestStatus;
|
||||
use http_client::{StatusCode, http};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::{Add, Sub};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
use thiserror::Error;
|
||||
fn is_default<T: Default + PartialEq>(value: &T) -> bool {
|
||||
*value == T::default()
|
||||
}
|
||||
|
||||
pub use crate::provider::*;
|
||||
pub use crate::rate_limiter::*;
|
||||
pub use crate::request::*;
|
||||
pub use crate::role::*;
|
||||
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
|
||||
pub use crate::util::{fix_streamed_json, parse_prompt_too_long, parse_tool_arguments};
|
||||
pub use gpui_shared_string::SharedString;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LanguageModelCacheConfiguration {
|
||||
pub max_cache_anchors: usize,
|
||||
pub should_speculate: bool,
|
||||
pub min_total_token: u64,
|
||||
}
|
||||
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
Queued {
|
||||
position: usize,
|
||||
},
|
||||
Started,
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking {
|
||||
data: String,
|
||||
},
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolUseJsonParseError {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
},
|
||||
ReasoningDetails(serde_json::Value),
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionEvent {
|
||||
pub fn from_completion_request_status(
|
||||
status: CompletionRequestStatus,
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
) -> Result<Option<Self>, LanguageModelCompletionError> {
|
||||
match status {
|
||||
CompletionRequestStatus::Queued { position } => {
|
||||
Ok(Some(LanguageModelCompletionEvent::Queued { position }))
|
||||
}
|
||||
CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
|
||||
CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
|
||||
CompletionRequestStatus::Failed {
|
||||
code,
|
||||
message,
|
||||
request_id: _,
|
||||
retry_after,
|
||||
} => Err(LanguageModelCompletionError::from_cloud_failure(
|
||||
upstream_provider,
|
||||
code,
|
||||
message,
|
||||
retry_after.map(Duration::from_secs_f64),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("missing {provider} API key")]
|
||||
NoApiKey { provider: LanguageModelProviderName },
|
||||
#[error("{provider}'s API rate limit exceeded")]
|
||||
RateLimitExceeded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API servers are overloaded right now")]
|
||||
ServerOverloaded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("{message}")]
|
||||
UpstreamProviderError {
|
||||
message: String,
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
#[error("invalid request format to {provider}'s API: {message}")]
|
||||
BadRequestFormat {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("authentication error with {provider}'s API: {message}")]
|
||||
AuthenticationError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("Permission error with {provider}'s API: {message}")]
|
||||
PermissionError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||
#[error("I/O error reading response from {provider}'s API")]
|
||||
ApiReadResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: io::Error,
|
||||
},
|
||||
#[error("error serializing request to {provider} API")]
|
||||
SerializeRequest {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("error building request body to {provider} API")]
|
||||
BuildRequestBody {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: http::Error,
|
||||
},
|
||||
#[error("error sending HTTP request to {provider} API")]
|
||||
HttpSend {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: anyhow::Error,
|
||||
},
|
||||
#[error("error deserializing {provider} API response")]
|
||||
DeserializeResponse {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("stream from {provider} ended unexpectedly")]
|
||||
StreamEndedUnexpectedly { provider: LanguageModelProviderName },
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
||||
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
||||
let upstream_status = error_json
|
||||
.get("upstream_status")
|
||||
.and_then(|v| v.as_u64())
|
||||
.and_then(|status| u16::try_from(status).ok())
|
||||
.and_then(|status| StatusCode::from_u16(status).ok())?;
|
||||
let inner_message = error_json
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(message)
|
||||
.to_string();
|
||||
Some((upstream_status, inner_message))
|
||||
}
|
||||
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
if let Some(tokens) = parse_prompt_too_long(&message) {
|
||||
Self::PromptTooLarge {
|
||||
tokens: Some(tokens),
|
||||
}
|
||||
} else if code == "upstream_http_error" {
|
||||
if let Some((upstream_status, inner_message)) =
|
||||
Self::parse_upstream_error_json(&message)
|
||||
{
|
||||
return Self::from_http_status(
|
||||
upstream_provider,
|
||||
upstream_status,
|
||||
inner_message,
|
||||
retry_after,
|
||||
);
|
||||
}
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||
} else {
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_http_status(
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
match status_code {
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&message),
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StopReason {
|
||||
EndTurn,
|
||||
MaxTokens,
|
||||
ToolUse,
|
||||
Refusal,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
|
||||
pub struct TokenUsage {
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub input_tokens: u64,
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub output_tokens: u64,
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub cache_creation_input_tokens: u64,
|
||||
#[serde(default, skip_serializing_if = "is_default")]
|
||||
pub cache_read_input_tokens: u64,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total_tokens(&self) -> u64 {
|
||||
self.input_tokens
|
||||
+ self.output_tokens
|
||||
+ self.cache_read_input_tokens
|
||||
+ self.cache_creation_input_tokens
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<TokenUsage> for TokenUsage {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> Self {
|
||||
Self {
|
||||
input_tokens: self.input_tokens + other.input_tokens,
|
||||
output_tokens: self.output_tokens + other.output_tokens,
|
||||
cache_creation_input_tokens: self.cache_creation_input_tokens
|
||||
+ other.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<TokenUsage> for TokenUsage {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> Self {
|
||||
Self {
|
||||
input_tokens: self.input_tokens - other.input_tokens,
|
||||
output_tokens: self.output_tokens - other.output_tokens,
|
||||
cache_creation_input_tokens: self.cache_creation_input_tokens
|
||||
- other.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUseId(Arc<str>);
|
||||
|
||||
impl fmt::Display for LanguageModelToolUseId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for LanguageModelToolUseId
|
||||
where
|
||||
T: Into<Arc<str>>,
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUse {
|
||||
pub id: LanguageModelToolUseId,
|
||||
pub name: Arc<str>,
|
||||
pub raw_input: String,
|
||||
pub input: serde_json::Value,
|
||||
pub is_input_complete: bool,
|
||||
/// Thought signature the model sent us. Some models require that this
|
||||
/// signature be preserved and sent back in conversation history for validation.
|
||||
pub thought_signature: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LanguageModelEffortLevel {
|
||||
pub name: SharedString,
|
||||
pub value: SharedString,
|
||||
pub is_default: bool,
|
||||
}
|
||||
|
||||
/// An error that occurred when trying to authenticate the language model provider.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthenticateError {
|
||||
#[error("connection refused")]
|
||||
ConnectionRefused,
|
||||
#[error("credentials not found")]
|
||||
CredentialsNotFound,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
pub struct LanguageModelId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelName(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelProviderId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelProviderName(pub SharedString);
|
||||
|
||||
impl LanguageModelProviderId {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderName {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderName {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelProviderName {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<str>> for LanguageModelProviderId {
|
||||
fn from(value: Arc<str>) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<str>> for LanguageModelProviderName {
|
||||
fn from(value: Arc<str>) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings-layer–free model mode enum.
|
||||
///
|
||||
/// Mirrors the shape of `settings_content::ModelMode` but lives here so that
|
||||
/// crates below the settings layer can reference it.
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ModelMode {
|
||||
#[default]
|
||||
Default,
|
||||
Thinking {
|
||||
budget_tokens: Option<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Settings-layer–free reasoning-effort enum.
|
||||
///
|
||||
/// Mirrors the shape of `settings_content::OpenAiReasoningEffort` but lives
|
||||
/// here so that crates below the settings layer can reference it.
|
||||
#[derive(
|
||||
Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, strum::EnumString,
|
||||
)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum ReasoningEffort {
|
||||
Minimal,
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
XHigh,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_upstream_http_error() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
assert_eq!(message, "Internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for 500 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_standard_format() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_503".to_string(),
|
||||
"Service unavailable".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upstream_http_error_connection_timeout() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
|
||||
error
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_serializes_with_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_tool".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("test_signature".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&tool_use).unwrap();
|
||||
|
||||
assert_eq!(serialized["id"], "test_id");
|
||||
assert_eq!(serialized["name"], "test_tool");
|
||||
assert_eq!(serialized["thought_signature"], "test_signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_deserializes_with_missing_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let json = json!({
|
||||
"id": "test_id",
|
||||
"name": "test_tool",
|
||||
"raw_input": "{\"arg\":\"value\"}",
|
||||
"input": {"arg": "value"},
|
||||
"is_input_complete": true
|
||||
});
|
||||
|
||||
let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
|
||||
|
||||
assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
|
||||
assert_eq!(tool_use.name.as_ref(), "test_tool");
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_round_trip_with_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let original = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("round_trip_id"),
|
||||
name: "round_trip_tool".into(),
|
||||
raw_input: json!({"key": "value"}).to_string(),
|
||||
input: json!({"key": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("round_trip_sig".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.id, original.id);
|
||||
assert_eq!(deserialized.name, original.name);
|
||||
assert_eq!(deserialized.thought_signature, original.thought_signature);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_round_trip_without_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let original = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("no_sig_id"),
|
||||
name: "no_sig_tool".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.id, original.id);
|
||||
assert_eq!(deserialized.name, original.name);
|
||||
assert_eq!(deserialized.thought_signature, None);
|
||||
}
|
||||
}
|
||||
21
crates/language_model_core/src/provider.rs
Normal file
21
crates/language_model_core/src/provider.rs
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
use crate::{LanguageModelProviderId, LanguageModelProviderName};
|
||||
|
||||
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
|
||||
LanguageModelProviderId::new("anthropic");
|
||||
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Anthropic");
|
||||
|
||||
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
|
||||
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("OpenAI");
|
||||
|
||||
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
|
||||
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Google AI");
|
||||
|
||||
pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
|
||||
pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
|
||||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
|
||||
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Zed");
|
||||
463
crates/language_model_core/src/request.rs
Normal file
463
crates/language_model_core/src/request.rs
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::role::Role;
|
||||
use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
|
||||
|
||||
/// Dimensions of a `LanguageModelImage`
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ImageSize {
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct LanguageModelImage {
|
||||
/// A base64-encoded PNG image.
|
||||
pub source: SharedString,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<ImageSize>,
|
||||
}
|
||||
|
||||
impl LanguageModelImage {
|
||||
pub fn len(&self) -> usize {
|
||||
self.source.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.source.is_empty()
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
source: "".into(),
|
||||
size: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse Self from a JSON object with case-insensitive field names
|
||||
pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
|
||||
let mut source = None;
|
||||
let mut size_obj = None;
|
||||
|
||||
for (k, v) in obj.iter() {
|
||||
match k.to_lowercase().as_str() {
|
||||
"source" => source = v.as_str(),
|
||||
"size" => size_obj = v.as_object(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let source = source?;
|
||||
let size_obj = size_obj?;
|
||||
|
||||
let mut width = None;
|
||||
let mut height = None;
|
||||
|
||||
for (k, v) in size_obj.iter() {
|
||||
match k.to_lowercase().as_str() {
|
||||
"width" => width = v.as_i64().map(|w| w as i32),
|
||||
"height" => height = v.as_i64().map(|h| h as i32),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
size: Some(ImageSize {
|
||||
width: width?,
|
||||
height: height?,
|
||||
}),
|
||||
source: SharedString::from(source.to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn estimate_tokens(&self) -> usize {
|
||||
let Some(size) = self.size.as_ref() else {
|
||||
return 0;
|
||||
};
|
||||
let width = size.width.unsigned_abs() as usize;
|
||||
let height = size.height.unsigned_abs() as usize;
|
||||
|
||||
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
|
||||
(width * height) / 750
|
||||
}
|
||||
|
||||
pub fn to_base64_url(&self) -> String {
|
||||
format!("data:image/png;base64,{}", self.source)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for LanguageModelImage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("LanguageModelImage")
|
||||
.field("source", &format!("<{} bytes>", self.source.len()))
|
||||
.field("size", &self.size)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub struct LanguageModelToolResult {
|
||||
pub tool_use_id: LanguageModelToolUseId,
|
||||
pub tool_name: Arc<str>,
|
||||
pub is_error: bool,
|
||||
/// The tool output formatted for presenting to the model
|
||||
pub content: LanguageModelToolResultContent,
|
||||
/// The raw tool output, if available, often for debugging or extra state for replay
|
||||
pub output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
|
||||
pub enum LanguageModelToolResultContent {
|
||||
Text(Arc<str>),
|
||||
Image(LanguageModelImage),
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
let value = serde_json::Value::deserialize(deserializer)?;
|
||||
|
||||
// 1. Try as plain string
|
||||
if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
|
||||
// 2. Try as object
|
||||
if let Some(obj) = value.as_object() {
|
||||
fn get_field<'a>(
|
||||
obj: &'a serde_json::Map<String, serde_json::Value>,
|
||||
field: &str,
|
||||
) -> Option<&'a serde_json::Value> {
|
||||
obj.iter()
|
||||
.find(|(k, _)| k.to_lowercase() == field.to_lowercase())
|
||||
.map(|(_, v)| v)
|
||||
}
|
||||
|
||||
// Accept wrapped text format: { "type": "text", "text": "..." }
|
||||
if let (Some(type_value), Some(text_value)) =
|
||||
(get_field(obj, "type"), get_field(obj, "text"))
|
||||
&& let Some(type_str) = type_value.as_str()
|
||||
&& type_str.to_lowercase() == "text"
|
||||
&& let Some(text) = text_value.as_str()
|
||||
{
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
|
||||
// Check for wrapped Text variant: { "text": "..." }
|
||||
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
|
||||
&& obj.len() == 1
|
||||
{
|
||||
if let Some(text) = value.as_str() {
|
||||
return Ok(Self::Text(Arc::from(text)));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
|
||||
if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
|
||||
&& obj.len() == 1
|
||||
{
|
||||
if let Some(image_obj) = value.as_object()
|
||||
&& let Some(image) = LanguageModelImage::from_json(image_obj)
|
||||
{
|
||||
return Ok(Self::Image(image));
|
||||
}
|
||||
}
|
||||
|
||||
// Try as direct Image
|
||||
if let Some(image) = LanguageModelImage::from_json(obj) {
|
||||
return Ok(Self::Image(image));
|
||||
}
|
||||
}
|
||||
|
||||
Err(D::Error::custom(format!(
|
||||
"data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
|
||||
an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
|
||||
serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelToolResultContent {
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Text(text) => Some(text),
|
||||
Self::Image(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
|
||||
Self::Image(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for LanguageModelToolResultContent {
|
||||
fn from(value: &str) -> Self {
|
||||
Self::Text(Arc::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelToolResultContent {
|
||||
fn from(value: String) -> Self {
|
||||
Self::Text(Arc::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for LanguageModelToolResultContent {
|
||||
fn from(image: LanguageModelImage) -> Self {
|
||||
Self::Image(image)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking(String),
|
||||
Image(LanguageModelImage),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolResult(LanguageModelToolResult),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(text) => Some(text.as_str()),
|
||||
MessageContent::Thinking { text, .. } => Some(text.as_str()),
|
||||
MessageContent::RedactedThinking(_) => None,
|
||||
MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
|
||||
MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
|
||||
MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
|
||||
MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
|
||||
MessageContent::RedactedThinking(_)
|
||||
| MessageContent::ToolUse(_)
|
||||
| MessageContent::Image(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for MessageContent {
|
||||
fn from(value: String) -> Self {
|
||||
MessageContent::Text(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for MessageContent {
|
||||
fn from(value: &str) -> Self {
|
||||
MessageContent::Text(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
pub cache: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
pub fn string_contents(&self) -> String {
|
||||
let mut buffer = String::new();
|
||||
for string in self.content.iter().filter_map(|content| content.to_str()) {
|
||||
buffer.push_str(string);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
|
||||
pub fn contents_empty(&self) -> bool {
|
||||
self.content.iter().all(|content| content.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelRequestTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
pub use_input_streaming: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionIntent {
|
||||
UserPrompt,
|
||||
Subagent,
|
||||
ToolResults,
|
||||
ThreadSummarization,
|
||||
ThreadContextSummarization,
|
||||
CreateFile,
|
||||
EditFile,
|
||||
InlineAssist,
|
||||
TerminalInlineAssist,
|
||||
GenerateGitCommitMessage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub thread_id: Option<String>,
|
||||
pub prompt_id: Option<String>,
|
||||
pub intent: Option<CompletionIntent>,
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub tools: Vec<LanguageModelRequestTool>,
|
||||
pub tool_choice: Option<LanguageModelToolChoice>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: Option<f32>,
|
||||
pub thinking_allowed: bool,
|
||||
pub thinking_effort: Option<String>,
|
||||
pub speed: Option<Speed>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Speed {
|
||||
#[default]
|
||||
Standard,
|
||||
Fast,
|
||||
}
|
||||
|
||||
impl Speed {
|
||||
pub fn toggle(self) -> Self {
|
||||
match self {
|
||||
Speed::Standard => Speed::Fast,
|
||||
Speed::Fast => Speed::Standard,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_result_content_deserialization() {
|
||||
// Test plain string
|
||||
let json = serde_json::json!("hello world");
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(
|
||||
content,
|
||||
LanguageModelToolResultContent::Text(Arc::from("hello world"))
|
||||
);
|
||||
|
||||
// Test wrapped text format: { "type": "text", "text": "..." }
|
||||
let json = serde_json::json!({"type": "text", "text": "hello"});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(
|
||||
content,
|
||||
LanguageModelToolResultContent::Text(Arc::from("hello"))
|
||||
);
|
||||
|
||||
// Test single-field text object: { "text": "..." }
|
||||
let json = serde_json::json!({"text": "hello"});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(
|
||||
content,
|
||||
LanguageModelToolResultContent::Text(Arc::from("hello"))
|
||||
);
|
||||
|
||||
// Test case-insensitive type field
|
||||
let json = serde_json::json!({"Type": "Text", "Text": "hello"});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(
|
||||
content,
|
||||
LanguageModelToolResultContent::Text(Arc::from("hello"))
|
||||
);
|
||||
|
||||
// Test image object
|
||||
let json = serde_json::json!({
|
||||
"source": "base64encodedimagedata",
|
||||
"size": {"width": 100, "height": 200}
|
||||
});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
match content {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "base64encodedimagedata");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width, 100);
|
||||
assert_eq!(size.height, 200);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test wrapped image: { "image": { "source": "...", "size": ... } }
|
||||
let json = serde_json::json!({
|
||||
"image": {
|
||||
"source": "wrappedimagedata",
|
||||
"size": {"width": 50, "height": 75}
|
||||
}
|
||||
});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
match content {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "wrappedimagedata");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width, 50);
|
||||
assert_eq!(size.height, 75);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test case insensitive
|
||||
let json = serde_json::json!({
|
||||
"Source": "caseinsensitive",
|
||||
"Size": {"Width": 30, "Height": 40}
|
||||
});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
match content {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "caseinsensitive");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width, 30);
|
||||
assert_eq!(size.height, 40);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
|
||||
// Test direct image object
|
||||
let json = serde_json::json!({
|
||||
"source": "directimage",
|
||||
"size": {"width": 200, "height": 300}
|
||||
});
|
||||
let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
|
||||
match content {
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
assert_eq!(image.source.as_ref(), "directimage");
|
||||
let size = image.size.expect("size");
|
||||
assert_eq!(size.width, 200);
|
||||
assert_eq!(size.height, 300);
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -77,8 +77,6 @@ pub fn adapt_schema_to_format(
|
|||
}
|
||||
|
||||
fn preprocess_json_schema(json: &mut Value) -> Result<()> {
|
||||
// `additionalProperties` defaults to `false` unless explicitly specified.
|
||||
// This prevents models from hallucinating tool parameters.
|
||||
if let Value::Object(obj) = json
|
||||
&& matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
|
||||
{
|
||||
|
|
@ -86,7 +84,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
|
|||
obj.insert("additionalProperties".to_string(), Value::Bool(false));
|
||||
}
|
||||
|
||||
// OpenAI API requires non-missing `properties`
|
||||
if !obj.contains_key("properties") {
|
||||
obj.insert("properties".to_string(), Value::Object(Default::default()));
|
||||
}
|
||||
|
|
@ -94,7 +91,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
|
||||
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
if let Value::Object(obj) = json {
|
||||
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
|
||||
|
|
@ -108,9 +104,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
|||
|
||||
const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
|
||||
("format", |value| value.is_string()),
|
||||
// Gemini doesn't support `additionalProperties` in any form (boolean or schema object)
|
||||
("additionalProperties", |_| true),
|
||||
// Gemini doesn't support `propertyNames`
|
||||
("propertyNames", |_| true),
|
||||
("exclusiveMinimum", |value| value.is_number()),
|
||||
("exclusiveMaximum", |value| value.is_number()),
|
||||
|
|
@ -124,7 +118,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
// If a type is not specified for an input parameter, add a default type
|
||||
if matches!(obj.get("description"), Some(Value::String(_)))
|
||||
&& !obj.contains_key("type")
|
||||
&& !(obj.contains_key("anyOf")
|
||||
|
|
@ -134,7 +127,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
|||
obj.insert("type".to_string(), Value::String("string".to_string()));
|
||||
}
|
||||
|
||||
// Handle oneOf -> anyOf conversion
|
||||
if let Some(subschemas) = obj.get_mut("oneOf")
|
||||
&& subschemas.is_array()
|
||||
{
|
||||
|
|
@ -143,7 +135,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
|||
obj.insert("anyOf".to_string(), subschemas_clone);
|
||||
}
|
||||
|
||||
// Recursively process all nested objects and arrays
|
||||
for (_, value) in obj.iter_mut() {
|
||||
if let Value::Object(_) | Value::Array(_) = value {
|
||||
adapt_to_json_schema_subset(value)?;
|
||||
|
|
@ -178,7 +169,6 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
// Ensure that we do not add a type if it is an object
|
||||
let mut json = json!({
|
||||
"description": {
|
||||
"value": "abc",
|
||||
|
|
@ -221,7 +211,6 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
// Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
|
||||
let mut json = json!({
|
||||
"description": "A test field",
|
||||
"type": "integer",
|
||||
|
|
@ -239,7 +228,6 @@ mod tests {
|
|||
})
|
||||
);
|
||||
|
||||
// additionalProperties as an object schema is also unsupported by Gemini
|
||||
let mut json = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -38,13 +38,22 @@ fn strip_trailing_incomplete_escape(json: &str) -> &str {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parses a "prompt is too long: N tokens ..." message and extracts the token count.
|
||||
pub fn parse_prompt_too_long(message: &str) -> Option<u64> {
|
||||
message
|
||||
.strip_prefix("prompt is too long: ")?
|
||||
.split_once(" tokens")?
|
||||
.0
|
||||
.parse()
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fix_streamed_json_strips_incomplete_escape() {
|
||||
// Trailing `\` inside a string — incomplete escape sequence
|
||||
let fixed = fix_streamed_json(r#"{"text": "hello\"#);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
|
||||
assert_eq!(parsed["text"], "hello");
|
||||
|
|
@ -52,7 +61,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_streamed_json_preserves_complete_escape() {
|
||||
// `\\` is a complete escape (literal backslash)
|
||||
let fixed = fix_streamed_json(r#"{"text": "hello\\"#);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
|
||||
assert_eq!(parsed["text"], "hello\\");
|
||||
|
|
@ -60,7 +68,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_streamed_json_strips_escape_after_complete_escape() {
|
||||
// `\\\` = complete `\\` (literal backslash) + incomplete `\`
|
||||
let fixed = fix_streamed_json(r#"{"text": "hello\\\"#);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
|
||||
assert_eq!(parsed["text"], "hello\\");
|
||||
|
|
@ -75,12 +82,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_streamed_json_newline_escape_boundary() {
|
||||
// Simulates a stream boundary landing between `\` and `n`
|
||||
let fixed = fix_streamed_json(r#"{"text": "line1\"#);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
|
||||
assert_eq!(parsed["text"], "line1");
|
||||
|
||||
// Next chunk completes the escape
|
||||
let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
|
||||
assert_eq!(parsed["text"], "line1\nline2");
|
||||
|
|
@ -88,8 +93,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_streamed_json_incremental_delta_correctness() {
|
||||
// This is the actual scenario that causes the bug:
|
||||
// chunk 1 ends mid-escape, chunk 2 completes it.
|
||||
let chunk1 = r#"{"replacement_text": "fn foo() {\"#;
|
||||
let fixed1 = fix_streamed_json(chunk1);
|
||||
let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json");
|
||||
|
|
@ -102,7 +105,6 @@ mod tests {
|
|||
let text2 = parsed2["replacement_text"].as_str().expect("string");
|
||||
assert_eq!(text2, "fn foo() {\n return bar;\n}");
|
||||
|
||||
// The delta should be the newline + rest, with no spurious backslash
|
||||
let delta = &text2[text1.len()..];
|
||||
assert_eq!(delta, "\n return bar;\n}");
|
||||
}
|
||||
|
|
@ -21,8 +21,8 @@ aws_http_client.workspace = true
|
|||
base64.workspace = true
|
||||
bedrock = { workspace = true, features = ["schemars"] }
|
||||
client.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
convert_case.workspace = true
|
||||
|
|
@ -41,6 +41,7 @@ gpui_tokio.workspace = true
|
|||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models_cloud.workspace = true
|
||||
lmstudio = { workspace = true, features = ["schemars"] }
|
||||
log.workspace = true
|
||||
menu.workspace = true
|
||||
|
|
@ -49,16 +50,13 @@ ollama = { workspace = true, features = ["schemars"] }
|
|||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
opencode = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
partial-json-fixer.workspace = true
|
||||
release_channel.workspace = true
|
||||
schemars.workspace = true
|
||||
semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
|
|
@ -70,4 +68,3 @@ x_ai = { workspace = true, features = ["schemars"] }
|
|||
[dev-dependencies]
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ pub mod open_ai;
|
|||
pub mod open_ai_compatible;
|
||||
pub mod open_router;
|
||||
pub mod opencode;
|
||||
mod util;
|
||||
|
||||
pub mod vercel;
|
||||
pub mod vercel_ai_gateway;
|
||||
pub mod x_ai;
|
||||
|
|
|
|||
|
|
@ -1,13 +1,10 @@
|
|||
pub mod telemetry;
|
||||
|
||||
use anthropic::{
|
||||
ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event,
|
||||
ResponseContent, ToolResultContent, ToolResultPart, Usage,
|
||||
};
|
||||
use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode};
|
||||
use anyhow::Result;
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
|
|
@ -16,20 +13,19 @@ use language_model::{
|
|||
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason, env_var,
|
||||
LanguageModelToolChoice, RateLimiter, env_var,
|
||||
};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
||||
use ui_input::InputField;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
|
||||
|
||||
pub use anthropic::completion::{
|
||||
AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
|
||||
into_anthropic_count_tokens_request,
|
||||
};
|
||||
pub use settings::AnthropicAvailableModel as AvailableModel;
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID;
|
||||
|
|
@ -249,228 +245,6 @@ pub struct AnthropicModel {
|
|||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
fn to_anthropic_content(content: MessageContent) -> Option<anthropic::RequestContent> {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
|
||||
text.trim_end().to_string()
|
||||
} else {
|
||||
text
|
||||
};
|
||||
if !text.is_empty() {
|
||||
Some(anthropic::RequestContent::Text {
|
||||
text,
|
||||
cache_control: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking {
|
||||
text: thinking,
|
||||
signature,
|
||||
} => {
|
||||
if let Some(signature) = signature
|
||||
&& !thinking.is_empty()
|
||||
{
|
||||
Some(anthropic::RequestContent::Thinking {
|
||||
thinking,
|
||||
signature,
|
||||
cache_control: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(data) => {
|
||||
if !data.is_empty() {
|
||||
Some(anthropic::RequestContent::RedactedThinking { data })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
cache_control: None,
|
||||
}),
|
||||
MessageContent::ToolUse(tool_use) => Some(anthropic::RequestContent::ToolUse {
|
||||
id: tool_use.id.to_string(),
|
||||
name: tool_use.name.to_string(),
|
||||
input: tool_use.input,
|
||||
cache_control: None,
|
||||
}),
|
||||
MessageContent::ToolResult(tool_result) => Some(anthropic::RequestContent::ToolResult {
|
||||
tool_use_id: tool_result.tool_use_id.to_string(),
|
||||
is_error: tool_result.is_error,
|
||||
content: match tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
ToolResultContent::Plain(text.to_string())
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
ToolResultContent::Multipart(vec![ToolResultPart::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
}])
|
||||
}
|
||||
},
|
||||
cache_control: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
|
||||
pub fn into_anthropic_count_tokens_request(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
mode: AnthropicModelMode,
|
||||
) -> CountTokensRequest {
|
||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages {
|
||||
if message.contents_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let anthropic_message_content: Vec<anthropic::RequestContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(to_anthropic_content)
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if anthropic_message_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(last_message) = new_messages.last_mut()
|
||||
&& last_message.role == anthropic_role
|
||||
{
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
|
||||
new_messages.push(anthropic::Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.string_contents());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CountTokensRequest {
|
||||
model,
|
||||
messages: new_messages,
|
||||
system: if system_message.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(anthropic::StringOrContents::String(system_message))
|
||||
},
|
||||
thinking: if request.thinking_allowed {
|
||||
match mode {
|
||||
AnthropicModelMode::Thinking { budget_tokens } => {
|
||||
Some(anthropic::Thinking::Enabled { budget_tokens })
|
||||
}
|
||||
AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
|
||||
AnthropicModelMode::Default => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| anthropic::Tool {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
eager_input_streaming: tool.use_input_streaming,
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
|
||||
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
|
||||
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
|
||||
/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
|
||||
pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
|
||||
let messages = request.messages;
|
||||
let mut tokens_from_images = 0;
|
||||
let mut string_messages = Vec::with_capacity(messages.len());
|
||||
|
||||
for message in messages {
|
||||
let mut string_contents = String::new();
|
||||
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
string_contents.push_str(&text);
|
||||
}
|
||||
MessageContent::Thinking { .. } => {
|
||||
// Thinking blocks are not included in the input token count.
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {
|
||||
// Thinking blocks are not included in the input token count.
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
MessageContent::ToolUse(_tool_use) => {
|
||||
// TODO: Estimate token usage from tool uses.
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
string_contents.push_str(text);
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
tokens_from_images += image.estimate_tokens();
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if !string_contents.is_empty() {
|
||||
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(string_contents),
|
||||
name: None,
|
||||
function_call: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
|
||||
.map(|tokens| (tokens + tokens_from_images) as u64)
|
||||
}
|
||||
|
||||
impl AnthropicModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
|
|
@ -617,10 +391,13 @@ impl LanguageModel for AnthropicModel {
|
|||
)
|
||||
});
|
||||
|
||||
let background = cx.background_executor().clone();
|
||||
async move {
|
||||
// If no API key, fall back to tiktoken estimation
|
||||
let Some(api_key) = api_key else {
|
||||
return count_anthropic_tokens_with_tiktoken(request);
|
||||
return background
|
||||
.spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
|
||||
.await;
|
||||
};
|
||||
|
||||
let count_request =
|
||||
|
|
@ -634,7 +411,9 @@ impl LanguageModel for AnthropicModel {
|
|||
log::error!(
|
||||
"Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
|
||||
);
|
||||
count_anthropic_tokens_with_tiktoken(request)
|
||||
background
|
||||
.spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -678,345 +457,6 @@ impl LanguageModel for AnthropicModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_anthropic(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
default_temperature: f32,
|
||||
max_output_tokens: u64,
|
||||
mode: AnthropicModelMode,
|
||||
) -> anthropic::Request {
|
||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages {
|
||||
if message.contents_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(to_anthropic_content)
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if anthropic_message_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(last_message) = new_messages.last_mut()
|
||||
&& last_message.role == anthropic_role
|
||||
{
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Mark the last segment of the message as cached
|
||||
if message.cache {
|
||||
let cache_control_value = Some(anthropic::CacheControl {
|
||||
cache_type: anthropic::CacheControlType::Ephemeral,
|
||||
});
|
||||
for message_content in anthropic_message_content.iter_mut().rev() {
|
||||
match message_content {
|
||||
anthropic::RequestContent::RedactedThinking { .. } => {
|
||||
// Caching is not possible, fallback to next message
|
||||
}
|
||||
anthropic::RequestContent::Text { cache_control, .. }
|
||||
| anthropic::RequestContent::Thinking { cache_control, .. }
|
||||
| anthropic::RequestContent::Image { cache_control, .. }
|
||||
| anthropic::RequestContent::ToolUse { cache_control, .. }
|
||||
| anthropic::RequestContent::ToolResult { cache_control, .. } => {
|
||||
*cache_control = cache_control_value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(anthropic::Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.string_contents());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages: new_messages,
|
||||
max_tokens: max_output_tokens,
|
||||
system: if system_message.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(anthropic::StringOrContents::String(system_message))
|
||||
},
|
||||
thinking: if request.thinking_allowed {
|
||||
match mode {
|
||||
AnthropicModelMode::Thinking { budget_tokens } => {
|
||||
Some(anthropic::Thinking::Enabled { budget_tokens })
|
||||
}
|
||||
AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
|
||||
AnthropicModelMode::Default => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| anthropic::Tool {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
eager_input_streaming: tool.use_input_streaming,
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
|
||||
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
|
||||
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
|
||||
}),
|
||||
metadata: None,
|
||||
output_config: if request.thinking_allowed
|
||||
&& matches!(mode, AnthropicModelMode::AdaptiveThinking)
|
||||
{
|
||||
request.thinking_effort.as_deref().and_then(|effort| {
|
||||
let effort = match effort {
|
||||
"low" => Some(anthropic::Effort::Low),
|
||||
"medium" => Some(anthropic::Effort::Medium),
|
||||
"high" => Some(anthropic::Effort::High),
|
||||
"max" => Some(anthropic::Effort::Max),
|
||||
_ => None,
|
||||
};
|
||||
effort.map(|effort| anthropic::OutputConfig {
|
||||
effort: Some(effort),
|
||||
})
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
stop_sequences: Vec::new(),
|
||||
speed: request.speed.map(From::from),
|
||||
temperature: request.temperature.or(Some(default_temperature)),
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicEventMapper {
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
usage: Usage,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
impl AnthropicEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
usage: Usage::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(error.into())],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: Event,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
match event {
|
||||
Event::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Text(text))]
|
||||
}
|
||||
ResponseContent::Thinking { thinking } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thinking,
|
||||
signature: None,
|
||||
})]
|
||||
}
|
||||
ResponseContent::RedactedThinking { data } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
self.tool_uses_by_index.insert(
|
||||
index,
|
||||
RawToolUse {
|
||||
id,
|
||||
name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Text(text))]
|
||||
}
|
||||
ContentDelta::ThinkingDelta { thinking } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thinking,
|
||||
signature: None,
|
||||
})]
|
||||
}
|
||||
ContentDelta::SignatureDelta { signature } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".to_string(),
|
||||
signature: Some(signature),
|
||||
})]
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
|
||||
// Try to convert invalid (incomplete) JSON into
|
||||
// valid JSON that serde can accept, e.g. by closing
|
||||
// unclosed delimiters. This way, we can update the
|
||||
// UI with whatever has been streamed back so far.
|
||||
if let Ok(input) =
|
||||
serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
|
||||
{
|
||||
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.clone().into(),
|
||||
name: tool_use.name.clone().into(),
|
||||
is_input_complete: false,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input,
|
||||
thought_signature: None,
|
||||
},
|
||||
))];
|
||||
}
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
|
||||
let input_json = tool_use.input_json.trim();
|
||||
let event_result = match parse_tool_arguments(input_json) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(json_parse_err) => {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_use.id.into(),
|
||||
tool_name: tool_use.name.into(),
|
||||
raw_input: input_json.into(),
|
||||
json_parse_error: json_parse_err.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
vec![event_result]
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
Event::MessageStart { message } => {
|
||||
update_usage(&mut self.usage, &message.usage);
|
||||
vec![
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
||||
&self.usage,
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
}),
|
||||
]
|
||||
}
|
||||
Event::MessageDelta { delta, usage } => {
|
||||
update_usage(&mut self.usage, &usage);
|
||||
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
||||
self.stop_reason = match stop_reason {
|
||||
"end_turn" => StopReason::EndTurn,
|
||||
"max_tokens" => StopReason::MaxTokens,
|
||||
"tool_use" => StopReason::ToolUse,
|
||||
"refusal" => StopReason::Refusal,
|
||||
_ => {
|
||||
log::error!("Unexpected anthropic stop_reason: {stop_reason}");
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
}
|
||||
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
))]
|
||||
}
|
||||
Event::MessageStop => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||
}
|
||||
Event::Error { error } => {
|
||||
vec![Err(error.into())]
|
||||
}
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
/// Updates usage data by preferring counts from `new`.
|
||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||
if let Some(input_tokens) = new.input_tokens {
|
||||
usage.input_tokens = Some(input_tokens);
|
||||
}
|
||||
if let Some(output_tokens) = new.output_tokens {
|
||||
usage.output_tokens = Some(output_tokens);
|
||||
}
|
||||
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
|
||||
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
|
||||
}
|
||||
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
|
||||
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
|
||||
language_model::TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
|
||||
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<InputField>,
|
||||
state: Entity<State>,
|
||||
|
|
@ -1157,192 +597,3 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anthropic::AnthropicModelMode;
|
||||
use language_model::{LanguageModelRequestMessage, MessageContent};
|
||||
|
||||
#[test]
|
||||
fn test_cache_control_only_on_last_segment() {
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
MessageContent::Text("Some prompt".to_string()),
|
||||
MessageContent::Image(language_model::LanguageModelImage::empty()),
|
||||
MessageContent::Image(language_model::LanguageModelImage::empty()),
|
||||
MessageContent::Image(language_model::LanguageModelImage::empty()),
|
||||
MessageContent::Image(language_model::LanguageModelImage::empty()),
|
||||
],
|
||||
cache: true,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thinking_allowed: true,
|
||||
thinking_effort: None,
|
||||
speed: None,
|
||||
};
|
||||
|
||||
let anthropic_request = into_anthropic(
|
||||
request,
|
||||
"claude-3-5-sonnet".to_string(),
|
||||
0.7,
|
||||
4096,
|
||||
AnthropicModelMode::Default,
|
||||
);
|
||||
|
||||
assert_eq!(anthropic_request.messages.len(), 1);
|
||||
|
||||
let message = &anthropic_request.messages[0];
|
||||
assert_eq!(message.content.len(), 5);
|
||||
|
||||
assert!(matches!(
|
||||
message.content[0],
|
||||
anthropic::RequestContent::Text {
|
||||
cache_control: None,
|
||||
..
|
||||
}
|
||||
));
|
||||
for i in 1..3 {
|
||||
assert!(matches!(
|
||||
message.content[i],
|
||||
anthropic::RequestContent::Image {
|
||||
cache_control: None,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
assert!(matches!(
|
||||
message.content[4],
|
||||
anthropic::RequestContent::Image {
|
||||
cache_control: Some(anthropic::CacheControl {
|
||||
cache_type: anthropic::CacheControlType::Ephemeral,
|
||||
}),
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
fn request_with_assistant_content(
|
||||
assistant_content: Vec<MessageContent>,
|
||||
) -> anthropic::Request {
|
||||
let mut request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello".to_string())],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
thinking_effort: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thinking_allowed: true,
|
||||
speed: None,
|
||||
};
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: assistant_content,
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
});
|
||||
into_anthropic(
|
||||
request,
|
||||
"claude-sonnet-4-5".to_string(),
|
||||
1.0,
|
||||
16000,
|
||||
AnthropicModelMode::Thinking {
|
||||
budget_tokens: Some(10000),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsigned_thinking_blocks_stripped() {
|
||||
let result = request_with_assistant_content(vec![
|
||||
MessageContent::Thinking {
|
||||
text: "Cancelled mid-think, no signature".to_string(),
|
||||
signature: None,
|
||||
},
|
||||
MessageContent::Text("Some response text".to_string()),
|
||||
]);
|
||||
|
||||
let assistant_message = result
|
||||
.messages
|
||||
.iter()
|
||||
.find(|m| m.role == anthropic::Role::Assistant)
|
||||
.expect("assistant message should still exist");
|
||||
|
||||
assert_eq!(
|
||||
assistant_message.content.len(),
|
||||
1,
|
||||
"Only the text content should remain; unsigned thinking block should be stripped"
|
||||
);
|
||||
assert!(matches!(
|
||||
&assistant_message.content[0],
|
||||
anthropic::RequestContent::Text { text, .. } if text == "Some response text"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signed_thinking_blocks_preserved() {
|
||||
let result = request_with_assistant_content(vec![
|
||||
MessageContent::Thinking {
|
||||
text: "Completed thinking".to_string(),
|
||||
signature: Some("valid-signature".to_string()),
|
||||
},
|
||||
MessageContent::Text("Response".to_string()),
|
||||
]);
|
||||
|
||||
let assistant_message = result
|
||||
.messages
|
||||
.iter()
|
||||
.find(|m| m.role == anthropic::Role::Assistant)
|
||||
.expect("assistant message should exist");
|
||||
|
||||
assert_eq!(
|
||||
assistant_message.content.len(),
|
||||
2,
|
||||
"Both the signed thinking block and text should be preserved"
|
||||
);
|
||||
assert!(matches!(
|
||||
&assistant_message.content[0],
|
||||
anthropic::RequestContent::Thinking { thinking, signature, .. }
|
||||
if thinking == "Completed thinking" && signature == "valid-signature"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_only_unsigned_thinking_block_omits_entire_message() {
|
||||
let result = request_with_assistant_content(vec![MessageContent::Thinking {
|
||||
text: "Cancelled before any text or signature".to_string(),
|
||||
signature: None,
|
||||
}]);
|
||||
|
||||
let assistant_messages: Vec<_> = result
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| m.role == anthropic::Role::Assistant)
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
assistant_messages.len(),
|
||||
0,
|
||||
"An assistant message whose only content was an unsigned thinking block \
|
||||
should be omitted entirely"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ use ui_input::InputField;
|
|||
use util::ResultExt;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
|
||||
use language_model::util::{fix_streamed_json, parse_tool_arguments};
|
||||
|
||||
actions!(bedrock, [Tab, TabPrev]);
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -32,7 +32,7 @@ use ui::prelude::*;
|
|||
use util::debug_panic;
|
||||
|
||||
use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic};
|
||||
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
|
||||
use language_model::util::{fix_streamed_json, parse_tool_arguments};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
|
||||
const PROVIDER_NAME: LanguageModelProviderName =
|
||||
|
|
@ -268,15 +268,15 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
levels
|
||||
.iter()
|
||||
.map(|level| {
|
||||
let name: SharedString = match level.as_str() {
|
||||
let name = match level.as_str() {
|
||||
"low" => "Low".into(),
|
||||
"medium" => "Medium".into(),
|
||||
"high" => "High".into(),
|
||||
_ => SharedString::from(level.clone()),
|
||||
_ => language_model::SharedString::from(level.clone()),
|
||||
};
|
||||
LanguageModelEffortLevel {
|
||||
name,
|
||||
value: SharedString::from(level.clone()),
|
||||
value: language_model::SharedString::from(level.clone()),
|
||||
is_default: level == "high",
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
|||
use ui_input::InputField;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
|
||||
use language_model::util::{fix_streamed_json, parse_tool_arguments};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
|
||||
|
|
|
|||
|
|
@ -1,32 +1,25 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use google_ai::{
|
||||
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
|
||||
ThinkingConfig, UsageMetadata,
|
||||
};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google};
|
||||
use google_ai::{GenerateContentResponse, GoogleModelMode};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
|
||||
};
|
||||
use language_model::{
|
||||
GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use settings::GoogleAvailableModel as AvailableModel;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{
|
||||
Arc, LazyLock,
|
||||
atomic::{self, AtomicU64},
|
||||
};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
||||
use ui_input::InputField;
|
||||
|
|
@ -394,369 +387,6 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_google(
|
||||
mut request: LanguageModelRequest,
|
||||
model_id: String,
|
||||
mode: GoogleModelMode,
|
||||
) -> google_ai::GenerateContentRequest {
|
||||
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||
content
|
||||
.into_iter()
|
||||
.flat_map(|content| match content {
|
||||
language_model::MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
vec![Part::TextPart(google_ai::TextPart { text })]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
language_model::MessageContent::Thinking {
|
||||
text: _,
|
||||
signature: Some(signature),
|
||||
} => {
|
||||
if !signature.is_empty() {
|
||||
vec![Part::ThoughtPart(google_ai::ThoughtPart {
|
||||
thought: true,
|
||||
thought_signature: signature,
|
||||
})]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
language_model::MessageContent::Thinking { .. } => {
|
||||
vec![]
|
||||
}
|
||||
language_model::MessageContent::RedactedThinking(_) => vec![],
|
||||
language_model::MessageContent::Image(image) => {
|
||||
vec![Part::InlineDataPart(google_ai::InlineDataPart {
|
||||
inline_data: google_ai::GenerativeContentBlob {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
})]
|
||||
}
|
||||
language_model::MessageContent::ToolUse(tool_use) => {
|
||||
// Normalize empty string signatures to None
|
||||
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
|
||||
|
||||
vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
|
||||
function_call: google_ai::FunctionCall {
|
||||
name: tool_use.name.to_string(),
|
||||
args: tool_use.input,
|
||||
},
|
||||
thought_signature,
|
||||
})]
|
||||
}
|
||||
language_model::MessageContent::ToolResult(tool_result) => {
|
||||
match tool_result.content {
|
||||
language_model::LanguageModelToolResultContent::Text(text) => {
|
||||
vec![Part::FunctionResponsePart(
|
||||
google_ai::FunctionResponsePart {
|
||||
function_response: google_ai::FunctionResponse {
|
||||
name: tool_result.tool_name.to_string(),
|
||||
// The API expects a valid JSON object
|
||||
response: serde_json::json!({
|
||||
"output": text
|
||||
}),
|
||||
},
|
||||
},
|
||||
)]
|
||||
}
|
||||
language_model::LanguageModelToolResultContent::Image(image) => {
|
||||
vec![
|
||||
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
|
||||
function_response: google_ai::FunctionResponse {
|
||||
name: tool_result.tool_name.to_string(),
|
||||
// The API expects a valid JSON object
|
||||
response: serde_json::json!({
|
||||
"output": "Tool responded with an image"
|
||||
}),
|
||||
},
|
||||
}),
|
||||
Part::InlineDataPart(google_ai::InlineDataPart {
|
||||
inline_data: google_ai::GenerativeContentBlob {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
}),
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
let system_instructions = if request
|
||||
.messages
|
||||
.first()
|
||||
.is_some_and(|msg| matches!(msg.role, Role::System))
|
||||
{
|
||||
let message = request.messages.remove(0);
|
||||
Some(SystemInstruction {
|
||||
parts: map_content(message.content),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
google_ai::GenerateContentRequest {
|
||||
model: google_ai::ModelName { model_id },
|
||||
system_instruction: system_instructions,
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
let parts = map_content(message.content);
|
||||
if parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(google_ai::Content {
|
||||
parts,
|
||||
role: match message.role {
|
||||
Role::User => google_ai::Role::User,
|
||||
Role::Assistant => google_ai::Role::Model,
|
||||
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
generation_config: Some(google_ai::GenerationConfig {
|
||||
candidate_count: Some(1),
|
||||
stop_sequences: Some(request.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||
thinking_config: match (request.thinking_allowed, mode) {
|
||||
(true, GoogleModelMode::Thinking { budget_tokens }) => {
|
||||
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
tools: (!request.tools.is_empty()).then(|| {
|
||||
vec![google_ai::Tool {
|
||||
function_declarations: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| FunctionDeclaration {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
})
|
||||
.collect(),
|
||||
}]
|
||||
}),
|
||||
tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
|
||||
function_calling_config: google_ai::FunctionCallingConfig {
|
||||
mode: match choice {
|
||||
LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
|
||||
LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
|
||||
LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
|
||||
},
|
||||
allowed_function_names: None,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GoogleEventMapper {
|
||||
usage: UsageMetadata,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
impl GoogleEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
usage: UsageMetadata::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events
|
||||
.map(Some)
|
||||
.chain(futures::stream::once(async { None }))
|
||||
.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Some(Ok(event)) => self.map_event(event),
|
||||
Some(Err(error)) => {
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: GenerateContentResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
let mut events: Vec<_> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut self.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
)))
|
||||
}
|
||||
|
||||
if let Some(prompt_feedback) = event.prompt_feedback
|
||||
&& let Some(block_reason) = prompt_feedback.block_reason.as_deref()
|
||||
{
|
||||
self.stop_reason = match block_reason {
|
||||
"SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
|
||||
StopReason::Refusal
|
||||
}
|
||||
_ => {
|
||||
log::error!("Unexpected Google block_reason: {block_reason}");
|
||||
StopReason::Refusal
|
||||
}
|
||||
};
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
if let Some(candidates) = event.candidates {
|
||||
for candidate in candidates {
|
||||
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
|
||||
self.stop_reason = match finish_reason {
|
||||
"STOP" => StopReason::EndTurn,
|
||||
"MAX_TOKENS" => StopReason::MaxTokens,
|
||||
_ => {
|
||||
log::error!("Unexpected google finish_reason: {finish_reason}");
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
}
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.for_each(|part| match part {
|
||||
Part::TextPart(text_part) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
|
||||
}
|
||||
Part::InlineDataPart(_) => {}
|
||||
Part::FunctionCallPart(function_call_part) => {
|
||||
wants_to_use_tool = true;
|
||||
let name: Arc<str> = function_call_part.function_call.name.into();
|
||||
let next_tool_id =
|
||||
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
let id: LanguageModelToolUseId =
|
||||
format!("{}-{}", name, next_tool_id).into();
|
||||
|
||||
// Normalize empty string signatures to None
|
||||
let thought_signature = function_call_part
|
||||
.thought_signature
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
is_input_complete: true,
|
||||
raw_input: function_call_part.function_call.args.to_string(),
|
||||
input: function_call_part.function_call.args,
|
||||
thought_signature,
|
||||
},
|
||||
)));
|
||||
}
|
||||
Part::FunctionResponsePart(_) => {}
|
||||
Part::ThoughtPart(part) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
|
||||
signature: Some(part.thought_signature),
|
||||
}));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Even when Gemini wants to use a Tool, the API
|
||||
// responds with `finish_reason: STOP`
|
||||
if wants_to_use_tool {
|
||||
self.stop_reason = StopReason::ToolUse;
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_google_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
// We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
|
||||
// So we have to use tokenizer from tiktoken_rs to count tokens.
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
|
||||
if let Some(prompt_token_count) = new.prompt_token_count {
|
||||
usage.prompt_token_count = Some(prompt_token_count);
|
||||
}
|
||||
if let Some(cached_content_token_count) = new.cached_content_token_count {
|
||||
usage.cached_content_token_count = Some(cached_content_token_count);
|
||||
}
|
||||
if let Some(candidates_token_count) = new.candidates_token_count {
|
||||
usage.candidates_token_count = Some(candidates_token_count);
|
||||
}
|
||||
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
|
||||
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
|
||||
}
|
||||
if let Some(thoughts_token_count) = new.thoughts_token_count {
|
||||
usage.thoughts_token_count = Some(thoughts_token_count);
|
||||
}
|
||||
if let Some(total_token_count) = new.total_token_count {
|
||||
usage.total_token_count = Some(total_token_count);
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
|
||||
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
|
||||
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
|
||||
let input_tokens = prompt_tokens - cached_tokens;
|
||||
let output_tokens = usage.candidates_token_count.unwrap_or(0);
|
||||
|
||||
language_model::TokenUsage {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_input_tokens: cached_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<InputField>,
|
||||
state: Entity<State>,
|
||||
|
|
@ -895,428 +525,3 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use google_ai::{
|
||||
Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
|
||||
Part, Role as GoogleRole, TextPart,
|
||||
};
|
||||
use language_model::{LanguageModelToolUseId, MessageContent, Role};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_function_call_with_signature_creates_tool_use_with_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("test_signature_123".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
assert_eq!(events.len(), 2); // ToolUse event + Stop event
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.name.as_ref(), "test_function");
|
||||
assert_eq!(
|
||||
tool_use.thought_signature.as_deref(),
|
||||
Some("test_signature_123")
|
||||
);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_without_signature_has_none() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: None,
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_signature_normalized_to_none() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_function_calls_preserve_signatures() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![
|
||||
Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "function_1".to_string(),
|
||||
args: json!({"arg": "value1"}),
|
||||
},
|
||||
thought_signature: Some("signature_1".to_string()),
|
||||
}),
|
||||
Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "function_2".to_string(),
|
||||
args: json!({"arg": "value2"}),
|
||||
},
|
||||
thought_signature: None,
|
||||
}),
|
||||
],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.name.as_ref(), "function_1");
|
||||
assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
|
||||
} else {
|
||||
panic!("Expected ToolUse event for function_1");
|
||||
}
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
|
||||
assert_eq!(tool_use.name.as_ref(), "function_2");
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected ToolUse event for function_2");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_use_with_signature_converts_to_function_call_part() {
|
||||
let tool_use = language_model::LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_function".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("test_signature_456".to_string()),
|
||||
};
|
||||
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
assert_eq!(request.contents[0].parts.len(), 1);
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.function_call.name, "test_function");
|
||||
assert_eq!(
|
||||
fc_part.thought_signature.as_deref(),
|
||||
Some("test_signature_456")
|
||||
);
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_use_without_signature_omits_field() {
|
||||
let tool_use = language_model::LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_function".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
assert_eq!(request.contents[0].parts.len(), 1);
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_signature_in_tool_use_normalized_to_none() {
|
||||
let tool_use = language_model::LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_function".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("".to_string()),
|
||||
};
|
||||
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_preserves_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
// Simulate receiving a response from Google with a signature
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("round_trip_sig".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
tool_use.clone()
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
};
|
||||
|
||||
// Convert back to Google format
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
// Verify signature is preserved
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_text_and_function_call_with_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![
|
||||
Part::TextPart(TextPart {
|
||||
text: "I'll help with that.".to_string(),
|
||||
}),
|
||||
Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "helper_function".to_string(),
|
||||
args: json!({"query": "help"}),
|
||||
},
|
||||
thought_signature: Some("mixed_sig".to_string()),
|
||||
}),
|
||||
],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
|
||||
assert_eq!(text, "I'll help with that.");
|
||||
} else {
|
||||
panic!("Expected Text event");
|
||||
}
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
|
||||
assert_eq!(tool_use.name.as_ref(), "helper_function");
|
||||
assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_characters_in_signature_preserved() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some(signature_with_special_chars.clone()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(
|
||||
tool_use.thought_signature.as_deref(),
|
||||
Some(signature_with_special_chars.as_str())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ use ui::{
|
|||
use ui_input::InputField;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::util::parse_tool_arguments;
|
||||
use language_model::util::parse_tool_arguments;
|
||||
|
||||
const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
|
||||
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
|||
use ui_input::InputField;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
|
||||
use language_model::util::{fix_streamed_json, parse_tool_arguments};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -402,7 +402,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
|
|||
self.model.capabilities.parallel_tool_calls,
|
||||
self.model.capabilities.prompt_cache_key,
|
||||
self.max_output_tokens(),
|
||||
self.model.reasoning_effort.clone(),
|
||||
self.model.reasoning_effort,
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
|
|
@ -417,7 +417,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
|
|||
self.model.capabilities.parallel_tool_calls,
|
||||
self.model.capabilities.prompt_cache_key,
|
||||
self.max_output_tokens(),
|
||||
self.model.reasoning_effort.clone(),
|
||||
self.model.reasoning_effort,
|
||||
);
|
||||
let completions = self.stream_response(request, cx);
|
||||
async move {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
|||
use ui_input::InputField;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
|
||||
use language_model::util::{fix_streamed_json, parse_tool_arguments};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use language_model::{
|
|||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
|
||||
Role, env_var,
|
||||
env_var,
|
||||
};
|
||||
use open_ai::ResponseStreamEvent;
|
||||
pub use settings::XaiAvailableModel as AvailableModel;
|
||||
|
|
@ -19,7 +19,8 @@ use strum::IntoEnumIterator;
|
|||
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
||||
use ui_input::InputField;
|
||||
use util::ResultExt;
|
||||
use x_ai::{Model, XAI_API_URL};
|
||||
use x_ai::XAI_API_URL;
|
||||
pub use x_ai::completion::count_xai_tokens;
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
|
||||
|
|
@ -320,7 +321,9 @@ impl LanguageModel for XAiLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_xai_tokens(request, self.model.clone(), cx)
|
||||
let model = self.model.clone();
|
||||
cx.background_spawn(async move { count_xai_tokens(request, model) })
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
|
|
@ -354,37 +357,6 @@ impl LanguageModel for XAiLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn count_xai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let model_name = if model.max_token_count() >= 100_000 {
|
||||
"gpt-4o"
|
||||
} else {
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<InputField>,
|
||||
state: Entity<State>,
|
||||
|
|
|
|||
33
crates/language_models_cloud/Cargo.toml
Normal file
33
crates/language_models_cloud/Cargo.toml
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
[package]
|
||||
name = "language_models_cloud"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/language_models_cloud.rs"
|
||||
|
||||
[dependencies]
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model.workspace = true
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
schemars.workspace = true
|
||||
semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
|
||||
[dev-dependencies]
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
1
crates/language_models_cloud/LICENSE-GPL
Symbolic link
1
crates/language_models_cloud/LICENSE-GPL
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
||||
1059
crates/language_models_cloud/src/language_models_cloud.rs
Normal file
1059
crates/language_models_cloud/src/language_models_cloud.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -17,13 +17,18 @@ schemars = ["dep:schemars"]
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model_core.workspace = true
|
||||
rand.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
log.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions.workspace = true
|
||||
|
|
|
|||
1693
crates/open_ai/src/completion.rs
Normal file
1693
crates/open_ai/src/completion.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,5 @@
|
|||
pub mod batches;
|
||||
pub mod completion;
|
||||
pub mod responses;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
|
|
@ -7,9 +8,9 @@ use http_client::{
|
|||
AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
|
||||
http::{HeaderMap, HeaderValue},
|
||||
};
|
||||
pub use language_model_core::ReasoningEffort;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
pub use settings::OpenAiReasoningEffort as ReasoningEffort;
|
||||
use std::{convert::TryFrom, future::Future};
|
||||
use strum::EnumIter;
|
||||
use thiserror::Error;
|
||||
|
|
@ -717,3 +718,26 @@ pub fn embed<'a>(
|
|||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
// -- Conversions to `language_model_core` types --
|
||||
|
||||
impl From<RequestError> for language_model_core::LanguageModelCompletionError {
|
||||
fn from(error: RequestError) -> Self {
|
||||
match error {
|
||||
RequestError::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
body,
|
||||
headers,
|
||||
} => {
|
||||
let retry_after = headers
|
||||
.get(http_client::http::header::RETRY_AFTER)
|
||||
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
|
||||
.map(std::time::Duration::from_secs);
|
||||
|
||||
Self::from_http_status(provider.into(), status_code, body, retry_after)
|
||||
}
|
||||
RequestError::Other(e) => Self::Other(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ schemars = ["dep:schemars"]
|
|||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model_core.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
|
|||
|
|
@ -744,3 +744,71 @@ impl ApiErrorCode {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Conversions to `language_model_core` types --
|
||||
|
||||
impl From<OpenRouterError> for language_model_core::LanguageModelCompletionError {
|
||||
fn from(error: OpenRouterError) -> Self {
|
||||
let provider = language_model_core::LanguageModelProviderName::new("OpenRouter");
|
||||
match error {
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
OpenRouterError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ApiError> for language_model_core::LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
use ApiErrorCode::*;
|
||||
let provider = language_model_core::LanguageModelProviderName::new("OpenRouter");
|
||||
match error.code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PaymentRequiredError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: format!("Payment required: {}", error.message),
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
RequestTimedOut => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code: http_client::StatusCode::REQUEST_TIMEOUT,
|
||||
message: error.message,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ impl PrettierStore {
|
|||
prettier_store
|
||||
.update(cx, |prettier_store, cx| {
|
||||
let name = if is_default {
|
||||
LanguageServerName("prettier (default)".to_string().into())
|
||||
LanguageServerName("prettier (default)".into())
|
||||
} else {
|
||||
let worktree_path = worktree_id
|
||||
.and_then(|id| {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ anyhow.workspace = true
|
|||
collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model_core.workspace = true
|
||||
log.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
use crate::merge_from::MergeFrom;
|
||||
use collections::HashMap;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings_macros::{MergeFrom, with_fallible_options};
|
||||
use strum::EnumString;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
@ -237,15 +237,12 @@ pub struct OpenAiAvailableModel {
|
|||
pub capabilities: OpenAiModelCapabilities,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, EnumString, JsonSchema, MergeFrom)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
pub enum OpenAiReasoningEffort {
|
||||
Minimal,
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
XHigh,
|
||||
pub use language_model_core::ReasoningEffort as OpenAiReasoningEffort;
|
||||
|
||||
impl MergeFrom for OpenAiReasoningEffort {
|
||||
fn merge_from(&mut self, other: &Self) {
|
||||
*self = *other;
|
||||
}
|
||||
}
|
||||
|
||||
#[with_fallible_options]
|
||||
|
|
@ -479,15 +476,10 @@ pub struct LanguageModelCacheConfiguration {
|
|||
pub min_total_token: u64,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, MergeFrom,
|
||||
)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ModelMode {
|
||||
#[default]
|
||||
Default,
|
||||
Thinking {
|
||||
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
|
||||
budget_tokens: Option<u32>,
|
||||
},
|
||||
pub use language_model_core::ModelMode;
|
||||
|
||||
impl MergeFrom for ModelMode {
|
||||
fn merge_from(&mut self, other: &Self) {
|
||||
*self = *other;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
futures.workspace = true
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ use std::sync::Arc;
|
|||
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
|
||||
use cloud_api_client::LlmApiToken;
|
||||
use cloud_api_types::OrganizationId;
|
||||
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use language_model::LlmApiToken;
|
||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||
|
||||
pub struct CloudWebSearchProvider {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ schemars = ["dep:schemars"]
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
language_model_core.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
strum.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
|
|
|
|||
30
crates/x_ai/src/completion.rs
Normal file
30
crates/x_ai/src/completion.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use anyhow::Result;
|
||||
use language_model_core::{LanguageModelRequest, Role};
|
||||
|
||||
use crate::Model;
|
||||
|
||||
/// Count tokens for an xAI model using tiktoken. This is synchronous;
|
||||
/// callers should spawn it on a background thread if needed.
|
||||
pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result<u64> {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let model_name = if model.max_token_count() >= 100_000 {
|
||||
"gpt-4o"
|
||||
} else {
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
pub mod completion;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
|
|
|||
Loading…
Reference in a new issue