mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
agent: Replay image output (#57143)
Release Notes: - agent: Fix image output from tools not being reloaded when restoring thread
This commit is contained in:
parent
46b08f9d7d
commit
c352cad169
1 changed files with 180 additions and 8 deletions
|
|
@ -1215,10 +1215,10 @@ impl Thread {
|
|||
stream: &ThreadEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// Extract saved output and status first, so they're available even if tool is not found
|
||||
let output = tool_result
|
||||
.as_ref()
|
||||
.and_then(|result| result.output.clone());
|
||||
let replay_content = tool_result.and_then(Self::tool_result_content_for_replay);
|
||||
let status = tool_result
|
||||
.as_ref()
|
||||
.map_or(acp::ToolCallStatus::Failed, |result| {
|
||||
|
|
@ -1255,13 +1255,13 @@ impl Thread {
|
|||
.raw_input(tool_use.input.clone()),
|
||||
)))
|
||||
.ok();
|
||||
stream.update_tool_call_fields(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields::new()
|
||||
.status(status)
|
||||
.raw_output(output),
|
||||
None,
|
||||
);
|
||||
let mut fields = acp::ToolCallUpdateFields::new()
|
||||
.status(status)
|
||||
.raw_output(output);
|
||||
if let Some(content) = replay_content {
|
||||
fields = fields.content(content);
|
||||
}
|
||||
stream.update_tool_call_fields(&tool_use.id, fields, None);
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
@ -1275,6 +1275,14 @@ impl Thread {
|
|||
tool_use.input.clone(),
|
||||
);
|
||||
|
||||
if let Some(content) = replay_content {
|
||||
stream.update_tool_call_fields(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields::new().content(content),
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(output) = output.clone() {
|
||||
// For replay, we use a dummy cancellation receiver since the tool already completed
|
||||
let (_cancellation_tx, cancellation_rx) = watch::channel(false);
|
||||
|
|
@ -1297,6 +1305,45 @@ impl Thread {
|
|||
);
|
||||
}
|
||||
|
||||
fn tool_result_content_for_replay(
|
||||
tool_result: &LanguageModelToolResult,
|
||||
) -> Option<Vec<acp::ToolCallContent>> {
|
||||
let has_image = tool_result
|
||||
.content
|
||||
.iter()
|
||||
.any(|part| matches!(part, LanguageModelToolResultContent::Image(_)));
|
||||
if !has_image && tool_result.output.is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content = tool_result
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(acp::ToolCallContent::Content(acp::Content::new(
|
||||
acp::ContentBlock::Text(acp::TextContent::new(text.to_string())),
|
||||
)))
|
||||
}
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => Some(
|
||||
acp::ToolCallContent::Content(acp::Content::new(acp::ContentBlock::Image(
|
||||
acp::ImageContent::new(image.source.clone(), "image/png"),
|
||||
))),
|
||||
),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_db(
|
||||
id: acp::SessionId,
|
||||
db_thread: DbThread,
|
||||
|
|
@ -4454,6 +4501,131 @@ mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
struct ReplayImageTool;
|
||||
|
||||
impl AgentTool for ReplayImageTool {
|
||||
type Input = ();
|
||||
type Output = String;
|
||||
|
||||
const NAME: &'static str = "registered_image_tool";
|
||||
|
||||
fn kind() -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(
|
||||
&self,
|
||||
_input: Result<Self::Input, serde_json::Value>,
|
||||
_cx: &mut App,
|
||||
) -> SharedString {
|
||||
"Registered Image Tool".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: ToolInput<Self::Input>,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<Self::Output, Self::Output>> {
|
||||
Task::ready(Ok(String::new()))
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replay_tool_call_replays_image_content(cx: &mut TestAppContext) {
|
||||
let (thread, _event_stream) = setup_thread_for_test(cx).await;
|
||||
|
||||
let registered_tool_use_id = LanguageModelToolUseId::from("registered_tool_id");
|
||||
let missing_tool_use_id = LanguageModelToolUseId::from("missing_tool_id");
|
||||
let image_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==";
|
||||
let image = LanguageModelImage {
|
||||
source: image_data.into(),
|
||||
};
|
||||
|
||||
let mut replay_events = cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(ReplayImageTool);
|
||||
|
||||
let registered_tool_use = LanguageModelToolUse {
|
||||
id: registered_tool_use_id.clone(),
|
||||
name: ReplayImageTool::NAME.into(),
|
||||
raw_input: "null".to_string(),
|
||||
input: json!(null),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
let missing_tool_use = LanguageModelToolUse {
|
||||
id: missing_tool_use_id.clone(),
|
||||
name: "missing_image_tool".into(),
|
||||
raw_input: "{}".to_string(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let mut tool_results = IndexMap::default();
|
||||
tool_results.insert(
|
||||
registered_tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: registered_tool_use_id.clone(),
|
||||
tool_name: ReplayImageTool::NAME.into(),
|
||||
is_error: false,
|
||||
content: vec![
|
||||
LanguageModelToolResultContent::Text("before".into()),
|
||||
LanguageModelToolResultContent::Image(image.clone()),
|
||||
LanguageModelToolResultContent::Text("after".into()),
|
||||
],
|
||||
output: Some(json!("raw output")),
|
||||
},
|
||||
);
|
||||
tool_results.insert(
|
||||
missing_tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: missing_tool_use_id.clone(),
|
||||
tool_name: "missing_image_tool".into(),
|
||||
is_error: false,
|
||||
content: vec![LanguageModelToolResultContent::Image(image.clone())],
|
||||
output: Some(json!("raw output")),
|
||||
},
|
||||
);
|
||||
|
||||
thread.messages.push(Message::Agent(AgentMessage {
|
||||
content: vec![
|
||||
AgentMessageContent::ToolUse(registered_tool_use),
|
||||
AgentMessageContent::ToolUse(missing_tool_use),
|
||||
],
|
||||
tool_results,
|
||||
reasoning_details: None,
|
||||
}));
|
||||
|
||||
thread.replay(cx)
|
||||
})
|
||||
});
|
||||
|
||||
let mut tool_use_ids_with_image_content = HashSet::default();
|
||||
while let Some(event) = replay_events.next().await {
|
||||
let event = event.unwrap();
|
||||
if let ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) =
|
||||
event
|
||||
&& let Some(content) = &update.fields.content
|
||||
&& content.iter().any(|content| {
|
||||
matches!(
|
||||
content,
|
||||
acp::ToolCallContent::Content(acp::Content {
|
||||
content: acp::ContentBlock::Image(_),
|
||||
..
|
||||
})
|
||||
)
|
||||
})
|
||||
{
|
||||
tool_use_ids_with_image_content.insert(update.tool_call_id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
assert!(tool_use_ids_with_image_content.contains(®istered_tool_use_id.to_string()));
|
||||
assert!(tool_use_ids_with_image_content.contains(&missing_tool_use_id.to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_set_model_propagates_to_subagents(cx: &mut TestAppContext) {
|
||||
let (parent, _event_stream) = setup_thread_for_test(cx).await;
|
||||
|
|
|
|||
Loading…
Reference in a new issue