agent: Fix issue with streaming tools when model produces invalid JSON (#52891)

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2026-04-06 22:26:26 +02:00 committed by GitHub
parent d2257dbc39
commit e2bba5526a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 957 additions and 372 deletions

View file

@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
);
});
}
#[gpui::test]
async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
cx: &mut TestAppContext,
) {
super::init_test(cx);
super::always_allow_tools(cx);
// Enable the streaming edit file tool feature flag.
cx.update(|cx| {
cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
});
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/project"),
json!({
"src": {
"main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
}
}),
)
.await;
let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let project_context = cx.new(|_cx| ProjectContext::default());
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
let context_server_registry =
cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
let model = Arc::new(FakeLanguageModel::default());
model.as_fake().set_supports_streaming_tools(true);
let fake_model = model.as_fake();
let thread = cx.new(|cx| {
let mut thread = crate::Thread::new(
project.clone(),
project_context,
context_server_registry,
crate::Templates::new(),
Some(model.clone()),
cx,
);
let language_registry = project.read(cx).languages().clone();
thread.add_tool(crate::StreamingEditFileTool::new(
project.clone(),
cx.weak_entity(),
thread.action_log().clone(),
language_registry,
));
thread
});
let _events = thread
.update(cx, |thread, cx| {
thread.send(
UserMessageId::new(),
["Write new content to src/main.rs"],
cx,
)
})
.unwrap();
cx.run_until_parked();
let tool_use_id = "edit_1";
let partial_1 = LanguageModelToolUse {
id: tool_use_id.into(),
name: EditFileTool::NAME.into(),
raw_input: json!({
"display_description": "Rewrite main.rs",
"path": "project/src/main.rs",
"mode": "write"
})
.to_string(),
input: json!({
"display_description": "Rewrite main.rs",
"path": "project/src/main.rs",
"mode": "write"
}),
is_input_complete: false,
thought_signature: None,
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
cx.run_until_parked();
let partial_2 = LanguageModelToolUse {
id: tool_use_id.into(),
name: EditFileTool::NAME.into(),
raw_input: json!({
"display_description": "Rewrite main.rs",
"path": "project/src/main.rs",
"mode": "write",
"content": "fn main() { /* rewritten */ }"
})
.to_string(),
input: json!({
"display_description": "Rewrite main.rs",
"path": "project/src/main.rs",
"mode": "write",
"content": "fn main() { /* rewritten */ }"
}),
is_input_complete: false,
thought_signature: None,
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
cx.run_until_parked();
// Now send a json parse error. At this point we have started writing content to the buffer.
fake_model.send_last_completion_stream_event(
LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_use_id.into(),
tool_name: EditFileTool::NAME.into(),
raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
},
);
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// cx.executor().advance_clock(Duration::from_secs(5));
// cx.run_until_parked();
assert!(
!fake_model.pending_completions().is_empty(),
"Thread should have retried after the error"
);
// Respond with a new, well-formed, complete edit_file tool use.
let tool_use = LanguageModelToolUse {
id: "edit_2".into(),
name: EditFileTool::NAME.into(),
raw_input: json!({
"display_description": "Rewrite main.rs",
"path": "project/src/main.rs",
"mode": "write",
"content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
})
.to_string(),
input: json!({
"display_description": "Rewrite main.rs",
"path": "project/src/main.rs",
"mode": "write",
"content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
}),
is_input_complete: true,
thought_signature: None,
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let pending_completions = fake_model.pending_completions();
assert!(
pending_completions.len() == 1,
"Expected only the follow-up completion containing the successful tool result"
);
let completion = pending_completions
.into_iter()
.last()
.expect("Expected a completion containing the tool result for edit_2");
let tool_result = completion
.messages
.iter()
.flat_map(|msg| &msg.content)
.find_map(|content| match content {
language_model::MessageContent::ToolResult(result)
if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
{
Some(result)
}
_ => None,
})
.expect("Should have a tool result for edit_2");
// Ensure that the second tool call completed successfully and edits were applied.
assert!(
!tool_result.is_error,
"Tool result should succeed, got: {:?}",
tool_result
);
let content_text = match &tool_result.content {
language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
other => panic!("Expected text content, got: {:?}", other),
};
assert!(
!content_text.contains("file has been modified since you last read it"),
"Did not expect a stale last-read error, got: {content_text}"
);
assert!(
!content_text.contains("This file has unsaved changes"),
"Did not expect an unsaved-changes error, got: {content_text}"
);
let file_content = fs
.load(path!("/project/src/main.rs").as_ref())
.await
.expect("file should exist");
super::assert_eq!(
file_content,
"fn main() {\n println!(\"Hello, rewritten!\");\n}\n",
"The second edit should be applied and saved gracefully"
);
fake_model.end_last_completion_stream();
cx.run_until_parked();
}

View file

@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
});
}
#[gpui::test]
async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool(
cx: &mut TestAppContext,
) {
init_test(cx);
always_allow_tools(cx);
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
thread.update(cx, |thread, _cx| {
thread.add_tool(StreamingJsonErrorContextTool);
});
let _events = thread
.update(cx, |thread, cx| {
thread.send(
UserMessageId::new(),
["Use the streaming_json_error_context tool"],
cx,
)
})
.unwrap();
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_1".into(),
name: StreamingJsonErrorContextTool::NAME.into(),
raw_input: r#"{"text": "partial"#.into(),
input: json!({"text": "partial"}),
is_input_complete: false,
thought_signature: None,
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
cx.run_until_parked();
fake_model.send_last_completion_stream_event(
LanguageModelCompletionEvent::ToolUseJsonParseError {
id: "tool_1".into(),
tool_name: StreamingJsonErrorContextTool::NAME.into(),
raw_input: r#"{"text": "partial"#.into(),
json_parse_error: "EOF while parsing a string at line 1 column 17".into(),
},
);
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
fake_model.end_last_completion_stream();
cx.run_until_parked();
cx.executor().advance_clock(Duration::from_secs(5));
cx.run_until_parked();
let completion = fake_model
.pending_completions()
.pop()
.expect("No running turn");
let tool_results: Vec<_> = completion
.messages
.iter()
.flat_map(|message| &message.content)
.filter_map(|content| match content {
MessageContent::ToolResult(result)
if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") =>
{
Some(result)
}
_ => None,
})
.collect();
assert_eq!(
tool_results.len(),
1,
"Expected exactly 1 tool result for tool_1, got {}: {:#?}",
tool_results.len(),
tool_results
);
let result = tool_results[0];
assert!(result.is_error);
let content_text = match &result.content {
language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
other => panic!("Expected text content, got {:?}", other),
};
assert!(
content_text.contains("Saw partial text 'partial' before invalid JSON"),
"Expected tool-enriched partial context, got: {content_text}"
);
assert!(
content_text
.contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"),
"Expected forwarded JSON parse error, got: {content_text}"
);
assert!(
!content_text.contains("tool input was not fully received"),
"Should not contain orphaned sender error, got: {content_text}"
);
fake_model.send_last_completion_stream_text_chunk("Done");
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _cx| {
assert!(
thread.is_turn_complete(),
"Thread should not be stuck; the turn should have completed",
);
});
}
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events
@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
InfiniteTool::NAME: true,
CancellationAwareTool::NAME: true,
StreamingEchoTool::NAME: true,
StreamingJsonErrorContextTool::NAME: true,
StreamingFailingEchoTool::NAME: true,
TerminalTool::NAME: true,
UpdatePlanTool::NAME: true,

View file

@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool {
fn run(
self: Arc<Self>,
mut input: ToolInput<Self::Input>,
input: ToolInput<Self::Input>,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String, String>> {
let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
cx.spawn(async move |_cx| {
while input.recv_partial().await.is_some() {}
let input = input
.recv()
.await
@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool {
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct StreamingJsonErrorContextToolInput {
/// The text to echo.
pub text: String,
}
pub struct StreamingJsonErrorContextTool;
impl AgentTool for StreamingJsonErrorContextTool {
type Input = StreamingJsonErrorContextToolInput;
type Output = String;
const NAME: &'static str = "streaming_json_error_context";
fn supports_input_streaming() -> bool {
true
}
fn kind() -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"Streaming JSON Error Context".into()
}
fn run(
self: Arc<Self>,
mut input: ToolInput<Self::Input>,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String, String>> {
cx.spawn(async move |_cx| {
let mut last_partial_text = None;
loop {
match input.next().await {
Ok(ToolInputPayload::Partial(partial)) => {
if let Some(text) = partial.get("text").and_then(|value| value.as_str()) {
last_partial_text = Some(text.to_string());
}
}
Ok(ToolInputPayload::Full(input)) => return Ok(input.text),
Ok(ToolInputPayload::InvalidJson { error_message }) => {
let partial_text = last_partial_text.unwrap_or_default();
return Err(format!(
"Saw partial text '{partial_text}' before invalid JSON: {error_message}"
));
}
Err(error) => {
return Err(format!("Failed to receive tool input: {error}"));
}
}
}
})
}
}
/// A streaming tool that echoes its input, used to test streaming tool
/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
/// before `is_input_complete`).
@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool {
) -> Task<Result<Self::Output, Self::Output>> {
cx.spawn(async move |_cx| {
for _ in 0..self.receive_chunks_until_failure {
let _ = input.recv_partial().await;
let _ = input.next().await;
}
Err("failed".into())
})

View file

@ -22,13 +22,13 @@ use client::UserStore;
use cloud_api_types::Plan;
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use futures::stream;
use futures::{
FutureExt,
channel::{mpsc, oneshot},
future::Shared,
stream::FuturesUnordered,
};
use futures::{StreamExt, stream};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file};
use smol::stream::StreamExt;
use std::{
collections::BTreeMap,
marker::PhantomData,
@ -2095,7 +2094,7 @@ impl Thread {
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
.insert(tool_result.tool_use_id.clone(), tool_result);
.insert(tool_result.tool_use_id.clone(), tool_result)
})?;
Ok(())
}
@ -2195,15 +2194,15 @@ impl Thread {
raw_input,
json_parse_error,
} => {
return Ok(Some(Task::ready(
self.handle_tool_use_json_parse_error_event(
id,
tool_name,
raw_input,
json_parse_error,
event_stream,
),
)));
return Ok(self.handle_tool_use_json_parse_error_event(
id,
tool_name,
raw_input,
json_parse_error,
event_stream,
cancellation_rx,
cx,
));
}
UsageUpdate(usage) => {
telemetry::event!(
@ -2304,12 +2303,12 @@ impl Thread {
if !tool_use.is_input_complete {
if tool.supports_input_streaming() {
let running_turn = self.running_turn.as_mut()?;
if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) {
if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) {
sender.send_partial(tool_use.input);
return None;
}
let (sender, tool_input) = ToolInputSender::channel();
let (mut sender, tool_input) = ToolInputSender::channel();
sender.send_partial(tool_use.input);
running_turn
.streaming_tool_inputs
@ -2331,13 +2330,13 @@ impl Thread {
}
}
if let Some(sender) = self
if let Some(mut sender) = self
.running_turn
.as_mut()?
.streaming_tool_inputs
.remove(&tool_use.id)
{
sender.send_final(tool_use.input);
sender.send_full(tool_use.input);
return None;
}
@ -2410,10 +2409,12 @@ impl Thread {
raw_input: Arc<str>,
json_parse_error: String,
event_stream: &ThreadEventStream,
) -> LanguageModelToolResult {
cancellation_rx: watch::Receiver<bool>,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
let tool_use = LanguageModelToolUse {
id: tool_use_id.clone(),
name: tool_name.clone(),
id: tool_use_id,
name: tool_name,
raw_input: raw_input.to_string(),
input: serde_json::json!({}),
is_input_complete: true,
@ -2426,14 +2427,43 @@ impl Thread {
event_stream,
);
let tool_output = format!("Error parsing input JSON: {json_parse_error}");
LanguageModelToolResult {
tool_use_id,
tool_name,
is_error: true,
content: LanguageModelToolResultContent::Text(tool_output.into()),
output: Some(serde_json::Value::String(raw_input.to_string())),
let tool = self.tool(tool_use.name.as_ref());
let Some(tool) = tool else {
let content = format!("No tool named {} exists", tool_use.name);
return Some(Task::ready(LanguageModelToolResult {
content: LanguageModelToolResultContent::Text(Arc::from(content)),
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
output: None,
}));
};
let error_message = format!("Error parsing input JSON: {json_parse_error}");
if tool.supports_input_streaming()
&& let Some(mut sender) = self
.running_turn
.as_mut()?
.streaming_tool_inputs
.remove(&tool_use.id)
{
sender.send_invalid_json(error_message);
return None;
}
log::debug!("Running tool {}. Received invalid JSON", tool_use.name);
let tool_input = ToolInput::invalid_json(error_message);
Some(self.run_tool(
tool,
tool_input,
tool_use.id,
tool_use.name,
event_stream,
cancellation_rx,
cx,
))
}
fn send_or_update_tool_use(
@ -3114,8 +3144,7 @@ impl EventEmitter<TitleUpdated> for Thread {}
/// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams
/// them, followed by the final complete input available through `.recv()`.
pub struct ToolInput<T> {
partial_rx: mpsc::UnboundedReceiver<serde_json::Value>,
final_rx: oneshot::Receiver<serde_json::Value>,
rx: mpsc::UnboundedReceiver<ToolInputPayload<serde_json::Value>>,
_phantom: PhantomData<T>,
}
@ -3127,13 +3156,20 @@ impl<T: DeserializeOwned> ToolInput<T> {
}
pub fn ready(value: serde_json::Value) -> Self {
let (partial_tx, partial_rx) = mpsc::unbounded();
drop(partial_tx);
let (final_tx, final_rx) = oneshot::channel();
final_tx.send(value).ok();
let (tx, rx) = mpsc::unbounded();
tx.unbounded_send(ToolInputPayload::Full(value)).ok();
Self {
partial_rx,
final_rx,
rx,
_phantom: PhantomData,
}
}
pub fn invalid_json(error_message: String) -> Self {
let (tx, rx) = mpsc::unbounded();
tx.unbounded_send(ToolInputPayload::InvalidJson { error_message })
.ok();
Self {
rx,
_phantom: PhantomData,
}
}
@ -3147,65 +3183,89 @@ impl<T: DeserializeOwned> ToolInput<T> {
/// Wait for the final deserialized input, ignoring all partial updates.
/// Non-streaming tools can use this to wait until the whole input is available.
pub async fn recv(mut self) -> Result<T> {
// Drain any remaining partials
while self.partial_rx.next().await.is_some() {}
let value = self
.final_rx
.await
.map_err(|_| anyhow!("tool input was not fully received"))?;
serde_json::from_value(value).map_err(Into::into)
while let Ok(value) = self.next().await {
match value {
ToolInputPayload::Full(value) => return Ok(value),
ToolInputPayload::Partial(_) => {}
ToolInputPayload::InvalidJson { error_message } => {
return Err(anyhow!(error_message));
}
}
}
Err(anyhow!("tool input was not fully received"))
}
/// Returns the next partial JSON snapshot, or `None` when input is complete.
/// Once this returns `None`, call `recv()` to get the final input.
pub async fn recv_partial(&mut self) -> Option<serde_json::Value> {
self.partial_rx.next().await
pub async fn next(&mut self) -> Result<ToolInputPayload<T>> {
let value = self
.rx
.next()
.await
.ok_or_else(|| anyhow!("tool input was not fully received"))?;
Ok(match value {
ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload),
ToolInputPayload::Full(payload) => {
ToolInputPayload::Full(serde_json::from_value(payload)?)
}
ToolInputPayload::InvalidJson { error_message } => {
ToolInputPayload::InvalidJson { error_message }
}
})
}
fn cast<U: DeserializeOwned>(self) -> ToolInput<U> {
ToolInput {
partial_rx: self.partial_rx,
final_rx: self.final_rx,
rx: self.rx,
_phantom: PhantomData,
}
}
}
pub enum ToolInputPayload<T> {
Partial(serde_json::Value),
Full(T),
InvalidJson { error_message: String },
}
pub struct ToolInputSender {
partial_tx: mpsc::UnboundedSender<serde_json::Value>,
final_tx: Option<oneshot::Sender<serde_json::Value>>,
has_received_final: bool,
tx: mpsc::UnboundedSender<ToolInputPayload<serde_json::Value>>,
}
impl ToolInputSender {
pub(crate) fn channel() -> (Self, ToolInput<serde_json::Value>) {
let (partial_tx, partial_rx) = mpsc::unbounded();
let (final_tx, final_rx) = oneshot::channel();
let (tx, rx) = mpsc::unbounded();
let sender = Self {
partial_tx,
final_tx: Some(final_tx),
tx,
has_received_final: false,
};
let input = ToolInput {
partial_rx,
final_rx,
rx,
_phantom: PhantomData,
};
(sender, input)
}
pub(crate) fn has_received_final(&self) -> bool {
self.final_tx.is_none()
self.has_received_final
}
pub(crate) fn send_partial(&self, value: serde_json::Value) {
self.partial_tx.unbounded_send(value).ok();
pub fn send_partial(&mut self, payload: serde_json::Value) {
self.tx
.unbounded_send(ToolInputPayload::Partial(payload))
.ok();
}
pub(crate) fn send_final(mut self, value: serde_json::Value) {
// Close the partial channel so recv_partial() returns None
self.partial_tx.close_channel();
if let Some(final_tx) = self.final_tx.take() {
final_tx.send(value).ok();
}
pub fn send_full(&mut self, payload: serde_json::Value) {
self.has_received_final = true;
self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok();
}
pub fn send_invalid_json(&mut self, error_message: String) {
self.has_received_final = true;
self.tx
.unbounded_send(ToolInputPayload::InvalidJson { error_message })
.ok();
}
}
@ -4251,68 +4311,78 @@ mod tests {
) {
let (thread, event_stream) = setup_thread_for_test(cx).await;
cx.update(|cx| {
thread.update(cx, |thread, _cx| {
let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
let tool_name: Arc<str> = Arc::from("test_tool");
let raw_input: Arc<str> = Arc::from("{invalid json");
let json_parse_error = "expected value at line 1 column 1".to_string();
let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
let tool_name: Arc<str> = Arc::from("test_tool");
let raw_input: Arc<str> = Arc::from("{invalid json");
let json_parse_error = "expected value at line 1 column 1".to_string();
// Call the function under test
let result = thread.handle_tool_use_json_parse_error_event(
tool_use_id.clone(),
tool_name.clone(),
raw_input.clone(),
json_parse_error,
&event_stream,
);
let (_cancellation_tx, cancellation_rx) = watch::channel(false);
// Verify the result is an error
assert!(result.is_error);
assert_eq!(result.tool_use_id, tool_use_id);
assert_eq!(result.tool_name, tool_name);
assert!(matches!(
result.content,
LanguageModelToolResultContent::Text(_)
));
let result = cx
.update(|cx| {
thread.update(cx, |thread, cx| {
// Call the function under test
thread
.handle_tool_use_json_parse_error_event(
tool_use_id.clone(),
tool_name.clone(),
raw_input.clone(),
json_parse_error,
&event_stream,
cancellation_rx,
cx,
)
.unwrap()
})
})
.await;
// Verify the tool use was added to the message content
{
let last_message = thread.pending_message();
assert_eq!(
last_message.content.len(),
1,
"Should have one tool_use in content"
);
// Verify the result is an error
assert!(result.is_error);
assert_eq!(result.tool_use_id, tool_use_id);
assert_eq!(result.tool_name, tool_name);
assert!(matches!(
result.content,
LanguageModelToolResultContent::Text(_)
));
match &last_message.content[0] {
AgentMessageContent::ToolUse(tool_use) => {
assert_eq!(tool_use.id, tool_use_id);
assert_eq!(tool_use.name, tool_name);
assert_eq!(tool_use.raw_input, raw_input.to_string());
assert!(tool_use.is_input_complete);
// Should fall back to empty object for invalid JSON
assert_eq!(tool_use.input, json!({}));
}
_ => panic!("Expected ToolUse content"),
}
}
// Insert the tool result (simulating what the caller does)
thread
.pending_message()
.tool_results
.insert(result.tool_use_id.clone(), result);
// Verify the tool result was added
thread.update(cx, |thread, _cx| {
// Verify the tool use was added to the message content
{
let last_message = thread.pending_message();
assert_eq!(
last_message.tool_results.len(),
last_message.content.len(),
1,
"Should have one tool_result"
"Should have one tool_use in content"
);
assert!(last_message.tool_results.contains_key(&tool_use_id));
});
});
match &last_message.content[0] {
AgentMessageContent::ToolUse(tool_use) => {
assert_eq!(tool_use.id, tool_use_id);
assert_eq!(tool_use.name, tool_name);
assert_eq!(tool_use.raw_input, raw_input.to_string());
assert!(tool_use.is_input_complete);
// Should fall back to empty object for invalid JSON
assert_eq!(tool_use.input, json!({}));
}
_ => panic!("Expected ToolUse content"),
}
}
// Insert the tool result (simulating what the caller does)
thread
.pending_message()
.tool_results
.insert(result.tool_use_id.clone(), result);
// Verify the tool result was added
let last_message = thread.pending_message();
assert_eq!(
last_message.tool_results.len(),
1,
"Should have one tool_result"
);
assert!(last_message.tool_results.contains_key(&tool_use_id));
})
}
}

File diff suppressed because it is too large Load diff

View file

@ -125,6 +125,7 @@ pub struct FakeLanguageModel {
>,
forbid_requests: AtomicBool,
supports_thinking: AtomicBool,
supports_streaming_tools: AtomicBool,
}
impl Default for FakeLanguageModel {
@ -137,6 +138,7 @@ impl Default for FakeLanguageModel {
current_completion_txs: Mutex::new(Vec::new()),
forbid_requests: AtomicBool::new(false),
supports_thinking: AtomicBool::new(false),
supports_streaming_tools: AtomicBool::new(false),
}
}
}
@ -169,6 +171,10 @@ impl FakeLanguageModel {
self.supports_thinking.store(supports, SeqCst);
}
pub fn set_supports_streaming_tools(&self, supports: bool) {
self.supports_streaming_tools.store(supports, SeqCst);
}
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel {
self.supports_thinking.load(SeqCst)
}
fn supports_streaming_tools(&self) -> bool {
self.supports_streaming_tools.load(SeqCst)
}
fn telemetry_id(&self) -> String {
"fake".to_string()
}