diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index ae5c5510764..55d7b62b99b 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1215,10 +1215,10 @@ impl Thread { stream: &ThreadEventStream, cx: &mut Context, ) { - // Extract saved output and status first, so they're available even if tool is not found let output = tool_result .as_ref() .and_then(|result| result.output.clone()); + let replay_content = tool_result.and_then(Self::tool_result_content_for_replay); let status = tool_result .as_ref() .map_or(acp::ToolCallStatus::Failed, |result| { @@ -1255,13 +1255,13 @@ impl Thread { .raw_input(tool_use.input.clone()), ))) .ok(); - stream.update_tool_call_fields( - &tool_use.id, - acp::ToolCallUpdateFields::new() - .status(status) - .raw_output(output), - None, - ); + let mut fields = acp::ToolCallUpdateFields::new() + .status(status) + .raw_output(output); + if let Some(content) = replay_content { + fields = fields.content(content); + } + stream.update_tool_call_fields(&tool_use.id, fields, None); return; }; @@ -1275,6 +1275,14 @@ impl Thread { tool_use.input.clone(), ); + if let Some(content) = replay_content { + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields::new().content(content), + None, + ); + } + if let Some(output) = output.clone() { // For replay, we use a dummy cancellation receiver since the tool already completed let (_cancellation_tx, cancellation_rx) = watch::channel(false); @@ -1297,6 +1305,45 @@ impl Thread { ); } + fn tool_result_content_for_replay( + tool_result: &LanguageModelToolResult, + ) -> Option> { + let has_image = tool_result + .content + .iter() + .any(|part| matches!(part, LanguageModelToolResultContent::Image(_))); + if !has_image && tool_result.output.is_some() { + return None; + } + + let content = tool_result + .content + .iter() + .filter_map(|part| match part { + LanguageModelToolResultContent::Text(text) => { + if text.is_empty() { + None + } else { + Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), + ))) + } + } + LanguageModelToolResultContent::Image(image) => Some( + acp::ToolCallContent::Content(acp::Content::new(acp::ContentBlock::Image( + acp::ImageContent::new(image.source.clone(), "image/png"), + ))), + ), + }) + .collect::>(); + + if content.is_empty() { + None + } else { + Some(content) + } + } + pub fn from_db( id: acp::SessionId, db_thread: DbThread, @@ -4454,6 +4501,131 @@ mod tests { }) } + struct ReplayImageTool; + + impl AgentTool for ReplayImageTool { + type Input = (); + type Output = String; + + const NAME: &'static str = "registered_image_tool"; + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "Registered Image Tool".into() + } + + fn run( + self: Arc, + _input: ToolInput, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(String::new())) + } + } + + #[gpui::test] + async fn test_replay_tool_call_replays_image_content(cx: &mut TestAppContext) { + let (thread, _event_stream) = setup_thread_for_test(cx).await; + + let registered_tool_use_id = LanguageModelToolUseId::from("registered_tool_id"); + let missing_tool_use_id = LanguageModelToolUseId::from("missing_tool_id"); + let image_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg=="; + let image = LanguageModelImage { + source: image_data.into(), + }; + + let mut replay_events = cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.add_tool(ReplayImageTool); + + let registered_tool_use = LanguageModelToolUse { + id: registered_tool_use_id.clone(), + name: ReplayImageTool::NAME.into(), + raw_input: "null".to_string(), + input: json!(null), + is_input_complete: true, + thought_signature: None, + }; + let missing_tool_use = LanguageModelToolUse { + id: missing_tool_use_id.clone(), + name: "missing_image_tool".into(), + raw_input: "{}".to_string(), + input: json!({}), + is_input_complete: true, + thought_signature: None, + }; + + let mut tool_results = IndexMap::default(); + tool_results.insert( + registered_tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id: registered_tool_use_id.clone(), + tool_name: ReplayImageTool::NAME.into(), + is_error: false, + content: vec![ + LanguageModelToolResultContent::Text("before".into()), + LanguageModelToolResultContent::Image(image.clone()), + LanguageModelToolResultContent::Text("after".into()), + ], + output: Some(json!("raw output")), + }, + ); + tool_results.insert( + missing_tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id: missing_tool_use_id.clone(), + tool_name: "missing_image_tool".into(), + is_error: false, + content: vec![LanguageModelToolResultContent::Image(image.clone())], + output: Some(json!("raw output")), + }, + ); + + thread.messages.push(Message::Agent(AgentMessage { + content: vec![ + AgentMessageContent::ToolUse(registered_tool_use), + AgentMessageContent::ToolUse(missing_tool_use), + ], + tool_results, + reasoning_details: None, + })); + + thread.replay(cx) + }) + }); + + let mut tool_use_ids_with_image_content = HashSet::default(); + while let Some(event) = replay_events.next().await { + let event = event.unwrap(); + if let ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) = + event + && let Some(content) = &update.fields.content + && content.iter().any(|content| { + matches!( + content, + acp::ToolCallContent::Content(acp::Content { + content: acp::ContentBlock::Image(_), + .. + }) + ) + }) + { + tool_use_ids_with_image_content.insert(update.tool_call_id.to_string()); + } + } + + assert!(tool_use_ids_with_image_content.contains(®istered_tool_use_id.to_string())); + assert!(tool_use_ids_with_image_content.contains(&missing_tool_use_id.to_string())); + } + #[gpui::test] async fn test_set_model_propagates_to_subagents(cx: &mut TestAppContext) { let (parent, _event_stream) = setup_thread_for_test(cx).await;