mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
bedrock: Model streamlining and cleanup (#49287)
Release Notes: - Improved Bedrock error messages: region-locked models ask the user to try a different region, rate limits and access errors are reported cleanly instead of as raw API responses - Streamlined Bedrock model list to 39 curated models - Fixed API errors when using non-tool models in agent threads --------- Co-authored-by: Ona <no-reply@ona.com>
This commit is contained in:
parent
891f432f66
commit
6f8023530c
3 changed files with 716 additions and 715 deletions
|
|
@ -1,6 +1,6 @@
|
|||
mod models;
|
||||
|
||||
use anyhow::{Context, Error, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use aws_sdk_bedrockruntime as bedrock;
|
||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||
use aws_sdk_bedrockruntime::types::InferenceConfiguration;
|
||||
|
|
@ -37,7 +37,7 @@ pub const CONTEXT_1M_BETA_HEADER: &str = "context-1m-2025-08-07";
|
|||
pub async fn stream_completion(
|
||||
client: bedrock::Client,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
|
||||
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, anyhow::Error>>, BedrockError> {
|
||||
let mut response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into());
|
||||
|
|
@ -94,10 +94,30 @@ pub async fn stream_completion(
|
|||
}
|
||||
}
|
||||
|
||||
let output = response
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send API request to Bedrock");
|
||||
let output = response.send().await.map_err(|err| match err {
|
||||
bedrock::error::SdkError::ServiceError(ctx) => {
|
||||
use bedrock::operation::converse_stream::ConverseStreamError;
|
||||
let err = ctx.into_err();
|
||||
match &err {
|
||||
ConverseStreamError::ValidationException(e) => {
|
||||
BedrockError::Validation(e.message().unwrap_or("validation error").to_string())
|
||||
}
|
||||
ConverseStreamError::ThrottlingException(_) => BedrockError::RateLimited,
|
||||
ConverseStreamError::ServiceUnavailableException(_)
|
||||
| ConverseStreamError::ModelNotReadyException(_) => {
|
||||
BedrockError::ServiceUnavailable
|
||||
}
|
||||
ConverseStreamError::AccessDeniedException(e) => {
|
||||
BedrockError::AccessDenied(e.message().unwrap_or("access denied").to_string())
|
||||
}
|
||||
ConverseStreamError::InternalServerException(e) => BedrockError::InternalServer(
|
||||
e.message().unwrap_or("internal server error").to_string(),
|
||||
),
|
||||
_ => BedrockError::Other(err.into()),
|
||||
}
|
||||
}
|
||||
other => BedrockError::Other(other.into()),
|
||||
});
|
||||
|
||||
let stream = Box::pin(stream::unfold(
|
||||
output?.stream,
|
||||
|
|
@ -106,10 +126,10 @@ pub async fn stream_completion(
|
|||
Ok(Some(output)) => Some((Ok(output), stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => Some((
|
||||
Err(BedrockError::ClientError(anyhow!(
|
||||
Err(anyhow!(
|
||||
"{}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
))),
|
||||
)),
|
||||
stream,
|
||||
)),
|
||||
}
|
||||
|
|
@ -196,10 +216,16 @@ pub struct Metadata {
|
|||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum BedrockError {
|
||||
#[error("client error: {0}")]
|
||||
ClientError(anyhow::Error),
|
||||
#[error("extension error: {0}")]
|
||||
ExtensionError(anyhow::Error),
|
||||
#[error("{0}")]
|
||||
Validation(String),
|
||||
#[error("rate limited")]
|
||||
RateLimited,
|
||||
#[error("service unavailable")]
|
||||
ServiceUnavailable,
|
||||
#[error("{0}")]
|
||||
AccessDenied(String),
|
||||
#[error("{0}")]
|
||||
InternalServer(String),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -600,18 +600,19 @@ impl BedrockModel {
|
|||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
|
||||
Result<BoxStream<'static, Result<BedrockStreamingResponse, anyhow::Error>>, BedrockError>,
|
||||
> {
|
||||
let Ok(runtime_client) = self
|
||||
.get_or_init_client(cx)
|
||||
.cloned()
|
||||
.context("Bedrock client not initialized")
|
||||
else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return futures::future::ready(Err(BedrockError::Other(anyhow!("App state dropped"))))
|
||||
.boxed();
|
||||
};
|
||||
|
||||
let task = Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request));
|
||||
async move { task.await.map_err(|err| anyhow!(err))? }.boxed()
|
||||
async move { task.await.map_err(|e| BedrockError::Other(e.into()))? }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -715,6 +716,7 @@ impl LanguageModel for BedrockModel {
|
|||
self.model.max_output_tokens(),
|
||||
self.model.mode(),
|
||||
self.model.supports_caching(),
|
||||
self.model.supports_tool_use(),
|
||||
use_extended_context,
|
||||
) {
|
||||
Ok(request) => request,
|
||||
|
|
@ -722,8 +724,44 @@ impl LanguageModel for BedrockModel {
|
|||
};
|
||||
|
||||
let request = self.stream_completion(request, cx);
|
||||
let display_name = self.model.display_name().to_string();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
let response = request.await.map_err(|err| match err {
|
||||
BedrockError::Validation(ref msg) => {
|
||||
if msg.contains("model identifier is invalid") {
|
||||
LanguageModelCompletionError::Other(anyhow!(
|
||||
"{display_name} is not available in {region}. \
|
||||
Try switching to a region where this model is supported."
|
||||
))
|
||||
} else {
|
||||
LanguageModelCompletionError::BadRequestFormat {
|
||||
provider: PROVIDER_NAME,
|
||||
message: msg.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
BedrockError::RateLimited => LanguageModelCompletionError::RateLimitExceeded {
|
||||
provider: PROVIDER_NAME,
|
||||
retry_after: None,
|
||||
},
|
||||
BedrockError::ServiceUnavailable => {
|
||||
LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: PROVIDER_NAME,
|
||||
retry_after: None,
|
||||
}
|
||||
}
|
||||
BedrockError::AccessDenied(msg) => LanguageModelCompletionError::PermissionError {
|
||||
provider: PROVIDER_NAME,
|
||||
message: msg,
|
||||
},
|
||||
BedrockError::InternalServer(msg) => {
|
||||
LanguageModelCompletionError::ApiInternalServerError {
|
||||
provider: PROVIDER_NAME,
|
||||
message: msg,
|
||||
}
|
||||
}
|
||||
other => LanguageModelCompletionError::Other(anyhow!(other)),
|
||||
})?;
|
||||
let events = map_to_language_model_completion_events(response);
|
||||
|
||||
if deny_tool_calls {
|
||||
|
|
@ -771,6 +809,7 @@ pub fn into_bedrock(
|
|||
max_output_tokens: u64,
|
||||
mode: BedrockModelMode,
|
||||
supports_caching: bool,
|
||||
supports_tool_use: bool,
|
||||
allow_extended_context: bool,
|
||||
) -> Result<bedrock::Request> {
|
||||
let mut new_messages: Vec<BedrockMessage> = Vec::new();
|
||||
|
|
@ -965,28 +1004,32 @@ pub fn into_bedrock(
|
|||
}
|
||||
}
|
||||
|
||||
let mut tool_spec: Vec<BedrockTool> = request
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
Some(BedrockTool::ToolSpec(
|
||||
BedrockToolSpec::builder()
|
||||
.name(tool.name.clone())
|
||||
.description(tool.description.clone())
|
||||
.input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
|
||||
&tool.input_schema,
|
||||
)))
|
||||
.build()
|
||||
.log_err()?,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
let mut tool_spec: Vec<BedrockTool> = if supports_tool_use {
|
||||
request
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
Some(BedrockTool::ToolSpec(
|
||||
BedrockToolSpec::builder()
|
||||
.name(tool.name.clone())
|
||||
.description(tool.description.clone())
|
||||
.input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
|
||||
&tool.input_schema,
|
||||
)))
|
||||
.build()
|
||||
.log_err()?,
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Bedrock requires toolConfig when messages contain tool use/result blocks.
|
||||
// If no tools are defined but messages contain tool content (e.g., when
|
||||
// summarising a conversation that used tools), add a dummy tool to satisfy
|
||||
// the API requirement.
|
||||
if tool_spec.is_empty() && messages_contain_tool_content {
|
||||
if supports_tool_use && tool_spec.is_empty() && messages_contain_tool_content {
|
||||
tool_spec.push(BedrockTool::ToolSpec(
|
||||
BedrockToolSpec::builder()
|
||||
.name("_placeholder")
|
||||
|
|
@ -1020,17 +1063,23 @@ pub fn into_bedrock(
|
|||
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
||||
}
|
||||
};
|
||||
let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
|
||||
.set_tools(Some(tool_spec))
|
||||
.tool_choice(tool_choice)
|
||||
.build()?;
|
||||
let tool_config = if tool_spec.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
BedrockToolConfig::builder()
|
||||
.set_tools(Some(tool_spec))
|
||||
.tool_choice(tool_choice)
|
||||
.build()?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(bedrock::Request {
|
||||
model,
|
||||
messages: new_messages,
|
||||
max_tokens: max_output_tokens,
|
||||
system: Some(system_message),
|
||||
tools: Some(tool_config),
|
||||
tools: tool_config,
|
||||
thinking: if request.thinking_allowed {
|
||||
match mode {
|
||||
BedrockModelMode::Thinking { budget_tokens } => {
|
||||
|
|
@ -1116,7 +1165,7 @@ pub fn get_bedrock_tokens(
|
|||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, anyhow::Error>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
|
|
@ -1125,13 +1174,15 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, anyhow::Error>>>>,
|
||||
tool_uses_by_index: HashMap<i32, RawToolUse>,
|
||||
emitted_tool_use: bool,
|
||||
}
|
||||
|
||||
let initial_state = State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
emitted_tool_use: false,
|
||||
};
|
||||
|
||||
futures::stream::unfold(initial_state, |mut state| async move {
|
||||
|
|
@ -1190,10 +1241,13 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
None
|
||||
}
|
||||
ConverseStreamOutput::MessageStart(_) => None,
|
||||
ConverseStreamOutput::ContentBlockStop(cb_stop) => state
|
||||
.tool_uses_by_index
|
||||
.remove(&cb_stop.content_block_index)
|
||||
.map(|tool_use| {
|
||||
state.emitted_tool_use = true;
|
||||
|
||||
let input = parse_tool_arguments(&tool_use.input_json)
|
||||
.unwrap_or_else(|_| Value::Object(Default::default()));
|
||||
|
||||
|
|
@ -1223,9 +1277,16 @@ pub fn map_to_language_model_completion_events(
|
|||
}))
|
||||
}),
|
||||
ConverseStreamOutput::MessageStop(message_stop) => {
|
||||
let stop_reason = match message_stop.stop_reason {
|
||||
StopReason::ToolUse => language_model::StopReason::ToolUse,
|
||||
_ => language_model::StopReason::EndTurn,
|
||||
let stop_reason = if state.emitted_tool_use {
|
||||
// Some models (e.g. Kimi) send EndTurn even when
|
||||
// they've made tool calls. Trust the content over
|
||||
// the stop reason.
|
||||
language_model::StopReason::ToolUse
|
||||
} else {
|
||||
match message_stop.stop_reason {
|
||||
StopReason::ToolUse => language_model::StopReason::ToolUse,
|
||||
_ => language_model::StopReason::EndTurn,
|
||||
}
|
||||
};
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue