mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
agent: Fix issue with streaming tools when model produces invalid JSON (#52891)
Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A
This commit is contained in:
parent
d2257dbc39
commit
e2bba5526a
6 changed files with 957 additions and 372 deletions
|
|
@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
|
|||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
super::init_test(cx);
|
||||
super::always_allow_tools(cx);
|
||||
|
||||
// Enable the streaming edit file tool feature flag.
|
||||
cx.update(|cx| {
|
||||
cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/project"),
|
||||
json!({
|
||||
"src": {
|
||||
"main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
model.as_fake().set_supports_streaming_tools(true);
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = crate::Thread::new(
|
||||
project.clone(),
|
||||
project_context,
|
||||
context_server_registry,
|
||||
crate::Templates::new(),
|
||||
Some(model.clone()),
|
||||
cx,
|
||||
);
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
thread.add_tool(crate::StreamingEditFileTool::new(
|
||||
project.clone(),
|
||||
cx.weak_entity(),
|
||||
thread.action_log().clone(),
|
||||
language_registry,
|
||||
));
|
||||
thread
|
||||
});
|
||||
|
||||
let _events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
UserMessageId::new(),
|
||||
["Write new content to src/main.rs"],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let tool_use_id = "edit_1";
|
||||
let partial_1 = LanguageModelToolUse {
|
||||
id: tool_use_id.into(),
|
||||
name: EditFileTool::NAME.into(),
|
||||
raw_input: json!({
|
||||
"display_description": "Rewrite main.rs",
|
||||
"path": "project/src/main.rs",
|
||||
"mode": "write"
|
||||
})
|
||||
.to_string(),
|
||||
input: json!({
|
||||
"display_description": "Rewrite main.rs",
|
||||
"path": "project/src/main.rs",
|
||||
"mode": "write"
|
||||
}),
|
||||
is_input_complete: false,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
|
||||
cx.run_until_parked();
|
||||
|
||||
let partial_2 = LanguageModelToolUse {
|
||||
id: tool_use_id.into(),
|
||||
name: EditFileTool::NAME.into(),
|
||||
raw_input: json!({
|
||||
"display_description": "Rewrite main.rs",
|
||||
"path": "project/src/main.rs",
|
||||
"mode": "write",
|
||||
"content": "fn main() { /* rewritten */ }"
|
||||
})
|
||||
.to_string(),
|
||||
input: json!({
|
||||
"display_description": "Rewrite main.rs",
|
||||
"path": "project/src/main.rs",
|
||||
"mode": "write",
|
||||
"content": "fn main() { /* rewritten */ }"
|
||||
}),
|
||||
is_input_complete: false,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
|
||||
cx.run_until_parked();
|
||||
|
||||
// Now send a json parse error. At this point we have started writing content to the buffer.
|
||||
fake_model.send_last_completion_stream_event(
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_use_id.into(),
|
||||
tool_name: EditFileTool::NAME.into(),
|
||||
raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
|
||||
json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
|
||||
},
|
||||
);
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// cx.executor().advance_clock(Duration::from_secs(5));
|
||||
// cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
!fake_model.pending_completions().is_empty(),
|
||||
"Thread should have retried after the error"
|
||||
);
|
||||
|
||||
// Respond with a new, well-formed, complete edit_file tool use.
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: "edit_2".into(),
|
||||
name: EditFileTool::NAME.into(),
|
||||
raw_input: json!({
|
||||
"display_description": "Rewrite main.rs",
|
||||
"path": "project/src/main.rs",
|
||||
"mode": "write",
|
||||
"content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
|
||||
})
|
||||
.to_string(),
|
||||
input: json!({
|
||||
"display_description": "Rewrite main.rs",
|
||||
"path": "project/src/main.rs",
|
||||
"mode": "write",
|
||||
"content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
|
||||
}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let pending_completions = fake_model.pending_completions();
|
||||
assert!(
|
||||
pending_completions.len() == 1,
|
||||
"Expected only the follow-up completion containing the successful tool result"
|
||||
);
|
||||
|
||||
let completion = pending_completions
|
||||
.into_iter()
|
||||
.last()
|
||||
.expect("Expected a completion containing the tool result for edit_2");
|
||||
|
||||
let tool_result = completion
|
||||
.messages
|
||||
.iter()
|
||||
.flat_map(|msg| &msg.content)
|
||||
.find_map(|content| match content {
|
||||
language_model::MessageContent::ToolResult(result)
|
||||
if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
|
||||
{
|
||||
Some(result)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.expect("Should have a tool result for edit_2");
|
||||
|
||||
// Ensure that the second tool call completed successfully and edits were applied.
|
||||
assert!(
|
||||
!tool_result.is_error,
|
||||
"Tool result should succeed, got: {:?}",
|
||||
tool_result
|
||||
);
|
||||
let content_text = match &tool_result.content {
|
||||
language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
|
||||
other => panic!("Expected text content, got: {:?}", other),
|
||||
};
|
||||
assert!(
|
||||
!content_text.contains("file has been modified since you last read it"),
|
||||
"Did not expect a stale last-read error, got: {content_text}"
|
||||
);
|
||||
assert!(
|
||||
!content_text.contains("This file has unsaved changes"),
|
||||
"Did not expect an unsaved-changes error, got: {content_text}"
|
||||
);
|
||||
|
||||
let file_content = fs
|
||||
.load(path!("/project/src/main.rs").as_ref())
|
||||
.await
|
||||
.expect("file should exist");
|
||||
super::assert_eq!(
|
||||
file_content,
|
||||
"fn main() {\n println!(\"Hello, rewritten!\");\n}\n",
|
||||
"The second edit should be applied and saved gracefully"
|
||||
);
|
||||
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool(
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
init_test(cx);
|
||||
always_allow_tools(cx);
|
||||
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.add_tool(StreamingJsonErrorContextTool);
|
||||
});
|
||||
|
||||
let _events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
UserMessageId::new(),
|
||||
["Use the streaming_json_error_context tool"],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: "tool_1".into(),
|
||||
name: StreamingJsonErrorContextTool::NAME.into(),
|
||||
raw_input: r#"{"text": "partial"#.into(),
|
||||
input: json!({"text": "partial"}),
|
||||
is_input_complete: false,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_event(
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: "tool_1".into(),
|
||||
tool_name: StreamingJsonErrorContextTool::NAME.into(),
|
||||
raw_input: r#"{"text": "partial"#.into(),
|
||||
json_parse_error: "EOF while parsing a string at line 1 column 17".into(),
|
||||
},
|
||||
);
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
cx.executor().advance_clock(Duration::from_secs(5));
|
||||
cx.run_until_parked();
|
||||
|
||||
let completion = fake_model
|
||||
.pending_completions()
|
||||
.pop()
|
||||
.expect("No running turn");
|
||||
|
||||
let tool_results: Vec<_> = completion
|
||||
.messages
|
||||
.iter()
|
||||
.flat_map(|message| &message.content)
|
||||
.filter_map(|content| match content {
|
||||
MessageContent::ToolResult(result)
|
||||
if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") =>
|
||||
{
|
||||
Some(result)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
tool_results.len(),
|
||||
1,
|
||||
"Expected exactly 1 tool result for tool_1, got {}: {:#?}",
|
||||
tool_results.len(),
|
||||
tool_results
|
||||
);
|
||||
|
||||
let result = tool_results[0];
|
||||
assert!(result.is_error);
|
||||
let content_text = match &result.content {
|
||||
language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
other => panic!("Expected text content, got {:?}", other),
|
||||
};
|
||||
assert!(
|
||||
content_text.contains("Saw partial text 'partial' before invalid JSON"),
|
||||
"Expected tool-enriched partial context, got: {content_text}"
|
||||
);
|
||||
assert!(
|
||||
content_text
|
||||
.contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"),
|
||||
"Expected forwarded JSON parse error, got: {content_text}"
|
||||
);
|
||||
assert!(
|
||||
!content_text.contains("tool input was not fully received"),
|
||||
"Should not contain orphaned sender error, got: {content_text}"
|
||||
);
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Done");
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert!(
|
||||
thread.is_turn_complete(),
|
||||
"Thread should not be stuck; the turn should have completed",
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
|
||||
result_events
|
||||
|
|
@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
InfiniteTool::NAME: true,
|
||||
CancellationAwareTool::NAME: true,
|
||||
StreamingEchoTool::NAME: true,
|
||||
StreamingJsonErrorContextTool::NAME: true,
|
||||
StreamingFailingEchoTool::NAME: true,
|
||||
TerminalTool::NAME: true,
|
||||
UpdatePlanTool::NAME: true,
|
||||
|
|
|
|||
|
|
@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool {
|
|||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
mut input: ToolInput<Self::Input>,
|
||||
input: ToolInput<Self::Input>,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String, String>> {
|
||||
let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
|
||||
cx.spawn(async move |_cx| {
|
||||
while input.recv_partial().await.is_some() {}
|
||||
let input = input
|
||||
.recv()
|
||||
.await
|
||||
|
|
@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct StreamingJsonErrorContextToolInput {
|
||||
/// The text to echo.
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub struct StreamingJsonErrorContextTool;
|
||||
|
||||
impl AgentTool for StreamingJsonErrorContextTool {
|
||||
type Input = StreamingJsonErrorContextToolInput;
|
||||
type Output = String;
|
||||
|
||||
const NAME: &'static str = "streaming_json_error_context";
|
||||
|
||||
fn supports_input_streaming() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn kind() -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(
|
||||
&self,
|
||||
_input: Result<Self::Input, serde_json::Value>,
|
||||
_cx: &mut App,
|
||||
) -> SharedString {
|
||||
"Streaming JSON Error Context".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
mut input: ToolInput<Self::Input>,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String, String>> {
|
||||
cx.spawn(async move |_cx| {
|
||||
let mut last_partial_text = None;
|
||||
|
||||
loop {
|
||||
match input.next().await {
|
||||
Ok(ToolInputPayload::Partial(partial)) => {
|
||||
if let Some(text) = partial.get("text").and_then(|value| value.as_str()) {
|
||||
last_partial_text = Some(text.to_string());
|
||||
}
|
||||
}
|
||||
Ok(ToolInputPayload::Full(input)) => return Ok(input.text),
|
||||
Ok(ToolInputPayload::InvalidJson { error_message }) => {
|
||||
let partial_text = last_partial_text.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"Saw partial text '{partial_text}' before invalid JSON: {error_message}"
|
||||
));
|
||||
}
|
||||
Err(error) => {
|
||||
return Err(format!("Failed to receive tool input: {error}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A streaming tool that echoes its input, used to test streaming tool
|
||||
/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
|
||||
/// before `is_input_complete`).
|
||||
|
|
@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool {
|
|||
) -> Task<Result<Self::Output, Self::Output>> {
|
||||
cx.spawn(async move |_cx| {
|
||||
for _ in 0..self.receive_chunks_until_failure {
|
||||
let _ = input.recv_partial().await;
|
||||
let _ = input.next().await;
|
||||
}
|
||||
Err("failed".into())
|
||||
})
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ use client::UserStore;
|
|||
use cloud_api_types::Plan;
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::stream;
|
||||
use futures::{
|
||||
FutureExt,
|
||||
channel::{mpsc, oneshot},
|
||||
future::Shared,
|
||||
stream::FuturesUnordered,
|
||||
};
|
||||
use futures::{StreamExt, stream};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
|
|
@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema};
|
|||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
marker::PhantomData,
|
||||
|
|
@ -2095,7 +2094,7 @@ impl Thread {
|
|||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -2195,15 +2194,15 @@ impl Thread {
|
|||
raw_input,
|
||||
json_parse_error,
|
||||
} => {
|
||||
return Ok(Some(Task::ready(
|
||||
self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
event_stream,
|
||||
),
|
||||
)));
|
||||
return Ok(self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
event_stream,
|
||||
cancellation_rx,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
UsageUpdate(usage) => {
|
||||
telemetry::event!(
|
||||
|
|
@ -2304,12 +2303,12 @@ impl Thread {
|
|||
if !tool_use.is_input_complete {
|
||||
if tool.supports_input_streaming() {
|
||||
let running_turn = self.running_turn.as_mut()?;
|
||||
if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) {
|
||||
if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) {
|
||||
sender.send_partial(tool_use.input);
|
||||
return None;
|
||||
}
|
||||
|
||||
let (sender, tool_input) = ToolInputSender::channel();
|
||||
let (mut sender, tool_input) = ToolInputSender::channel();
|
||||
sender.send_partial(tool_use.input);
|
||||
running_turn
|
||||
.streaming_tool_inputs
|
||||
|
|
@ -2331,13 +2330,13 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(sender) = self
|
||||
if let Some(mut sender) = self
|
||||
.running_turn
|
||||
.as_mut()?
|
||||
.streaming_tool_inputs
|
||||
.remove(&tool_use.id)
|
||||
{
|
||||
sender.send_final(tool_use.input);
|
||||
sender.send_full(tool_use.input);
|
||||
return None;
|
||||
}
|
||||
|
||||
|
|
@ -2410,10 +2409,12 @@ impl Thread {
|
|||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
event_stream: &ThreadEventStream,
|
||||
) -> LanguageModelToolResult {
|
||||
cancellation_rx: watch::Receiver<bool>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: tool_use_id.clone(),
|
||||
name: tool_name.clone(),
|
||||
id: tool_use_id,
|
||||
name: tool_name,
|
||||
raw_input: raw_input.to_string(),
|
||||
input: serde_json::json!({}),
|
||||
is_input_complete: true,
|
||||
|
|
@ -2426,14 +2427,43 @@ impl Thread {
|
|||
event_stream,
|
||||
);
|
||||
|
||||
let tool_output = format!("Error parsing input JSON: {json_parse_error}");
|
||||
LanguageModelToolResult {
|
||||
tool_use_id,
|
||||
tool_name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(tool_output.into()),
|
||||
output: Some(serde_json::Value::String(raw_input.to_string())),
|
||||
let tool = self.tool(tool_use.name.as_ref());
|
||||
|
||||
let Some(tool) = tool else {
|
||||
let content = format!("No tool named {} exists", tool_use.name);
|
||||
return Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}));
|
||||
};
|
||||
|
||||
let error_message = format!("Error parsing input JSON: {json_parse_error}");
|
||||
|
||||
if tool.supports_input_streaming()
|
||||
&& let Some(mut sender) = self
|
||||
.running_turn
|
||||
.as_mut()?
|
||||
.streaming_tool_inputs
|
||||
.remove(&tool_use.id)
|
||||
{
|
||||
sender.send_invalid_json(error_message);
|
||||
return None;
|
||||
}
|
||||
|
||||
log::debug!("Running tool {}. Received invalid JSON", tool_use.name);
|
||||
let tool_input = ToolInput::invalid_json(error_message);
|
||||
Some(self.run_tool(
|
||||
tool,
|
||||
tool_input,
|
||||
tool_use.id,
|
||||
tool_use.name,
|
||||
event_stream,
|
||||
cancellation_rx,
|
||||
cx,
|
||||
))
|
||||
}
|
||||
|
||||
fn send_or_update_tool_use(
|
||||
|
|
@ -3114,8 +3144,7 @@ impl EventEmitter<TitleUpdated> for Thread {}
|
|||
/// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams
|
||||
/// them, followed by the final complete input available through `.recv()`.
|
||||
pub struct ToolInput<T> {
|
||||
partial_rx: mpsc::UnboundedReceiver<serde_json::Value>,
|
||||
final_rx: oneshot::Receiver<serde_json::Value>,
|
||||
rx: mpsc::UnboundedReceiver<ToolInputPayload<serde_json::Value>>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
|
|
@ -3127,13 +3156,20 @@ impl<T: DeserializeOwned> ToolInput<T> {
|
|||
}
|
||||
|
||||
pub fn ready(value: serde_json::Value) -> Self {
|
||||
let (partial_tx, partial_rx) = mpsc::unbounded();
|
||||
drop(partial_tx);
|
||||
let (final_tx, final_rx) = oneshot::channel();
|
||||
final_tx.send(value).ok();
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
tx.unbounded_send(ToolInputPayload::Full(value)).ok();
|
||||
Self {
|
||||
partial_rx,
|
||||
final_rx,
|
||||
rx,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn invalid_json(error_message: String) -> Self {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
tx.unbounded_send(ToolInputPayload::InvalidJson { error_message })
|
||||
.ok();
|
||||
Self {
|
||||
rx,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
@ -3147,65 +3183,89 @@ impl<T: DeserializeOwned> ToolInput<T> {
|
|||
/// Wait for the final deserialized input, ignoring all partial updates.
|
||||
/// Non-streaming tools can use this to wait until the whole input is available.
|
||||
pub async fn recv(mut self) -> Result<T> {
|
||||
// Drain any remaining partials
|
||||
while self.partial_rx.next().await.is_some() {}
|
||||
let value = self
|
||||
.final_rx
|
||||
.await
|
||||
.map_err(|_| anyhow!("tool input was not fully received"))?;
|
||||
serde_json::from_value(value).map_err(Into::into)
|
||||
while let Ok(value) = self.next().await {
|
||||
match value {
|
||||
ToolInputPayload::Full(value) => return Ok(value),
|
||||
ToolInputPayload::Partial(_) => {}
|
||||
ToolInputPayload::InvalidJson { error_message } => {
|
||||
return Err(anyhow!(error_message));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("tool input was not fully received"))
|
||||
}
|
||||
|
||||
/// Returns the next partial JSON snapshot, or `None` when input is complete.
|
||||
/// Once this returns `None`, call `recv()` to get the final input.
|
||||
pub async fn recv_partial(&mut self) -> Option<serde_json::Value> {
|
||||
self.partial_rx.next().await
|
||||
pub async fn next(&mut self) -> Result<ToolInputPayload<T>> {
|
||||
let value = self
|
||||
.rx
|
||||
.next()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("tool input was not fully received"))?;
|
||||
|
||||
Ok(match value {
|
||||
ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload),
|
||||
ToolInputPayload::Full(payload) => {
|
||||
ToolInputPayload::Full(serde_json::from_value(payload)?)
|
||||
}
|
||||
ToolInputPayload::InvalidJson { error_message } => {
|
||||
ToolInputPayload::InvalidJson { error_message }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn cast<U: DeserializeOwned>(self) -> ToolInput<U> {
|
||||
ToolInput {
|
||||
partial_rx: self.partial_rx,
|
||||
final_rx: self.final_rx,
|
||||
rx: self.rx,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ToolInputPayload<T> {
|
||||
Partial(serde_json::Value),
|
||||
Full(T),
|
||||
InvalidJson { error_message: String },
|
||||
}
|
||||
|
||||
pub struct ToolInputSender {
|
||||
partial_tx: mpsc::UnboundedSender<serde_json::Value>,
|
||||
final_tx: Option<oneshot::Sender<serde_json::Value>>,
|
||||
has_received_final: bool,
|
||||
tx: mpsc::UnboundedSender<ToolInputPayload<serde_json::Value>>,
|
||||
}
|
||||
|
||||
impl ToolInputSender {
|
||||
pub(crate) fn channel() -> (Self, ToolInput<serde_json::Value>) {
|
||||
let (partial_tx, partial_rx) = mpsc::unbounded();
|
||||
let (final_tx, final_rx) = oneshot::channel();
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let sender = Self {
|
||||
partial_tx,
|
||||
final_tx: Some(final_tx),
|
||||
tx,
|
||||
has_received_final: false,
|
||||
};
|
||||
let input = ToolInput {
|
||||
partial_rx,
|
||||
final_rx,
|
||||
rx,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
(sender, input)
|
||||
}
|
||||
|
||||
pub(crate) fn has_received_final(&self) -> bool {
|
||||
self.final_tx.is_none()
|
||||
self.has_received_final
|
||||
}
|
||||
|
||||
pub(crate) fn send_partial(&self, value: serde_json::Value) {
|
||||
self.partial_tx.unbounded_send(value).ok();
|
||||
pub fn send_partial(&mut self, payload: serde_json::Value) {
|
||||
self.tx
|
||||
.unbounded_send(ToolInputPayload::Partial(payload))
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub(crate) fn send_final(mut self, value: serde_json::Value) {
|
||||
// Close the partial channel so recv_partial() returns None
|
||||
self.partial_tx.close_channel();
|
||||
if let Some(final_tx) = self.final_tx.take() {
|
||||
final_tx.send(value).ok();
|
||||
}
|
||||
pub fn send_full(&mut self, payload: serde_json::Value) {
|
||||
self.has_received_final = true;
|
||||
self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok();
|
||||
}
|
||||
|
||||
pub fn send_invalid_json(&mut self, error_message: String) {
|
||||
self.has_received_final = true;
|
||||
self.tx
|
||||
.unbounded_send(ToolInputPayload::InvalidJson { error_message })
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4251,68 +4311,78 @@ mod tests {
|
|||
) {
|
||||
let (thread, event_stream) = setup_thread_for_test(cx).await;
|
||||
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
|
||||
let tool_name: Arc<str> = Arc::from("test_tool");
|
||||
let raw_input: Arc<str> = Arc::from("{invalid json");
|
||||
let json_parse_error = "expected value at line 1 column 1".to_string();
|
||||
let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
|
||||
let tool_name: Arc<str> = Arc::from("test_tool");
|
||||
let raw_input: Arc<str> = Arc::from("{invalid json");
|
||||
let json_parse_error = "expected value at line 1 column 1".to_string();
|
||||
|
||||
// Call the function under test
|
||||
let result = thread.handle_tool_use_json_parse_error_event(
|
||||
tool_use_id.clone(),
|
||||
tool_name.clone(),
|
||||
raw_input.clone(),
|
||||
json_parse_error,
|
||||
&event_stream,
|
||||
);
|
||||
let (_cancellation_tx, cancellation_rx) = watch::channel(false);
|
||||
|
||||
// Verify the result is an error
|
||||
assert!(result.is_error);
|
||||
assert_eq!(result.tool_use_id, tool_use_id);
|
||||
assert_eq!(result.tool_name, tool_name);
|
||||
assert!(matches!(
|
||||
result.content,
|
||||
LanguageModelToolResultContent::Text(_)
|
||||
));
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
// Call the function under test
|
||||
thread
|
||||
.handle_tool_use_json_parse_error_event(
|
||||
tool_use_id.clone(),
|
||||
tool_name.clone(),
|
||||
raw_input.clone(),
|
||||
json_parse_error,
|
||||
&event_stream,
|
||||
cancellation_rx,
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
// Verify the tool use was added to the message content
|
||||
{
|
||||
let last_message = thread.pending_message();
|
||||
assert_eq!(
|
||||
last_message.content.len(),
|
||||
1,
|
||||
"Should have one tool_use in content"
|
||||
);
|
||||
// Verify the result is an error
|
||||
assert!(result.is_error);
|
||||
assert_eq!(result.tool_use_id, tool_use_id);
|
||||
assert_eq!(result.tool_name, tool_name);
|
||||
assert!(matches!(
|
||||
result.content,
|
||||
LanguageModelToolResultContent::Text(_)
|
||||
));
|
||||
|
||||
match &last_message.content[0] {
|
||||
AgentMessageContent::ToolUse(tool_use) => {
|
||||
assert_eq!(tool_use.id, tool_use_id);
|
||||
assert_eq!(tool_use.name, tool_name);
|
||||
assert_eq!(tool_use.raw_input, raw_input.to_string());
|
||||
assert!(tool_use.is_input_complete);
|
||||
// Should fall back to empty object for invalid JSON
|
||||
assert_eq!(tool_use.input, json!({}));
|
||||
}
|
||||
_ => panic!("Expected ToolUse content"),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the tool result (simulating what the caller does)
|
||||
thread
|
||||
.pending_message()
|
||||
.tool_results
|
||||
.insert(result.tool_use_id.clone(), result);
|
||||
|
||||
// Verify the tool result was added
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Verify the tool use was added to the message content
|
||||
{
|
||||
let last_message = thread.pending_message();
|
||||
assert_eq!(
|
||||
last_message.tool_results.len(),
|
||||
last_message.content.len(),
|
||||
1,
|
||||
"Should have one tool_result"
|
||||
"Should have one tool_use in content"
|
||||
);
|
||||
assert!(last_message.tool_results.contains_key(&tool_use_id));
|
||||
});
|
||||
});
|
||||
|
||||
match &last_message.content[0] {
|
||||
AgentMessageContent::ToolUse(tool_use) => {
|
||||
assert_eq!(tool_use.id, tool_use_id);
|
||||
assert_eq!(tool_use.name, tool_name);
|
||||
assert_eq!(tool_use.raw_input, raw_input.to_string());
|
||||
assert!(tool_use.is_input_complete);
|
||||
// Should fall back to empty object for invalid JSON
|
||||
assert_eq!(tool_use.input, json!({}));
|
||||
}
|
||||
_ => panic!("Expected ToolUse content"),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the tool result (simulating what the caller does)
|
||||
thread
|
||||
.pending_message()
|
||||
.tool_results
|
||||
.insert(result.tool_use_id.clone(), result);
|
||||
|
||||
// Verify the tool result was added
|
||||
let last_message = thread.pending_message();
|
||||
assert_eq!(
|
||||
last_message.tool_results.len(),
|
||||
1,
|
||||
"Should have one tool_result"
|
||||
);
|
||||
assert!(last_message.tool_results.contains_key(&tool_use_id));
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -125,6 +125,7 @@ pub struct FakeLanguageModel {
|
|||
>,
|
||||
forbid_requests: AtomicBool,
|
||||
supports_thinking: AtomicBool,
|
||||
supports_streaming_tools: AtomicBool,
|
||||
}
|
||||
|
||||
impl Default for FakeLanguageModel {
|
||||
|
|
@ -137,6 +138,7 @@ impl Default for FakeLanguageModel {
|
|||
current_completion_txs: Mutex::new(Vec::new()),
|
||||
forbid_requests: AtomicBool::new(false),
|
||||
supports_thinking: AtomicBool::new(false),
|
||||
supports_streaming_tools: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -169,6 +171,10 @@ impl FakeLanguageModel {
|
|||
self.supports_thinking.store(supports, SeqCst);
|
||||
}
|
||||
|
||||
pub fn set_supports_streaming_tools(&self, supports: bool) {
|
||||
self.supports_streaming_tools.store(supports, SeqCst);
|
||||
}
|
||||
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
|
|
@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel {
|
|||
self.supports_thinking.load(SeqCst)
|
||||
}
|
||||
|
||||
fn supports_streaming_tools(&self) -> bool {
|
||||
self.supports_streaming_tools.load(SeqCst)
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
"fake".to_string()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue