mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Subagents PR 2: Thread spawning + execution (#46187)
This PR implements the behind-the-scenes subagent execution logic: ### Core subagent execution - Add `Thread::new_subagent()` constructor for creating subagent threads - Implement `SubagentTool::run()` to spawn and manage subagent lifecycle - Add `SubagentContext` for parent-child thread relationship - Implement `submit_user_message()`, `interrupt_for_summary()`, `request_final_summary()` - Add timeout support and context-low detection (25% threshold) - Propagate cancellation from parent to child threads ### Thread management - Add `MAX_SUBAGENT_DEPTH` (4) and `MAX_PARALLEL_SUBAGENTS` (8) limits - Add `register/unregister_running_subagent()` for tracking - Add `restrict_tools()` for allowed_tools filtering - Add `is_subagent()`, `depth()`, `is_turn_complete()` accessors ### Thread changes - Add `ToolCallContent::SubagentThread` variant - Add `ToolCallUpdateSubagentThread` for UI updates - Add `tool_name` field for subagent detection - Add `is_subagent()` method on `ToolCall` - Add image support in `ContentBlock` Release Notes: - N/A --------- Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
816e5f5c73
commit
756637fef5
6 changed files with 1891 additions and 30 deletions
|
|
@ -4,6 +4,26 @@ mod mention;
|
|||
mod terminal;
|
||||
|
||||
use agent_settings::AgentSettings;
|
||||
|
||||
/// Key used in ACP ToolCall meta to store the tool's programmatic name.
|
||||
/// This is a workaround since ACP's ToolCall doesn't have a dedicated name field.
|
||||
pub const TOOL_NAME_META_KEY: &str = "tool_name";
|
||||
|
||||
/// The tool name for subagent spawning
|
||||
pub const SUBAGENT_TOOL_NAME: &str = "subagent";
|
||||
|
||||
/// Helper to extract tool name from ACP meta
|
||||
pub fn tool_name_from_meta(meta: &Option<acp::Meta>) -> Option<SharedString> {
|
||||
meta.as_ref()
|
||||
.and_then(|m| m.get(TOOL_NAME_META_KEY))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| SharedString::from(s.to_owned()))
|
||||
}
|
||||
|
||||
/// Helper to create meta with tool name
|
||||
pub fn meta_with_tool_name(tool_name: &str) -> acp::Meta {
|
||||
acp::Meta::from_iter([(TOOL_NAME_META_KEY.into(), tool_name.into())])
|
||||
}
|
||||
use collections::HashSet;
|
||||
pub use connection::*;
|
||||
pub use diff::*;
|
||||
|
|
@ -195,6 +215,7 @@ pub struct ToolCall {
|
|||
pub raw_input: Option<serde_json::Value>,
|
||||
pub raw_input_markdown: Option<Entity<Markdown>>,
|
||||
pub raw_output: Option<serde_json::Value>,
|
||||
pub tool_name: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl ToolCall {
|
||||
|
|
@ -229,6 +250,8 @@ impl ToolCall {
|
|||
.as_ref()
|
||||
.and_then(|input| markdown_for_raw_output(input, &language_registry, cx));
|
||||
|
||||
let tool_name = tool_name_from_meta(&tool_call.meta);
|
||||
|
||||
let result = Self {
|
||||
id: tool_call.tool_call_id,
|
||||
label: cx
|
||||
|
|
@ -241,6 +264,7 @@ impl ToolCall {
|
|||
raw_input: tool_call.raw_input,
|
||||
raw_input_markdown,
|
||||
raw_output: tool_call.raw_output,
|
||||
tool_name,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
|
@ -338,6 +362,7 @@ impl ToolCall {
|
|||
ToolCallContent::Diff(diff) => Some(diff),
|
||||
ToolCallContent::ContentBlock(_) => None,
|
||||
ToolCallContent::Terminal(_) => None,
|
||||
ToolCallContent::SubagentThread(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -346,9 +371,26 @@ impl ToolCall {
|
|||
ToolCallContent::Terminal(terminal) => Some(terminal),
|
||||
ToolCallContent::ContentBlock(_) => None,
|
||||
ToolCallContent::Diff(_) => None,
|
||||
ToolCallContent::SubagentThread(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn subagent_thread(&self) -> Option<&Entity<AcpThread>> {
|
||||
self.content.iter().find_map(|content| match content {
|
||||
ToolCallContent::SubagentThread(thread) => Some(thread),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_subagent(&self) -> bool {
|
||||
matches!(self.kind, acp::ToolKind::Other)
|
||||
&& self
|
||||
.tool_name
|
||||
.as_ref()
|
||||
.map(|n| n.as_ref() == SUBAGENT_TOOL_NAME)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
let mut markdown = format!(
|
||||
"**Tool Call: {}**\nStatus: {}\n\n",
|
||||
|
|
@ -642,6 +684,7 @@ pub enum ToolCallContent {
|
|||
ContentBlock(ContentBlock),
|
||||
Diff(Entity<Diff>),
|
||||
Terminal(Entity<Terminal>),
|
||||
SubagentThread(Entity<AcpThread>),
|
||||
}
|
||||
|
||||
impl ToolCallContent {
|
||||
|
|
@ -713,6 +756,7 @@ impl ToolCallContent {
|
|||
Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
|
||||
Self::Diff(diff) => diff.read(cx).to_markdown(cx),
|
||||
Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
|
||||
Self::SubagentThread(thread) => thread.read(cx).to_markdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -722,6 +766,13 @@ impl ToolCallContent {
|
|||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subagent_thread(&self) -> Option<&Entity<AcpThread>> {
|
||||
match self {
|
||||
Self::SubagentThread(thread) => Some(thread),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
|
|
@ -729,6 +780,7 @@ pub enum ToolCallUpdate {
|
|||
UpdateFields(acp::ToolCallUpdate),
|
||||
UpdateDiff(ToolCallUpdateDiff),
|
||||
UpdateTerminal(ToolCallUpdateTerminal),
|
||||
UpdateSubagentThread(ToolCallUpdateSubagentThread),
|
||||
}
|
||||
|
||||
impl ToolCallUpdate {
|
||||
|
|
@ -737,6 +789,7 @@ impl ToolCallUpdate {
|
|||
Self::UpdateFields(update) => &update.tool_call_id,
|
||||
Self::UpdateDiff(diff) => &diff.id,
|
||||
Self::UpdateTerminal(terminal) => &terminal.id,
|
||||
Self::UpdateSubagentThread(subagent) => &subagent.id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -771,6 +824,18 @@ pub struct ToolCallUpdateTerminal {
|
|||
pub terminal: Entity<Terminal>,
|
||||
}
|
||||
|
||||
impl From<ToolCallUpdateSubagentThread> for ToolCallUpdate {
|
||||
fn from(subagent: ToolCallUpdateSubagentThread) -> Self {
|
||||
Self::UpdateSubagentThread(subagent)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ToolCallUpdateSubagentThread {
|
||||
pub id: acp::ToolCallId,
|
||||
pub thread: Entity<AcpThread>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Plan {
|
||||
pub entries: Vec<PlanEntry>,
|
||||
|
|
@ -1425,6 +1490,7 @@ impl AcpThread {
|
|||
raw_input: None,
|
||||
raw_input_markdown: None,
|
||||
raw_output: None,
|
||||
tool_name: None,
|
||||
};
|
||||
self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
|
||||
return Ok(());
|
||||
|
|
@ -1451,6 +1517,11 @@ impl AcpThread {
|
|||
call.content
|
||||
.push(ToolCallContent::Terminal(update.terminal));
|
||||
}
|
||||
ToolCallUpdate::UpdateSubagentThread(update) => {
|
||||
call.content.clear();
|
||||
call.content
|
||||
.push(ToolCallContent::SubagentThread(update.thread));
|
||||
}
|
||||
}
|
||||
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -58,6 +58,27 @@ use uuid::Uuid;
|
|||
|
||||
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
||||
pub const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
pub const MAX_SUBAGENT_DEPTH: u8 = 4;
|
||||
pub const MAX_PARALLEL_SUBAGENTS: usize = 8;
|
||||
|
||||
/// Context passed to a subagent thread for lifecycle management
|
||||
#[derive(Clone)]
|
||||
pub struct SubagentContext {
|
||||
/// ID of the parent thread
|
||||
pub parent_thread_id: acp::SessionId,
|
||||
|
||||
/// ID of the tool call that spawned this subagent
|
||||
pub tool_use_id: LanguageModelToolUseId,
|
||||
|
||||
/// Current depth level (0 = root agent, 1 = first-level subagent, etc.)
|
||||
pub depth: u8,
|
||||
|
||||
/// Prompt to send when subagent completes successfully
|
||||
pub summary_prompt: String,
|
||||
|
||||
/// Prompt to send when context is running low (≤25% remaining)
|
||||
pub context_low_prompt: String,
|
||||
}
|
||||
|
||||
/// The ID of the user prompt that initiated a request.
|
||||
///
|
||||
|
|
@ -626,6 +647,10 @@ pub struct Thread {
|
|||
pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
|
||||
/// True if this thread was imported from a shared thread and can be synced.
|
||||
imported: bool,
|
||||
/// If this is a subagent thread, contains context about the parent
|
||||
subagent_context: Option<SubagentContext>,
|
||||
/// Weak references to running subagent threads for cancellation propagation
|
||||
running_subagents: Vec<WeakEntity<Thread>>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
|
|
@ -683,6 +708,56 @@ impl Thread {
|
|||
action_log,
|
||||
file_read_times: HashMap::default(),
|
||||
imported: false,
|
||||
subagent_context: None,
|
||||
running_subagents: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_subagent(
|
||||
project: Entity<Project>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
templates: Arc<Templates>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
subagent_context: SubagentContext,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||
let (prompt_capabilities_tx, prompt_capabilities_rx) =
|
||||
watch::channel(Self::prompt_capabilities(Some(model.as_ref())));
|
||||
Self {
|
||||
id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()),
|
||||
prompt_id: PromptId::new(),
|
||||
updated_at: Utc::now(),
|
||||
title: None,
|
||||
pending_title_generation: None,
|
||||
pending_summary_generation: None,
|
||||
summary: None,
|
||||
messages: Vec::new(),
|
||||
user_store: project.read(cx).user_store(),
|
||||
completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
|
||||
running_turn: None,
|
||||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
request_token_usage: HashMap::default(),
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
initial_project_snapshot: Task::ready(None).shared(),
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
templates,
|
||||
model: Some(model),
|
||||
summarization_model: None,
|
||||
prompt_capabilities_tx,
|
||||
prompt_capabilities_rx,
|
||||
project,
|
||||
action_log,
|
||||
file_read_times: HashMap::default(),
|
||||
imported: false,
|
||||
subagent_context: Some(subagent_context),
|
||||
running_subagents: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -880,6 +955,8 @@ impl Thread {
|
|||
prompt_capabilities_rx,
|
||||
file_read_times: HashMap::default(),
|
||||
imported: db_thread.imported,
|
||||
subagent_context: None,
|
||||
running_subagents: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -984,7 +1061,6 @@ impl Thread {
|
|||
cx.notify()
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn last_message(&self) -> Option<Message> {
|
||||
if let Some(message) = self.pending_message.clone() {
|
||||
Some(Message::Agent(message))
|
||||
|
|
@ -1030,8 +1106,17 @@ impl Thread {
|
|||
self.add_tool(ThinkingTool);
|
||||
self.add_tool(WebSearchTool);
|
||||
|
||||
if cx.has_flag::<SubagentsFeatureFlag>() {
|
||||
self.add_tool(SubagentTool::new());
|
||||
if cx.has_flag::<SubagentsFeatureFlag>() && self.depth() < MAX_SUBAGENT_DEPTH {
|
||||
let tool_names = self.registered_tool_names();
|
||||
self.add_tool(SubagentTool::new(
|
||||
cx.weak_entity(),
|
||||
self.project.clone(),
|
||||
self.project_context.clone(),
|
||||
self.context_server_registry.clone(),
|
||||
self.templates.clone(),
|
||||
self.depth(),
|
||||
tool_names,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1043,6 +1128,10 @@ impl Thread {
|
|||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
pub fn restrict_tools(&mut self, allowed: &collections::HashSet<SharedString>) {
|
||||
self.tools.retain(|name, _| allowed.contains(name));
|
||||
}
|
||||
|
||||
pub fn profile(&self) -> &AgentProfileId {
|
||||
&self.profile_id
|
||||
}
|
||||
|
|
@ -1061,6 +1150,12 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
|
||||
for subagent in self.running_subagents.drain(..) {
|
||||
if let Some(subagent) = subagent.upgrade() {
|
||||
subagent.update(cx, |thread, cx| thread.cancel(cx)).detach();
|
||||
}
|
||||
}
|
||||
|
||||
let Some(running_turn) = self.running_turn.take() else {
|
||||
self.flush_pending_message(cx);
|
||||
return Task::ready(());
|
||||
|
|
@ -2138,6 +2233,82 @@ impl Thread {
|
|||
.is_some_and(|turn| turn.tools.contains_key(name))
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn has_registered_tool(&self, name: &str) -> bool {
|
||||
self.tools.contains_key(name)
|
||||
}
|
||||
|
||||
pub fn registered_tool_names(&self) -> Vec<SharedString> {
|
||||
self.tools.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn register_running_subagent(&mut self, subagent: WeakEntity<Thread>) {
|
||||
self.running_subagents.push(subagent);
|
||||
}
|
||||
|
||||
pub fn unregister_running_subagent(&mut self, subagent: &WeakEntity<Thread>) {
|
||||
self.running_subagents
|
||||
.retain(|s| s.entity_id() != subagent.entity_id());
|
||||
}
|
||||
|
||||
pub fn running_subagent_count(&self) -> usize {
|
||||
self.running_subagents
|
||||
.iter()
|
||||
.filter(|s| s.upgrade().is_some())
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn is_subagent(&self) -> bool {
|
||||
self.subagent_context.is_some()
|
||||
}
|
||||
|
||||
pub fn depth(&self) -> u8 {
|
||||
self.subagent_context.as_ref().map(|c| c.depth).unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn is_turn_complete(&self) -> bool {
|
||||
self.running_turn.is_none()
|
||||
}
|
||||
|
||||
pub fn submit_user_message(
|
||||
&mut self,
|
||||
content: impl Into<String>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
|
||||
let content = content.into();
|
||||
self.messages.push(Message::User(UserMessage {
|
||||
id: UserMessageId::new(),
|
||||
content: vec![UserMessageContent::Text(content)],
|
||||
}));
|
||||
cx.notify();
|
||||
self.send_existing(cx)
|
||||
}
|
||||
|
||||
pub fn interrupt_for_summary(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
|
||||
let context = self
|
||||
.subagent_context
|
||||
.as_ref()
|
||||
.context("Not a subagent thread")?;
|
||||
let prompt = context.context_low_prompt.clone();
|
||||
self.cancel(cx).detach();
|
||||
self.submit_user_message(prompt, cx)
|
||||
}
|
||||
|
||||
pub fn request_final_summary(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
|
||||
let context = self
|
||||
.subagent_context
|
||||
.as_ref()
|
||||
.context("Not a subagent thread")?;
|
||||
let prompt = context.summary_prompt.clone();
|
||||
self.submit_user_message(prompt, cx)
|
||||
}
|
||||
|
||||
fn build_request_messages(
|
||||
&self,
|
||||
available_tools: Vec<SharedString>,
|
||||
|
|
@ -2546,10 +2717,7 @@ impl ThreadEventStream {
|
|||
acp::ToolCall::new(id.to_string(), title)
|
||||
.kind(kind)
|
||||
.raw_input(input)
|
||||
.meta(acp::Meta::from_iter([(
|
||||
"tool_name".into(),
|
||||
tool_name.into(),
|
||||
)]))
|
||||
.meta(acp_thread::meta_with_tool_name(tool_name))
|
||||
}
|
||||
|
||||
fn update_tool_call_fields(
|
||||
|
|
@ -2645,6 +2813,10 @@ impl ToolCallEventStream {
|
|||
*self.cancellation_rx.clone().borrow()
|
||||
}
|
||||
|
||||
pub fn tool_use_id(&self) -> &LanguageModelToolUseId {
|
||||
&self.tool_use_id
|
||||
}
|
||||
|
||||
pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
|
||||
self.stream
|
||||
.update_tool_call_fields(&self.tool_use_id, fields);
|
||||
|
|
@ -2663,6 +2835,19 @@ impl ToolCallEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
pub fn update_subagent_thread(&self, thread: Entity<acp_thread::AcpThread>) {
|
||||
self.stream
|
||||
.0
|
||||
.unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
|
||||
acp_thread::ToolCallUpdateSubagentThread {
|
||||
id: acp::ToolCallId::new(self.tool_use_id.to_string()),
|
||||
thread,
|
||||
}
|
||||
.into(),
|
||||
)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
|
||||
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
|
||||
return Task::ready(Ok(()));
|
||||
|
|
|
|||
|
|
@ -1,11 +1,31 @@
|
|||
use acp_thread::{AcpThread, AgentConnection, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, SharedString, Task};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashSet;
|
||||
use futures::channel::mpsc;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::any::Any;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use util::ResultExt;
|
||||
use watch;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
use crate::{
|
||||
AgentTool, ContextServerRegistry, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH, SubagentContext,
|
||||
Templates, Thread, ThreadEvent, ToolCallAuthorization, ToolCallEventStream,
|
||||
};
|
||||
|
||||
/// When a subagent's remaining context window falls below this fraction (25%),
|
||||
/// the "context running out" prompt is sent to encourage the subagent to wrap up.
|
||||
const CONTEXT_LOW_THRESHOLD: f32 = 0.25;
|
||||
|
||||
/// Spawns a subagent with its own context window to perform a delegated task.
|
||||
///
|
||||
|
|
@ -63,11 +83,50 @@ pub struct SubagentToolInput {
|
|||
pub allowed_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
pub struct SubagentTool;
|
||||
pub struct SubagentTool {
|
||||
parent_thread: WeakEntity<Thread>,
|
||||
project: Entity<Project>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
templates: Arc<Templates>,
|
||||
current_depth: u8,
|
||||
parent_tool_names: HashSet<SharedString>,
|
||||
}
|
||||
|
||||
impl SubagentTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
pub fn new(
|
||||
parent_thread: WeakEntity<Thread>,
|
||||
project: Entity<Project>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
templates: Arc<Templates>,
|
||||
current_depth: u8,
|
||||
parent_tool_names: Vec<SharedString>,
|
||||
) -> Self {
|
||||
Self {
|
||||
parent_thread,
|
||||
project,
|
||||
project_context,
|
||||
context_server_registry,
|
||||
templates,
|
||||
current_depth,
|
||||
parent_tool_names: parent_tool_names.into_iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_allowed_tools(&self, allowed_tools: &Option<Vec<String>>) -> Result<()> {
|
||||
if let Some(tools) = allowed_tools {
|
||||
for tool in tools {
|
||||
if !self.parent_tool_names.contains(tool.as_str()) {
|
||||
return Err(anyhow!(
|
||||
"Tool '{}' is not available to the parent agent. Available tools: {:?}",
|
||||
tool,
|
||||
self.parent_tool_names.iter().collect::<Vec<_>>()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -76,7 +135,7 @@ impl AgentTool for SubagentTool {
|
|||
type Output = String;
|
||||
|
||||
fn name() -> &'static str {
|
||||
"subagent"
|
||||
acp_thread::SUBAGENT_TOOL_NAME
|
||||
}
|
||||
|
||||
fn kind() -> acp::ToolKind {
|
||||
|
|
@ -88,22 +147,405 @@ impl AgentTool for SubagentTool {
|
|||
input: Result<Self::Input, serde_json::Value>,
|
||||
_cx: &mut App,
|
||||
) -> SharedString {
|
||||
match input {
|
||||
Ok(input) => format!("Subagent: {}", input.label).into(),
|
||||
Err(_) => "Subagent".into(),
|
||||
}
|
||||
input
|
||||
.map(|i| i.label.into())
|
||||
.unwrap_or_else(|_| "Subagent".into())
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
event_stream.update_fields(
|
||||
acp::ToolCallUpdateFields::new()
|
||||
.content(vec![format!("Starting subagent: {}", input.label).into()]),
|
||||
);
|
||||
Task::ready(Ok("Subagent tool not yet implemented.".to_string()))
|
||||
if self.current_depth >= MAX_SUBAGENT_DEPTH {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Maximum subagent depth ({}) reached",
|
||||
MAX_SUBAGENT_DEPTH
|
||||
)));
|
||||
}
|
||||
|
||||
if let Err(e) = self.validate_allowed_tools(&input.allowed_tools) {
|
||||
return Task::ready(Err(e));
|
||||
}
|
||||
|
||||
let Some(parent_thread) = self.parent_thread.upgrade() else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Parent thread no longer exists (subagent depth={})",
|
||||
self.current_depth + 1
|
||||
)));
|
||||
};
|
||||
|
||||
let running_count = parent_thread.read(cx).running_subagent_count();
|
||||
if running_count >= MAX_PARALLEL_SUBAGENTS {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
|
||||
MAX_PARALLEL_SUBAGENTS
|
||||
)));
|
||||
}
|
||||
|
||||
let parent_thread_id = parent_thread.read(cx).id().clone();
|
||||
let parent_model = parent_thread.read(cx).model().cloned();
|
||||
let tool_use_id = event_stream.tool_use_id().clone();
|
||||
|
||||
let Some(model) = parent_model else {
|
||||
return Task::ready(Err(anyhow!("No model configured")));
|
||||
};
|
||||
|
||||
let subagent_context = SubagentContext {
|
||||
parent_thread_id,
|
||||
tool_use_id,
|
||||
depth: self.current_depth + 1,
|
||||
summary_prompt: input.summary_prompt.clone(),
|
||||
context_low_prompt: input.context_low_prompt.clone(),
|
||||
};
|
||||
|
||||
let project = self.project.clone();
|
||||
let project_context = self.project_context.clone();
|
||||
let context_server_registry = self.context_server_registry.clone();
|
||||
let templates = self.templates.clone();
|
||||
let task_prompt = input.task_prompt;
|
||||
let timeout_ms = input.timeout_ms;
|
||||
let allowed_tools: Option<HashSet<SharedString>> = input
|
||||
.allowed_tools
|
||||
.map(|tools| tools.into_iter().map(SharedString::from).collect());
|
||||
|
||||
let parent_thread = self.parent_thread.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let subagent_thread: Entity<Thread> = cx.new(|cx| {
|
||||
Thread::new_subagent(
|
||||
project.clone(),
|
||||
project_context.clone(),
|
||||
context_server_registry.clone(),
|
||||
templates.clone(),
|
||||
model,
|
||||
subagent_context,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let subagent_weak = subagent_thread.downgrade();
|
||||
|
||||
let acp_thread: Entity<AcpThread> = cx.new(|cx| {
|
||||
let session_id = subagent_thread.read(cx).id().clone();
|
||||
let action_log: Entity<ActionLog> = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let connection: Rc<dyn AgentConnection> = Rc::new(SubagentDisplayConnection);
|
||||
AcpThread::new(
|
||||
"Subagent",
|
||||
connection,
|
||||
project.clone(),
|
||||
action_log,
|
||||
session_id,
|
||||
watch::Receiver::constant(acp::PromptCapabilities::new()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
event_stream.update_subagent_thread(acp_thread.clone());
|
||||
|
||||
if let Some(parent) = parent_thread.upgrade() {
|
||||
parent.update(cx, |thread, _cx| {
|
||||
thread.register_running_subagent(subagent_weak.clone());
|
||||
});
|
||||
}
|
||||
|
||||
let result = run_subagent(
|
||||
&subagent_thread,
|
||||
&acp_thread,
|
||||
allowed_tools,
|
||||
task_prompt,
|
||||
timeout_ms,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(parent) = parent_thread.upgrade() {
|
||||
let _ = parent.update(cx, |thread, _cx| {
|
||||
thread.unregister_running_subagent(&subagent_weak);
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_subagent(
|
||||
subagent_thread: &Entity<Thread>,
|
||||
acp_thread: &Entity<AcpThread>,
|
||||
allowed_tools: Option<HashSet<SharedString>>,
|
||||
task_prompt: String,
|
||||
timeout_ms: Option<u64>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<String> {
|
||||
if let Some(ref allowed) = allowed_tools {
|
||||
subagent_thread.update(cx, |thread, _cx| {
|
||||
thread.restrict_tools(allowed);
|
||||
});
|
||||
}
|
||||
|
||||
let mut events_rx =
|
||||
subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?;
|
||||
|
||||
let acp_thread_weak = acp_thread.downgrade();
|
||||
|
||||
let timed_out = if let Some(timeout) = timeout_ms {
|
||||
forward_events_with_timeout(
|
||||
&mut events_rx,
|
||||
&acp_thread_weak,
|
||||
Duration::from_millis(timeout),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await;
|
||||
false
|
||||
};
|
||||
|
||||
let should_interrupt =
|
||||
timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx);
|
||||
|
||||
if should_interrupt {
|
||||
let mut summary_rx =
|
||||
subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?;
|
||||
forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
|
||||
} else {
|
||||
let mut summary_rx =
|
||||
subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?;
|
||||
forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
|
||||
}
|
||||
|
||||
Ok(extract_last_message(subagent_thread, cx))
|
||||
}
|
||||
|
||||
async fn forward_events_until_stop(
|
||||
events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
||||
acp_thread: &WeakEntity<AcpThread>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
while let Some(event) = events_rx.next().await {
|
||||
match event {
|
||||
Ok(ThreadEvent::Stop(_)) => break,
|
||||
Ok(event) => {
|
||||
forward_event_to_acp_thread(event, acp_thread, cx);
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn forward_events_with_timeout(
|
||||
events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
||||
acp_thread: &WeakEntity<AcpThread>,
|
||||
timeout: Duration,
|
||||
cx: &mut AsyncApp,
|
||||
) -> bool {
|
||||
use futures::future::{self, Either};
|
||||
|
||||
let deadline = std::time::Instant::now() + timeout;
|
||||
|
||||
loop {
|
||||
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
|
||||
if remaining.is_zero() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let timeout_future = cx.background_executor().timer(remaining);
|
||||
let event_future = events_rx.next();
|
||||
|
||||
match future::select(event_future, timeout_future).await {
|
||||
Either::Left((event, _)) => match event {
|
||||
Some(Ok(ThreadEvent::Stop(_))) => return false,
|
||||
Some(Ok(event)) => {
|
||||
forward_event_to_acp_thread(event, acp_thread, cx);
|
||||
}
|
||||
Some(Err(_)) => return false,
|
||||
None => return false,
|
||||
},
|
||||
Either::Right((_, _)) => return true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_event_to_acp_thread(
|
||||
event: ThreadEvent,
|
||||
acp_thread: &WeakEntity<AcpThread>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
match event {
|
||||
ThreadEvent::UserMessage(message) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| {
|
||||
for content in message.content {
|
||||
thread.push_user_content_block(
|
||||
Some(message.id.clone()),
|
||||
content.into(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::AgentText(text) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(text.into(), false, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::AgentThinking(text) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(text.into(), true, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
}) => {
|
||||
let outcome_task = acp_thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, options, true, cx)
|
||||
});
|
||||
if let Ok(Ok(task)) = outcome_task {
|
||||
cx.background_spawn(async move {
|
||||
if let acp::RequestPermissionOutcome::Selected(
|
||||
acp::SelectedPermissionOutcome { option_id, .. },
|
||||
) = task.await
|
||||
{
|
||||
response.send(option_id).ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
ThreadEvent::ToolCall(tool_call) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx))
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::ToolCallUpdate(update) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| thread.update_tool_call(update, cx))
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::Retry(status) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| thread.update_retry_status(status, cx))
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::Stop(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_context_low(thread: &Entity<Thread>, threshold: f32, cx: &mut AsyncApp) -> bool {
|
||||
thread.read_with(cx, |thread, _| {
|
||||
if let Some(usage) = thread.latest_token_usage() {
|
||||
let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
|
||||
remaining_ratio <= threshold
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_last_message(thread: &Entity<Thread>, cx: &mut AsyncApp) -> String {
|
||||
thread.read_with(cx, |thread, _| {
|
||||
thread
|
||||
.last_message()
|
||||
.map(|m| m.to_markdown())
|
||||
.unwrap_or_else(|| "No response from subagent".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
|
||||
#[test]
|
||||
fn test_subagent_tool_input_json_schema_is_valid() {
|
||||
let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema);
|
||||
let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON");
|
||||
|
||||
assert!(
|
||||
schema_json.get("properties").is_some(),
|
||||
"schema should have properties"
|
||||
);
|
||||
let properties = schema_json.get("properties").unwrap();
|
||||
|
||||
assert!(properties.get("label").is_some(), "should have label field");
|
||||
assert!(
|
||||
properties.get("task_prompt").is_some(),
|
||||
"should have task_prompt field"
|
||||
);
|
||||
assert!(
|
||||
properties.get("summary_prompt").is_some(),
|
||||
"should have summary_prompt field"
|
||||
);
|
||||
assert!(
|
||||
properties.get("context_low_prompt").is_some(),
|
||||
"should have context_low_prompt field"
|
||||
);
|
||||
assert!(
|
||||
properties.get("timeout_ms").is_some(),
|
||||
"should have timeout_ms field"
|
||||
);
|
||||
assert!(
|
||||
properties.get("allowed_tools").is_some(),
|
||||
"should have allowed_tools field"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subagent_tool_name() {
|
||||
assert_eq!(SubagentTool::name(), "subagent");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subagent_tool_kind() {
|
||||
assert_eq!(SubagentTool::kind(), acp::ToolKind::Other);
|
||||
}
|
||||
}
|
||||
|
||||
struct SubagentDisplayConnection;
|
||||
|
||||
impl AgentConnection for SubagentDisplayConnection {
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
"subagent".into()
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
_project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
unimplemented!("SubagentDisplayConnection does not support new_thread")
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
unimplemented!("SubagentDisplayConnection does not support authenticate")
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<UserMessageId>,
|
||||
_params: acp::PromptRequest,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
unimplemented!("SubagentDisplayConnection does not support prompt")
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3309,6 +3309,12 @@ impl AcpThreadView {
|
|||
ToolCallContent::Terminal(terminal) => {
|
||||
self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
|
||||
}
|
||||
ToolCallContent::SubagentThread(_thread) => {
|
||||
// The subagent's AcpThread entity stores the subagent's conversation
|
||||
// (messages, tool calls, etc.) but we don't render it here. The entity
|
||||
// is used for serialization (e.g., to_markdown) and data storage, not display.
|
||||
Empty.into_any_element()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ impl ExampleContext {
|
|||
ThreadEvent::ToolCall(tool_call) => {
|
||||
let meta = tool_call.meta.expect("Missing meta field in tool_call");
|
||||
let tool_name = meta
|
||||
.get("tool_name")
|
||||
.get(acp_thread::TOOL_NAME_META_KEY)
|
||||
.expect("Missing tool_name field in meta")
|
||||
.as_str()
|
||||
.expect("Unknown tool_name content in meta");
|
||||
|
|
|
|||
Loading…
Reference in a new issue