mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - agent: Improve reliability when LLM edits file
7443 lines
243 KiB
Rust
7443 lines
243 KiB
Rust
use super::*;
|
|
use acp_thread::{
|
|
AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
|
|
UserMessageId,
|
|
};
|
|
use agent_client_protocol::schema as acp;
|
|
use agent_settings::AgentProfileId;
|
|
use anyhow::Result;
|
|
use client::{Client, RefreshLlmTokenListener, UserStore};
|
|
use collections::IndexMap;
|
|
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
|
use feature_flags::FeatureFlagAppExt as _;
|
|
use fs::{FakeFs, Fs};
|
|
use futures::{
|
|
FutureExt as _, StreamExt,
|
|
channel::{
|
|
mpsc::{self, UnboundedReceiver},
|
|
oneshot,
|
|
},
|
|
future::{Fuse, Shared},
|
|
};
|
|
use gpui::{
|
|
App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
|
|
http_client::FakeHttpClient,
|
|
};
|
|
use indoc::indoc;
|
|
use language_model::{
|
|
CompletionIntent, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
|
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
|
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
|
LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
|
|
fake_provider::FakeLanguageModel,
|
|
};
|
|
use pretty_assertions::assert_eq;
|
|
use project::{
|
|
Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
|
|
};
|
|
use prompt_store::ProjectContext;
|
|
use reqwest_client::ReqwestClient;
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
use settings::{Settings, SettingsStore};
|
|
use std::{
|
|
path::Path,
|
|
pin::Pin,
|
|
rc::Rc,
|
|
sync::{
|
|
Arc,
|
|
atomic::{AtomicBool, AtomicUsize, Ordering},
|
|
},
|
|
time::Duration,
|
|
};
|
|
use util::path;
|
|
|
|
mod test_tools;
|
|
use test_tools::*;
|
|
|
|
pub(crate) fn init_test(cx: &mut TestAppContext) {
|
|
cx.update(|cx| {
|
|
let settings_store = SettingsStore::test(cx);
|
|
cx.set_global(settings_store);
|
|
});
|
|
}
|
|
|
|
pub(crate) struct FakeTerminalHandle {
|
|
killed: Arc<AtomicBool>,
|
|
stopped_by_user: Arc<AtomicBool>,
|
|
exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
|
|
wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
|
|
output: acp::TerminalOutputResponse,
|
|
id: acp::TerminalId,
|
|
}
|
|
|
|
impl FakeTerminalHandle {
|
|
pub(crate) fn new_never_exits(cx: &mut App) -> Self {
|
|
let killed = Arc::new(AtomicBool::new(false));
|
|
let stopped_by_user = Arc::new(AtomicBool::new(false));
|
|
|
|
let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
|
|
|
|
let wait_for_exit = cx
|
|
.spawn(async move |_cx| {
|
|
// Wait for the exit signal (sent when kill() is called)
|
|
let _ = exit_receiver.await;
|
|
acp::TerminalExitStatus::new()
|
|
})
|
|
.shared();
|
|
|
|
Self {
|
|
killed,
|
|
stopped_by_user,
|
|
exit_sender: std::cell::RefCell::new(Some(exit_sender)),
|
|
wait_for_exit,
|
|
output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
|
|
id: acp::TerminalId::new("fake_terminal".to_string()),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
|
|
let killed = Arc::new(AtomicBool::new(false));
|
|
let stopped_by_user = Arc::new(AtomicBool::new(false));
|
|
let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
|
|
|
|
let wait_for_exit = cx
|
|
.spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
|
|
.shared();
|
|
|
|
Self {
|
|
killed,
|
|
stopped_by_user,
|
|
exit_sender: std::cell::RefCell::new(Some(exit_sender)),
|
|
wait_for_exit,
|
|
output: acp::TerminalOutputResponse::new("command output".to_string(), false),
|
|
id: acp::TerminalId::new("fake_terminal".to_string()),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn was_killed(&self) -> bool {
|
|
self.killed.load(Ordering::SeqCst)
|
|
}
|
|
|
|
pub(crate) fn set_stopped_by_user(&self, stopped: bool) {
|
|
self.stopped_by_user.store(stopped, Ordering::SeqCst);
|
|
}
|
|
|
|
pub(crate) fn signal_exit(&self) {
|
|
if let Some(sender) = self.exit_sender.borrow_mut().take() {
|
|
let _ = sender.send(());
|
|
}
|
|
}
|
|
}
|
|
|
|
impl crate::TerminalHandle for FakeTerminalHandle {
|
|
fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
|
|
Ok(self.id.clone())
|
|
}
|
|
|
|
fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
|
|
Ok(self.output.clone())
|
|
}
|
|
|
|
fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
|
|
Ok(self.wait_for_exit.clone())
|
|
}
|
|
|
|
fn kill(&self, _cx: &AsyncApp) -> Result<()> {
|
|
self.killed.store(true, Ordering::SeqCst);
|
|
self.signal_exit();
|
|
Ok(())
|
|
}
|
|
|
|
fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
|
|
Ok(self.stopped_by_user.load(Ordering::SeqCst))
|
|
}
|
|
}
|
|
|
|
struct FakeSubagentHandle {
|
|
session_id: acp::SessionId,
|
|
send_task: Shared<Task<String>>,
|
|
}
|
|
|
|
impl SubagentHandle for FakeSubagentHandle {
|
|
fn id(&self) -> acp::SessionId {
|
|
self.session_id.clone()
|
|
}
|
|
|
|
fn num_entries(&self, _cx: &App) -> usize {
|
|
unimplemented!()
|
|
}
|
|
|
|
fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
|
|
let task = self.send_task.clone();
|
|
cx.background_spawn(async move { Ok(task.await) })
|
|
}
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub(crate) struct FakeThreadEnvironment {
|
|
terminal_handle: Option<Rc<FakeTerminalHandle>>,
|
|
subagent_handle: Option<Rc<FakeSubagentHandle>>,
|
|
terminal_creations: Arc<AtomicUsize>,
|
|
}
|
|
|
|
impl FakeThreadEnvironment {
|
|
pub(crate) fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
|
|
Self {
|
|
terminal_handle: Some(terminal_handle.into()),
|
|
..self
|
|
}
|
|
}
|
|
|
|
pub(crate) fn terminal_creation_count(&self) -> usize {
|
|
self.terminal_creations.load(Ordering::SeqCst)
|
|
}
|
|
}
|
|
|
|
impl crate::ThreadEnvironment for FakeThreadEnvironment {
|
|
fn create_terminal(
|
|
&self,
|
|
_command: String,
|
|
_cwd: Option<std::path::PathBuf>,
|
|
_output_byte_limit: Option<u64>,
|
|
_cx: &mut AsyncApp,
|
|
) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
|
|
self.terminal_creations.fetch_add(1, Ordering::SeqCst);
|
|
let handle = self
|
|
.terminal_handle
|
|
.clone()
|
|
.expect("Terminal handle not available on FakeThreadEnvironment");
|
|
Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
|
|
}
|
|
|
|
fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
|
|
Ok(self
|
|
.subagent_handle
|
|
.clone()
|
|
.expect("Subagent handle not available on FakeThreadEnvironment")
|
|
as Rc<dyn SubagentHandle>)
|
|
}
|
|
}
|
|
|
|
/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
|
|
struct MultiTerminalEnvironment {
|
|
handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
|
|
}
|
|
|
|
impl MultiTerminalEnvironment {
|
|
fn new() -> Self {
|
|
Self {
|
|
handles: std::cell::RefCell::new(Vec::new()),
|
|
}
|
|
}
|
|
|
|
fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
|
|
self.handles.borrow().clone()
|
|
}
|
|
}
|
|
|
|
impl crate::ThreadEnvironment for MultiTerminalEnvironment {
|
|
fn create_terminal(
|
|
&self,
|
|
_command: String,
|
|
_cwd: Option<std::path::PathBuf>,
|
|
_output_byte_limit: Option<u64>,
|
|
cx: &mut AsyncApp,
|
|
) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
|
|
let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
|
|
self.handles.borrow_mut().push(handle.clone());
|
|
Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
|
|
}
|
|
|
|
fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
|
|
unimplemented!()
|
|
}
|
|
}
|
|
|
|
fn always_allow_tools(cx: &mut TestAppContext) {
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_echo(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hello");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let events = events.collect().await;
|
|
thread.update(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.last_received_or_pending_message().unwrap().role(),
|
|
Role::Assistant
|
|
);
|
|
assert_eq!(
|
|
thread
|
|
.last_received_or_pending_message()
|
|
.unwrap()
|
|
.to_markdown(),
|
|
"Hello\n"
|
|
)
|
|
});
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
let project = Project::test(fs, [], cx).await;
|
|
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
let handle = environment.terminal_handle.clone().unwrap();
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::TerminalTool::new(project, environment));
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::TerminalToolInput {
|
|
command: "sleep 1000".to_string(),
|
|
cd: ".".to_string(),
|
|
timeout_ms: Some(5),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let update = rx.expect_update_fields().await;
|
|
assert!(
|
|
update.content.iter().any(|blocks| {
|
|
blocks
|
|
.iter()
|
|
.any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
|
|
}),
|
|
"expected tool call update to include terminal content"
|
|
);
|
|
|
|
let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
|
|
|
|
let deadline = std::time::Instant::now() + Duration::from_millis(500);
|
|
loop {
|
|
if let Some(result) = task_future.as_mut().now_or_never() {
|
|
let result = result.expect("terminal tool task should complete");
|
|
|
|
assert!(
|
|
handle.was_killed(),
|
|
"expected terminal handle to be killed on timeout"
|
|
);
|
|
assert!(
|
|
result.contains("partial output"),
|
|
"expected result to include terminal output, got: {result}"
|
|
);
|
|
return;
|
|
}
|
|
|
|
if std::time::Instant::now() >= deadline {
|
|
panic!("timed out waiting for terminal tool task to complete");
|
|
}
|
|
|
|
cx.run_until_parked();
|
|
cx.background_executor.timer(Duration::from_millis(1)).await;
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
#[ignore]
|
|
async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
let project = Project::test(fs, [], cx).await;
|
|
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
let handle = environment.terminal_handle.clone().unwrap();
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::TerminalTool::new(project, environment));
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
|
|
let _task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::TerminalToolInput {
|
|
command: "sleep 1000".to_string(),
|
|
cd: ".".to_string(),
|
|
timeout_ms: None,
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let update = rx.expect_update_fields().await;
|
|
assert!(
|
|
update.content.iter().any(|blocks| {
|
|
blocks
|
|
.iter()
|
|
.any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
|
|
}),
|
|
"expected tool call update to include terminal content"
|
|
);
|
|
|
|
cx.background_executor
|
|
.timer(Duration::from_millis(25))
|
|
.await;
|
|
|
|
assert!(
|
|
!handle.was_killed(),
|
|
"did not expect terminal handle to be killed without a timeout"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_thinking(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
[indoc! {"
|
|
Testing:
|
|
|
|
Generate a thinking step where you just think the word 'Think',
|
|
and have your final answer be 'Hello'
|
|
"}],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
|
|
text: "Think".to_string(),
|
|
signature: None,
|
|
});
|
|
fake_model.send_last_completion_stream_text_chunk("Hello");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let events = events.collect().await;
|
|
thread.update(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.last_received_or_pending_message().unwrap().role(),
|
|
Role::Assistant
|
|
);
|
|
assert_eq!(
|
|
thread
|
|
.last_received_or_pending_message()
|
|
.unwrap()
|
|
.to_markdown(),
|
|
indoc! {"
|
|
<think>Think</think>
|
|
Hello
|
|
"}
|
|
)
|
|
});
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_system_prompt(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model,
|
|
thread,
|
|
project_context,
|
|
..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
project_context.update(cx, |project_context, _cx| {
|
|
project_context.shell = "test-shell".into()
|
|
});
|
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
let mut pending_completions = fake_model.pending_completions();
|
|
assert_eq!(
|
|
pending_completions.len(),
|
|
1,
|
|
"unexpected pending completions: {:?}",
|
|
pending_completions
|
|
);
|
|
|
|
let pending_completion = pending_completions.pop().unwrap();
|
|
assert_eq!(pending_completion.messages[0].role, Role::System);
|
|
|
|
let system_message = &pending_completion.messages[0];
|
|
let MessageContent::Text(system_prompt) = &system_message.content[0] else {
|
|
panic!("Expected text content");
|
|
};
|
|
assert!(
|
|
system_prompt.contains("test-shell"),
|
|
"unexpected system message: {:?}",
|
|
system_message
|
|
);
|
|
assert!(
|
|
system_prompt.contains("## Fixing Diagnostics"),
|
|
"unexpected system message: {:?}",
|
|
system_message
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
let mut pending_completions = fake_model.pending_completions();
|
|
assert_eq!(
|
|
pending_completions.len(),
|
|
1,
|
|
"unexpected pending completions: {:?}",
|
|
pending_completions
|
|
);
|
|
|
|
let pending_completion = pending_completions.pop().unwrap();
|
|
assert_eq!(pending_completion.messages[0].role, Role::System);
|
|
|
|
let system_message = &pending_completion.messages[0];
|
|
let MessageContent::Text(system_prompt) = &system_message.content[0] else {
|
|
panic!("Expected text content");
|
|
};
|
|
assert!(
|
|
!system_prompt.contains("## Tool Use"),
|
|
"unexpected system message: {:?}",
|
|
system_message
|
|
);
|
|
assert!(
|
|
!system_prompt.contains("## Fixing Diagnostics"),
|
|
"unexpected system message: {:?}",
|
|
system_message
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_prompt_caching(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// Send initial user message and verify it's cached
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
completion.messages[1..],
|
|
vec![LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Message 1".into()],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
}]
|
|
);
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
|
|
"Response to Message 1".into(),
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// Send another user message and verify only the latest is cached
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Message 2"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
completion.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Message 1".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec!["Response to Message 1".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Message 2".into()],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
}
|
|
]
|
|
);
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
|
|
"Response to Message 2".into(),
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// Simulate a tool call and verify that the latest tool result is cached
|
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let tool_use = LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: EchoTool::NAME.into(),
|
|
raw_input: json!({"text": "test"}).to_string(),
|
|
input: json!({"text": "test"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let tool_result = LanguageModelToolResult {
|
|
tool_use_id: "tool_1".into(),
|
|
tool_name: EchoTool::NAME.into(),
|
|
is_error: false,
|
|
content: vec!["test".into()],
|
|
output: Some("test".into()),
|
|
};
|
|
assert_eq!(
|
|
completion.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Message 1".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec!["Response to Message 1".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Message 2".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec!["Response to Message 2".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Use the echo tool".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec![MessageContent::ToolUse(tool_use)],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![MessageContent::ToolResult(tool_result)],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
}
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
#[cfg_attr(not(feature = "e2e"), ignore)]
|
|
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
|
|
|
// Test a tool call that's likely to complete *before* streaming stops.
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(EchoTool);
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap()
|
|
.collect()
|
|
.await;
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
|
|
|
// Test a tool calls that's likely to complete *after* streaming stops.
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.remove_tool(&EchoTool::NAME);
|
|
thread.add_tool(DelayTool);
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
[
|
|
"Now call the delay tool with 200ms.",
|
|
"When the timer goes off, then you echo the output of the tool.",
|
|
],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap()
|
|
.collect()
|
|
.await;
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
|
thread.update(cx, |thread, _cx| {
|
|
assert!(
|
|
thread
|
|
.last_received_or_pending_message()
|
|
.unwrap()
|
|
.as_agent_message()
|
|
.unwrap()
|
|
.content
|
|
.iter()
|
|
.any(|content| {
|
|
if let AgentMessageContent::Text(text) = content {
|
|
text.contains("Ding")
|
|
} else {
|
|
false
|
|
}
|
|
}),
|
|
"{}",
|
|
thread.to_markdown()
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
#[cfg_attr(not(feature = "e2e"), ignore)]
|
|
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
|
|
|
// Test a tool call that's likely to complete *before* streaming stops.
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(WordListTool);
|
|
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
|
})
|
|
.unwrap();
|
|
|
|
let mut saw_partial_tool_use = false;
|
|
while let Some(event) = events.next().await {
|
|
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
|
|
thread.update(cx, |thread, _cx| {
|
|
// Look for a tool use in the thread's last message
|
|
let message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = message.as_agent_message().unwrap();
|
|
let last_content = agent_message.content.last().unwrap();
|
|
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
|
|
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
|
if tool_call.status == acp::ToolCallStatus::Pending {
|
|
if !last_tool_use.is_input_complete
|
|
&& last_tool_use.input.get("g").is_none()
|
|
{
|
|
saw_partial_tool_use = true;
|
|
}
|
|
} else {
|
|
last_tool_use
|
|
.input
|
|
.get("a")
|
|
.expect("'a' has streamed because input is now complete");
|
|
last_tool_use
|
|
.input
|
|
.get("g")
|
|
.expect("'g' has streamed because input is now complete");
|
|
}
|
|
} else {
|
|
panic!("last content should be a tool use");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
assert!(
|
|
saw_partial_tool_use,
|
|
"should see at least one partially streamed tool use in the history"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_tool_authorization(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_1".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_2".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
|
|
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
|
|
|
|
// Approve the first - send "allow" option_id (UI transforms "once" to "allow")
|
|
tool_call_auth_1
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("allow"),
|
|
acp::PermissionOptionKind::AllowOnce,
|
|
))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Reject the second - send "deny" option_id directly since Deny is now a button
|
|
tool_call_auth_2
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("deny"),
|
|
acp::PermissionOptionKind::RejectOnce,
|
|
))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
assert_eq!(
|
|
message.content,
|
|
vec![
|
|
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
|
|
tool_name: ToolRequiringPermission::NAME.into(),
|
|
is_error: false,
|
|
content: vec!["Allowed".into()],
|
|
output: Some("Allowed".into())
|
|
}),
|
|
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
|
|
tool_name: ToolRequiringPermission::NAME.into(),
|
|
is_error: true,
|
|
content: vec!["Permission to run tool denied by user".into()],
|
|
output: Some("Permission to run tool denied by user".into())
|
|
})
|
|
]
|
|
);
|
|
|
|
// Simulate yet another tool call.
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_3".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Respond by always allowing tools - send transformed option_id
|
|
// (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
|
|
let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
|
|
tool_call_auth_3
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("always_allow:tool_requiring_permission"),
|
|
acp::PermissionOptionKind::AllowAlways,
|
|
))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
assert_eq!(
|
|
message.content,
|
|
vec![language_model::MessageContent::ToolResult(
|
|
LanguageModelToolResult {
|
|
tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
|
|
tool_name: ToolRequiringPermission::NAME.into(),
|
|
is_error: false,
|
|
content: vec!["Allowed".into()],
|
|
output: Some("Allowed".into())
|
|
}
|
|
)]
|
|
);
|
|
|
|
// Simulate a final tool call, ensuring we don't trigger authorization.
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_4".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
assert_eq!(
|
|
message.content,
|
|
vec![language_model::MessageContent::ToolResult(
|
|
LanguageModelToolResult {
|
|
tool_use_id: "tool_id_4".into(),
|
|
tool_name: ToolRequiringPermission::NAME.into(),
|
|
is_error: false,
|
|
content: vec!["Allowed".into()],
|
|
output: Some("Allowed".into())
|
|
}
|
|
)]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_1".into(),
|
|
name: "nonexistent_tool".into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let tool_call = expect_tool_call(&mut events).await;
|
|
assert_eq!(tool_call.title, "nonexistent_tool");
|
|
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
|
|
let update = expect_tool_call_update_fields(&mut events).await;
|
|
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
|
}
|
|
|
|
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
|
|
let event = events
|
|
.next()
|
|
.await
|
|
.expect("no tool call authorization event received")
|
|
.unwrap();
|
|
match event {
|
|
ThreadEvent::ToolCall(tool_call) => tool_call,
|
|
event => {
|
|
panic!("Unexpected event {event:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn expect_tool_call_update_fields(
|
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
|
) -> acp::ToolCallUpdate {
|
|
let event = events
|
|
.next()
|
|
.await
|
|
.expect("no tool call authorization event received")
|
|
.unwrap();
|
|
match event {
|
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
|
|
event => {
|
|
panic!("Unexpected event {event:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn expect_plan(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::Plan {
|
|
let event = events
|
|
.next()
|
|
.await
|
|
.expect("no plan event received")
|
|
.unwrap();
|
|
match event {
|
|
ThreadEvent::Plan(plan) => plan,
|
|
event => {
|
|
panic!("Unexpected event {event:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn next_tool_call_authorization(
|
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
|
) -> ToolCallAuthorization {
|
|
loop {
|
|
let event = events
|
|
.next()
|
|
.await
|
|
.expect("no tool call authorization event received")
|
|
.unwrap();
|
|
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
|
let permission_kinds = tool_call_authorization
|
|
.options
|
|
.first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
|
|
.map(|option| option.kind);
|
|
let allow_once = tool_call_authorization
|
|
.options
|
|
.first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
|
|
.map(|option| option.kind);
|
|
|
|
assert_eq!(
|
|
permission_kinds,
|
|
Some(acp::PermissionOptionKind::AllowAlways)
|
|
);
|
|
assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
|
|
return tool_call_authorization;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_terminal_with_pattern() {
|
|
let permission_options = ToolPermissionContext::new(
|
|
TerminalTool::NAME,
|
|
vec!["cargo build --release".to_string()],
|
|
)
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
assert_eq!(choices.len(), 3);
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for terminal"));
|
|
assert!(labels.contains(&"Always for `cargo build` commands"));
|
|
assert!(labels.contains(&"Only this time"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_terminal_command_with_flag_second_token() {
|
|
let permission_options =
|
|
ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
assert_eq!(choices.len(), 3);
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for terminal"));
|
|
assert!(labels.contains(&"Always for `ls` commands"));
|
|
assert!(labels.contains(&"Only this time"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_terminal_single_word_command() {
|
|
let permission_options =
|
|
ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
assert_eq!(choices.len(), 3);
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for terminal"));
|
|
assert!(labels.contains(&"Always for `whoami` commands"));
|
|
assert!(labels.contains(&"Only this time"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_edit_file_with_path_pattern() {
|
|
let permission_options =
|
|
ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for edit file"));
|
|
assert!(labels.contains(&"Always for `src/`"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_fetch_with_domain_pattern() {
|
|
let permission_options =
|
|
ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for fetch"));
|
|
assert!(labels.contains(&"Always for `docs.rs`"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_without_pattern() {
|
|
let permission_options = ToolPermissionContext::new(
|
|
TerminalTool::NAME,
|
|
vec!["./deploy.sh --production".to_string()],
|
|
)
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
assert_eq!(choices.len(), 2);
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for terminal"));
|
|
assert!(labels.contains(&"Only this time"));
|
|
assert!(!labels.iter().any(|label| label.contains("commands")));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_symlink_target_are_flat_once_only() {
|
|
let permission_options =
|
|
ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Flat(options) = permission_options else {
|
|
panic!("Expected flat permission options for symlink target authorization");
|
|
};
|
|
|
|
assert_eq!(options.len(), 2);
|
|
assert!(options.iter().any(|option| {
|
|
option.option_id.0.as_ref() == "allow"
|
|
&& option.kind == acp::PermissionOptionKind::AllowOnce
|
|
}));
|
|
assert!(options.iter().any(|option| {
|
|
option.option_id.0.as_ref() == "deny"
|
|
&& option.kind == acp::PermissionOptionKind::RejectOnce
|
|
}));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_option_ids_for_terminal() {
|
|
let permission_options = ToolPermissionContext::new(
|
|
TerminalTool::NAME,
|
|
vec!["cargo build --release".to_string()],
|
|
)
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::Dropdown(choices) = permission_options else {
|
|
panic!("Expected dropdown permission options");
|
|
};
|
|
|
|
// Expect 3 choices: always-tool, always-pattern, once
|
|
assert_eq!(choices.len(), 3);
|
|
|
|
// First two choices both use the tool-level option IDs
|
|
assert_eq!(
|
|
choices[0].allow.option_id.0.as_ref(),
|
|
"always_allow:terminal"
|
|
);
|
|
assert_eq!(choices[0].deny.option_id.0.as_ref(), "always_deny:terminal");
|
|
assert!(choices[0].sub_patterns.is_empty());
|
|
|
|
assert_eq!(
|
|
choices[1].allow.option_id.0.as_ref(),
|
|
"always_allow:terminal"
|
|
);
|
|
assert_eq!(choices[1].deny.option_id.0.as_ref(), "always_deny:terminal");
|
|
assert_eq!(choices[1].sub_patterns, vec!["^cargo\\s+build(\\s|$)"]);
|
|
|
|
// Third choice is the one-time allow/deny
|
|
assert_eq!(choices[2].allow.option_id.0.as_ref(), "allow");
|
|
assert_eq!(choices[2].deny.option_id.0.as_ref(), "deny");
|
|
assert!(choices[2].sub_patterns.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_terminal_pipeline_produces_dropdown_with_patterns() {
|
|
let permission_options = ToolPermissionContext::new(
|
|
TerminalTool::NAME,
|
|
vec!["cargo test 2>&1 | tail".to_string()],
|
|
)
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::DropdownWithPatterns {
|
|
choices,
|
|
patterns,
|
|
tool_name,
|
|
} = permission_options
|
|
else {
|
|
panic!("Expected DropdownWithPatterns permission options for pipeline command");
|
|
};
|
|
|
|
assert_eq!(tool_name, TerminalTool::NAME);
|
|
|
|
// Should have "Always for terminal" and "Only this time" choices
|
|
assert_eq!(choices.len(), 2);
|
|
let labels: Vec<&str> = choices
|
|
.iter()
|
|
.map(|choice| choice.allow.name.as_ref())
|
|
.collect();
|
|
assert!(labels.contains(&"Always for terminal"));
|
|
assert!(labels.contains(&"Only this time"));
|
|
|
|
// Should have per-command patterns for "cargo test" and "tail"
|
|
assert_eq!(patterns.len(), 2);
|
|
let pattern_names: Vec<&str> = patterns.iter().map(|cp| cp.display_name.as_str()).collect();
|
|
assert!(pattern_names.contains(&"cargo test"));
|
|
assert!(pattern_names.contains(&"tail"));
|
|
|
|
// Verify patterns are valid regex patterns
|
|
let regex_patterns: Vec<&str> = patterns.iter().map(|cp| cp.pattern.as_str()).collect();
|
|
assert!(regex_patterns.contains(&"^cargo\\s+test(\\s|$)"));
|
|
assert!(regex_patterns.contains(&"^tail\\b"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_options_terminal_pipeline_with_chaining() {
|
|
let permission_options = ToolPermissionContext::new(
|
|
TerminalTool::NAME,
|
|
vec!["npm install && npm test | tail".to_string()],
|
|
)
|
|
.build_permission_options();
|
|
|
|
let PermissionOptions::DropdownWithPatterns { patterns, .. } = permission_options else {
|
|
panic!("Expected DropdownWithPatterns for chained pipeline command");
|
|
};
|
|
|
|
// With subcommand-aware patterns, "npm install" and "npm test" are distinct
|
|
assert_eq!(patterns.len(), 3);
|
|
let pattern_names: Vec<&str> = patterns.iter().map(|cp| cp.display_name.as_str()).collect();
|
|
assert!(pattern_names.contains(&"npm install"));
|
|
assert!(pattern_names.contains(&"npm test"));
|
|
assert!(pattern_names.contains(&"tail"));
|
|
}
|
|
|
|
#[gpui::test]
|
|
#[cfg_attr(not(feature = "e2e"), ignore)]
|
|
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
|
|
|
// Test concurrent tool calls with different delay times
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(DelayTool);
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
[
|
|
"Call the delay tool twice in the same message.",
|
|
"Once with 100ms. Once with 300ms.",
|
|
"When both timers are complete, describe the outputs.",
|
|
],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap()
|
|
.collect()
|
|
.await;
|
|
|
|
let stop_reasons = stop_events(events);
|
|
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
let last_message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = last_message.as_agent_message().unwrap();
|
|
let text = agent_message
|
|
.content
|
|
.iter()
|
|
.filter_map(|content| {
|
|
if let AgentMessageContent::Text(text) = content {
|
|
Some(text.as_str())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect::<String>();
|
|
|
|
assert!(text.contains("Ding"));
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_profiles(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model, thread, fs, ..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(DelayTool);
|
|
thread.add_tool(EchoTool);
|
|
thread.add_tool(InfiniteTool);
|
|
});
|
|
|
|
// Override profiles and wait for settings to be loaded.
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"profiles": {
|
|
"test-1": {
|
|
"name": "Test Profile 1",
|
|
"tools": {
|
|
EchoTool::NAME: true,
|
|
DelayTool::NAME: true,
|
|
}
|
|
},
|
|
"test-2": {
|
|
"name": "Test Profile 2",
|
|
"tools": {
|
|
InfiniteTool::NAME: true,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
cx.run_until_parked();
|
|
|
|
// Test that test-1 profile (default) has echo and delay tools
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("test-1".into()), cx);
|
|
thread.send(UserMessageId::new(), ["test"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let mut pending_completions = fake_model.pending_completions();
|
|
assert_eq!(pending_completions.len(), 1);
|
|
let completion = pending_completions.pop().unwrap();
|
|
let tool_names: Vec<String> = completion
|
|
.tools
|
|
.iter()
|
|
.map(|tool| tool.name.clone())
|
|
.collect();
|
|
assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("test-2".into()), cx);
|
|
thread.send(UserMessageId::new(), ["test2"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
let mut pending_completions = fake_model.pending_completions();
|
|
assert_eq!(pending_completions.len(), 1);
|
|
let completion = pending_completions.pop().unwrap();
|
|
let tool_names: Vec<String> = completion
|
|
.tools
|
|
.iter()
|
|
.map(|tool| tool.name.clone())
|
|
.collect();
|
|
assert_eq!(tool_names, vec![InfiniteTool::NAME]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_mcp_tools(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model,
|
|
thread,
|
|
context_server_store,
|
|
fs,
|
|
..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// Override profiles and wait for settings to be loaded.
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"tool_permissions": { "default": "allow" },
|
|
"profiles": {
|
|
"test": {
|
|
"name": "Test Profile",
|
|
"enable_all_context_servers": true,
|
|
"tools": {
|
|
EchoTool::NAME: true,
|
|
}
|
|
},
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
cx.run_until_parked();
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("test".into()), cx)
|
|
});
|
|
|
|
let mut mcp_tool_calls = setup_context_server(
|
|
"test_server",
|
|
vec![context_server::types::Tool {
|
|
name: "echo".into(),
|
|
title: None,
|
|
description: None,
|
|
input_schema: serde_json::to_value(EchoTool::input_schema(
|
|
LanguageModelToolSchemaFormat::JsonSchema,
|
|
))
|
|
.unwrap(),
|
|
output_schema: None,
|
|
annotations: None,
|
|
}],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
|
|
let events = thread.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling the MCP tool.
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: "echo".into(),
|
|
raw_input: json!({"text": "test"}).to_string(),
|
|
input: json!({"text": "test"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
|
assert_eq!(tool_call_params.name, "echo");
|
|
assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
|
|
tool_call_response
|
|
.send(context_server::types::CallToolResponse {
|
|
content: vec![context_server::types::ToolResponseContent::Text {
|
|
text: "test".into(),
|
|
}],
|
|
is_error: None,
|
|
meta: None,
|
|
structured_content: None,
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
|
|
fake_model.send_last_completion_stream_text_chunk("Done!");
|
|
fake_model.end_last_completion_stream();
|
|
events.collect::<Vec<_>>().await;
|
|
|
|
// Send again after adding the echo tool, ensuring the name collision is resolved.
|
|
let events = thread.update(cx, |thread, cx| {
|
|
thread.add_tool(EchoTool);
|
|
thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
|
|
});
|
|
cx.run_until_parked();
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
tool_names_for_completion(&completion),
|
|
vec!["echo", "test_server_echo"]
|
|
);
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_2".into(),
|
|
name: "test_server_echo".into(),
|
|
raw_input: json!({"text": "mcp"}).to_string(),
|
|
input: json!({"text": "mcp"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_3".into(),
|
|
name: "echo".into(),
|
|
raw_input: json!({"text": "native"}).to_string(),
|
|
input: json!({"text": "native"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
|
assert_eq!(tool_call_params.name, "echo");
|
|
assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
|
|
tool_call_response
|
|
.send(context_server::types::CallToolResponse {
|
|
content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
|
|
is_error: None,
|
|
meta: None,
|
|
structured_content: None,
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Ensure the tool results were inserted with the correct names.
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
completion.messages.last().unwrap().content,
|
|
vec![
|
|
MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: "tool_3".into(),
|
|
tool_name: "echo".into(),
|
|
is_error: false,
|
|
content: vec!["native".into()],
|
|
output: Some("native".into()),
|
|
},),
|
|
MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: "tool_2".into(),
|
|
tool_name: "test_server_echo".into(),
|
|
is_error: false,
|
|
content: vec!["mcp".into()],
|
|
output: Some("mcp".into()),
|
|
},),
|
|
]
|
|
);
|
|
fake_model.end_last_completion_stream();
|
|
events.collect::<Vec<_>>().await;
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_mcp_tool_multi_content_response(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model,
|
|
thread,
|
|
context_server_store,
|
|
fs,
|
|
..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
fake_model.set_supports_images(true);
|
|
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"tool_permissions": { "default": "allow" },
|
|
"profiles": {
|
|
"test": {
|
|
"name": "Test Profile",
|
|
"enable_all_context_servers": true,
|
|
"tools": {}
|
|
},
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
cx.run_until_parked();
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("test".into()), cx)
|
|
});
|
|
|
|
let mut mcp_tool_calls = setup_context_server(
|
|
"screenshot_server",
|
|
vec![context_server::types::Tool {
|
|
name: "screenshot".into(),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
}],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
|
|
let events = thread.update(cx, |thread, cx| {
|
|
thread
|
|
.send(UserMessageId::new(), ["Take a screenshot"], cx)
|
|
.unwrap()
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: "screenshot".into(),
|
|
raw_input: json!({}).to_string(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
let _ = completion;
|
|
|
|
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
|
assert_eq!(tool_call_params.name, "screenshot");
|
|
tool_call_response
|
|
.send(context_server::types::CallToolResponse {
|
|
content: vec![
|
|
context_server::types::ToolResponseContent::Text {
|
|
text: "Some text".into(),
|
|
},
|
|
context_server::types::ToolResponseContent::Image {
|
|
data: "aGVsbG8=".into(),
|
|
mime_type: "image/png".into(),
|
|
},
|
|
context_server::types::ToolResponseContent::Text {
|
|
text: "Some more text".into(),
|
|
},
|
|
],
|
|
is_error: None,
|
|
meta: None,
|
|
structured_content: None,
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Verify the tool result round-trips back to the model as a multi-part Vec.
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let tool_result = completion
|
|
.messages
|
|
.last()
|
|
.unwrap()
|
|
.content
|
|
.iter()
|
|
.find_map(|c| match c {
|
|
MessageContent::ToolResult(r) => Some(r.clone()),
|
|
_ => None,
|
|
})
|
|
.expect("expected a tool result");
|
|
assert_eq!(tool_result.tool_use_id, "tool_1".into());
|
|
assert_eq!(tool_result.content.len(), 2);
|
|
assert_eq!(
|
|
tool_result.content[0],
|
|
language_model::LanguageModelToolResultContent::Text(Arc::from("Some text"))
|
|
);
|
|
assert_eq!(
|
|
tool_result.content[1],
|
|
language_model::LanguageModelToolResultContent::Text(Arc::from("Some more text"))
|
|
);
|
|
fake_model.end_last_completion_stream();
|
|
events.collect::<Vec<_>>().await;
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model,
|
|
thread,
|
|
context_server_store,
|
|
fs,
|
|
..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// Setup settings to allow MCP tools
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"always_allow_tool_actions": true,
|
|
"profiles": {
|
|
"test": {
|
|
"name": "Test Profile",
|
|
"enable_all_context_servers": true,
|
|
"tools": {}
|
|
},
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
cx.run_until_parked();
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("test".into()), cx)
|
|
});
|
|
|
|
// Setup a context server with a tool
|
|
let mut mcp_tool_calls = setup_context_server(
|
|
"github_server",
|
|
vec![context_server::types::Tool {
|
|
name: "issue_read".into(),
|
|
title: None,
|
|
description: Some("Read a GitHub issue".into()),
|
|
input_schema: json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"issue_url": { "type": "string" }
|
|
}
|
|
}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
}],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
|
|
// Send a message and have the model call the MCP tool
|
|
let events = thread.update(cx, |thread, cx| {
|
|
thread
|
|
.send(UserMessageId::new(), ["Read issue #47404"], cx)
|
|
.unwrap()
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// Verify the MCP tool is available to the model
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
tool_names_for_completion(&completion),
|
|
vec!["issue_read"],
|
|
"MCP tool should be available"
|
|
);
|
|
|
|
// Simulate the model calling the MCP tool
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: "issue_read".into(),
|
|
raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
|
|
.to_string(),
|
|
input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// The MCP server receives the tool call and responds with content
|
|
let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
|
|
let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
|
|
assert_eq!(tool_call_params.name, "issue_read");
|
|
tool_call_response
|
|
.send(context_server::types::CallToolResponse {
|
|
content: vec![context_server::types::ToolResponseContent::Text {
|
|
text: expected_tool_output.into(),
|
|
}],
|
|
is_error: None,
|
|
meta: None,
|
|
structured_content: None,
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// After tool completes, the model continues with a new completion request
|
|
// that includes the tool results. We need to respond to this.
|
|
let _completion = fake_model.pending_completions().pop().unwrap();
|
|
fake_model.send_last_completion_stream_text_chunk("I found the issue!");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
events.collect::<Vec<_>>().await;
|
|
|
|
// Verify the tool result is stored in the thread by checking the markdown output.
|
|
// The tool result is in the first assistant message (not the last one, which is
|
|
// the model's response after the tool completed).
|
|
thread.update(cx, |thread, _cx| {
|
|
let markdown = thread.to_markdown();
|
|
assert!(
|
|
markdown.contains("**Tool Result**: issue_read"),
|
|
"Thread should contain tool result header"
|
|
);
|
|
assert!(
|
|
markdown.contains(expected_tool_output),
|
|
"Thread should contain tool output: {}",
|
|
expected_tool_output
|
|
);
|
|
});
|
|
|
|
// Simulate app restart: disconnect the MCP server.
|
|
// After restart, the MCP server won't be connected yet when the thread is replayed.
|
|
context_server_store.update(cx, |store, cx| {
|
|
let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// Replay the thread (this is what happens when loading a saved thread)
|
|
let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
|
|
|
|
let mut found_tool_call = None;
|
|
let mut found_tool_call_update_with_output = None;
|
|
|
|
while let Some(event) = replay_events.next().await {
|
|
let event = event.unwrap();
|
|
match &event {
|
|
ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
|
|
found_tool_call = Some(tc.clone());
|
|
}
|
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
|
|
if update.tool_call_id.to_string() == "tool_1" =>
|
|
{
|
|
if update.fields.raw_output.is_some() {
|
|
found_tool_call_update_with_output = Some(update.clone());
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
// The tool call should be found
|
|
assert!(
|
|
found_tool_call.is_some(),
|
|
"Tool call should be emitted during replay"
|
|
);
|
|
|
|
assert!(
|
|
found_tool_call_update_with_output.is_some(),
|
|
"ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
|
|
);
|
|
|
|
let update = found_tool_call_update_with_output.unwrap();
|
|
assert_eq!(
|
|
update.fields.raw_output,
|
|
Some(expected_tool_output.into()),
|
|
"raw_output should contain the saved tool result"
|
|
);
|
|
|
|
// Also verify the status is correct (completed, not failed)
|
|
assert_eq!(
|
|
update.fields.status,
|
|
Some(acp::ToolCallStatus::Completed),
|
|
"Tool call status should reflect the original completion status"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model,
|
|
thread,
|
|
context_server_store,
|
|
fs,
|
|
..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// Set up a profile with all tools enabled
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"profiles": {
|
|
"test": {
|
|
"name": "Test Profile",
|
|
"enable_all_context_servers": true,
|
|
"tools": {
|
|
EchoTool::NAME: true,
|
|
DelayTool::NAME: true,
|
|
WordListTool::NAME: true,
|
|
ToolRequiringPermission::NAME: true,
|
|
InfiniteTool::NAME: true,
|
|
}
|
|
},
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
cx.run_until_parked();
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("test".into()), cx);
|
|
thread.add_tool(EchoTool);
|
|
thread.add_tool(DelayTool);
|
|
thread.add_tool(WordListTool);
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.add_tool(InfiniteTool);
|
|
});
|
|
|
|
// Set up multiple context servers with some overlapping tool names
|
|
let _server1_calls = setup_context_server(
|
|
"xxx",
|
|
vec![
|
|
context_server::types::Tool {
|
|
name: "echo".into(), // Conflicts with native EchoTool
|
|
title: None,
|
|
description: None,
|
|
input_schema: serde_json::to_value(EchoTool::input_schema(
|
|
LanguageModelToolSchemaFormat::JsonSchema,
|
|
))
|
|
.unwrap(),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
context_server::types::Tool {
|
|
name: "unique_tool_1".into(),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
|
|
let _server2_calls = setup_context_server(
|
|
"yyy",
|
|
vec![
|
|
context_server::types::Tool {
|
|
name: "echo".into(), // Also conflicts with native EchoTool
|
|
title: None,
|
|
description: None,
|
|
input_schema: serde_json::to_value(EchoTool::input_schema(
|
|
LanguageModelToolSchemaFormat::JsonSchema,
|
|
))
|
|
.unwrap(),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
context_server::types::Tool {
|
|
name: "unique_tool_2".into(),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
context_server::types::Tool {
|
|
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
context_server::types::Tool {
|
|
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
let _server3_calls = setup_context_server(
|
|
"zzz",
|
|
vec![
|
|
context_server::types::Tool {
|
|
name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
context_server::types::Tool {
|
|
name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
context_server::types::Tool {
|
|
name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
|
|
title: None,
|
|
description: None,
|
|
input_schema: json!({"type": "object", "properties": {}}),
|
|
output_schema: None,
|
|
annotations: None,
|
|
},
|
|
],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
|
|
// Server with spaces in name - tests snake_case conversion for API compatibility
|
|
let _server4_calls = setup_context_server(
|
|
"Azure DevOps",
|
|
vec![context_server::types::Tool {
|
|
name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
|
|
title: None,
|
|
description: None,
|
|
input_schema: serde_json::to_value(EchoTool::input_schema(
|
|
LanguageModelToolSchemaFormat::JsonSchema,
|
|
))
|
|
.unwrap(),
|
|
output_schema: None,
|
|
annotations: None,
|
|
}],
|
|
&context_server_store,
|
|
cx,
|
|
);
|
|
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Go"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
tool_names_for_completion(&completion),
|
|
vec![
|
|
"azure_dev_ops_echo",
|
|
"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
|
|
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
|
|
"delay",
|
|
"echo",
|
|
"infinite",
|
|
"tool_requiring_permission",
|
|
"unique_tool_1",
|
|
"unique_tool_2",
|
|
"word_list",
|
|
"xxx_echo",
|
|
"y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
|
"yyy_echo",
|
|
"z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
#[cfg_attr(not(feature = "e2e"), ignore)]
|
|
async fn test_cancellation(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(InfiniteTool);
|
|
thread.add_tool(EchoTool);
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Call the echo tool, then call the infinite tool, then explain their output"],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
|
|
// Wait until both tools are called.
|
|
let mut expected_tools = vec!["Echo", "Infinite Tool"];
|
|
let mut echo_id = None;
|
|
let mut echo_completed = false;
|
|
while let Some(event) = events.next().await {
|
|
match event.unwrap() {
|
|
ThreadEvent::ToolCall(tool_call) => {
|
|
assert_eq!(tool_call.title, expected_tools.remove(0));
|
|
if tool_call.title == "Echo" {
|
|
echo_id = Some(tool_call.tool_call_id);
|
|
}
|
|
}
|
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
|
acp::ToolCallUpdate {
|
|
tool_call_id,
|
|
fields:
|
|
acp::ToolCallUpdateFields {
|
|
status: Some(acp::ToolCallStatus::Completed),
|
|
..
|
|
},
|
|
..
|
|
},
|
|
)) if Some(&tool_call_id) == echo_id.as_ref() => {
|
|
echo_completed = true;
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
if expected_tools.is_empty() && echo_completed {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Cancel the current send and ensure that the event stream is closed, even
|
|
// if one of the tools is still running.
|
|
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
|
let events = events.collect::<Vec<_>>().await;
|
|
let last_event = events.last();
|
|
assert!(
|
|
matches!(
|
|
last_event,
|
|
Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
|
|
),
|
|
"unexpected event {last_event:?}"
|
|
);
|
|
|
|
// Ensure we can still send a new message after cancellation.
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Testing: reply with 'Hello' then stop."],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap()
|
|
.collect::<Vec<_>>()
|
|
.await;
|
|
thread.update(cx, |thread, _cx| {
|
|
let message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = message.as_agent_message().unwrap();
|
|
assert_eq!(
|
|
agent_message.content,
|
|
vec![AgentMessageContent::Text("Hello".to_string())]
|
|
);
|
|
});
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
always_allow_tools(cx);
|
|
let fake_model = model.as_fake();
|
|
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
let handle = environment.terminal_handle.clone().unwrap();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(crate::TerminalTool::new(
|
|
thread.project().clone(),
|
|
environment,
|
|
));
|
|
thread.send(UserMessageId::new(), ["run a command"], cx)
|
|
})
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling the terminal tool
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "terminal_tool_1".into(),
|
|
name: TerminalTool::NAME.into(),
|
|
raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
|
|
input: json!({"command": "sleep 1000", "cd": "."}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Wait for the terminal tool to start running
|
|
wait_for_terminal_tool_started(&mut events, cx).await;
|
|
|
|
// Cancel the thread while the terminal is running
|
|
thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
|
|
|
|
// Collect remaining events, driving the executor to let cancellation complete
|
|
let remaining_events = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify the terminal was killed
|
|
assert!(
|
|
handle.was_killed(),
|
|
"expected terminal handle to be killed on cancellation"
|
|
);
|
|
|
|
// Verify we got a cancellation stop event
|
|
assert_eq!(
|
|
stop_events(remaining_events),
|
|
vec![acp::StopReason::Cancelled],
|
|
);
|
|
|
|
// Verify the tool result contains the terminal output, not just "Tool canceled by user"
|
|
thread.update(cx, |thread, _cx| {
|
|
let message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = message.as_agent_message().unwrap();
|
|
|
|
let tool_use = agent_message
|
|
.content
|
|
.iter()
|
|
.find_map(|content| match content {
|
|
AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
|
|
_ => None,
|
|
})
|
|
.expect("expected tool use in agent message");
|
|
|
|
let tool_result = agent_message
|
|
.tool_results
|
|
.get(&tool_use.id)
|
|
.expect("expected tool result");
|
|
|
|
let result_text = tool_result.text_contents();
|
|
|
|
// "partial output" comes from FakeTerminalHandle's output field
|
|
assert!(
|
|
result_text.contains("partial output"),
|
|
"expected tool result to contain terminal output, got: {result_text}"
|
|
);
|
|
// Match the actual format from process_content in terminal_tool.rs
|
|
assert!(
|
|
result_text.contains("The user stopped this command"),
|
|
"expected tool result to indicate user stopped, got: {result_text}"
|
|
);
|
|
});
|
|
|
|
// Verify we can send a new message after cancellation
|
|
verify_thread_recovery(&thread, &fake_model, cx).await;
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
|
|
// This test verifies that tools which properly handle cancellation via
|
|
// `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
|
|
// to cancellation and report that they were cancelled.
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
always_allow_tools(cx);
|
|
let fake_model = model.as_fake();
|
|
|
|
let (tool, was_cancelled) = CancellationAwareTool::new();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(tool);
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["call the cancellation aware tool"],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling the cancellation-aware tool
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "cancellation_aware_1".into(),
|
|
name: "cancellation_aware".into(),
|
|
raw_input: r#"{}"#.into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Wait for the tool call to be reported
|
|
let mut tool_started = false;
|
|
let deadline = cx.executor().num_cpus() * 100;
|
|
for _ in 0..deadline {
|
|
cx.run_until_parked();
|
|
|
|
while let Some(Some(event)) = events.next().now_or_never() {
|
|
if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
|
|
if tool_call.title == "Cancellation Aware Tool" {
|
|
tool_started = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if tool_started {
|
|
break;
|
|
}
|
|
|
|
cx.background_executor
|
|
.timer(Duration::from_millis(10))
|
|
.await;
|
|
}
|
|
assert!(tool_started, "expected cancellation aware tool to start");
|
|
|
|
// Cancel the thread and wait for it to complete
|
|
let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
|
|
|
|
// The cancel task should complete promptly because the tool handles cancellation
|
|
let timeout = cx.background_executor.timer(Duration::from_secs(5));
|
|
futures::select! {
|
|
_ = cancel_task.fuse() => {}
|
|
_ = timeout.fuse() => {
|
|
panic!("cancel task timed out - tool did not respond to cancellation");
|
|
}
|
|
}
|
|
|
|
// Verify the tool detected cancellation via its flag
|
|
assert!(
|
|
was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
|
|
"tool should have detected cancellation via event_stream.cancelled_by_user()"
|
|
);
|
|
|
|
// Collect remaining events
|
|
let remaining_events = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify we got a cancellation stop event
|
|
assert_eq!(
|
|
stop_events(remaining_events),
|
|
vec![acp::StopReason::Cancelled],
|
|
);
|
|
|
|
// Verify we can send a new message after cancellation
|
|
verify_thread_recovery(&thread, &fake_model, cx).await;
|
|
}
|
|
|
|
/// Helper to verify thread can recover after cancellation by sending a simple message.
|
|
async fn verify_thread_recovery(
|
|
thread: &Entity<Thread>,
|
|
fake_model: &FakeLanguageModel,
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Testing: reply with 'Hello' then stop."],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hello");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let events = events.collect::<Vec<_>>().await;
|
|
thread.update(cx, |thread, _cx| {
|
|
let message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = message.as_agent_message().unwrap();
|
|
assert_eq!(
|
|
agent_message.content,
|
|
vec![AgentMessageContent::Text("Hello".to_string())]
|
|
);
|
|
});
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
|
|
async fn wait_for_terminal_tool_started(
|
|
events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
|
|
for _ in 0..deadline {
|
|
cx.run_until_parked();
|
|
|
|
while let Some(Some(event)) = events.next().now_or_never() {
|
|
if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
|
update,
|
|
))) = &event
|
|
{
|
|
if update.fields.content.as_ref().is_some_and(|content| {
|
|
content
|
|
.iter()
|
|
.any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
|
|
}) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
cx.background_executor
|
|
.timer(Duration::from_millis(10))
|
|
.await;
|
|
}
|
|
panic!("terminal tool did not start within the expected time");
|
|
}
|
|
|
|
/// Collects events until a Stop event is received, driving the executor to completion.
|
|
async fn collect_events_until_stop(
|
|
events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
|
cx: &mut TestAppContext,
|
|
) -> Vec<Result<ThreadEvent>> {
|
|
let mut collected = Vec::new();
|
|
let deadline = cx.executor().num_cpus() * 200;
|
|
|
|
for _ in 0..deadline {
|
|
cx.executor().advance_clock(Duration::from_millis(10));
|
|
cx.run_until_parked();
|
|
|
|
while let Some(Some(event)) = events.next().now_or_never() {
|
|
let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
|
|
collected.push(event);
|
|
if is_stop {
|
|
return collected;
|
|
}
|
|
}
|
|
}
|
|
panic!(
|
|
"did not receive Stop event within the expected time; collected {} events",
|
|
collected.len()
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
always_allow_tools(cx);
|
|
let fake_model = model.as_fake();
|
|
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
let handle = environment.terminal_handle.clone().unwrap();
|
|
|
|
let message_id = UserMessageId::new();
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(crate::TerminalTool::new(
|
|
thread.project().clone(),
|
|
environment,
|
|
));
|
|
thread.send(message_id.clone(), ["run a command"], cx)
|
|
})
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling the terminal tool
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "terminal_tool_1".into(),
|
|
name: TerminalTool::NAME.into(),
|
|
raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
|
|
input: json!({"command": "sleep 1000", "cd": "."}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Wait for the terminal tool to start running
|
|
wait_for_terminal_tool_started(&mut events, cx).await;
|
|
|
|
// Truncate the thread while the terminal is running
|
|
thread
|
|
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
|
.unwrap();
|
|
|
|
// Drive the executor to let cancellation complete
|
|
let _ = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify the terminal was killed
|
|
assert!(
|
|
handle.was_killed(),
|
|
"expected terminal handle to be killed on truncate"
|
|
);
|
|
|
|
// Verify the thread is empty after truncation
|
|
thread.update(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
"",
|
|
"expected thread to be empty after truncating the only message"
|
|
);
|
|
});
|
|
|
|
// Verify we can send a new message after truncation
|
|
verify_thread_recovery(&thread, &fake_model, cx).await;
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
|
|
// Tests that cancellation properly kills all running terminal tools when multiple are active.
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
always_allow_tools(cx);
|
|
let fake_model = model.as_fake();
|
|
|
|
let environment = Rc::new(MultiTerminalEnvironment::new());
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(crate::TerminalTool::new(
|
|
thread.project().clone(),
|
|
environment.clone(),
|
|
));
|
|
thread.send(UserMessageId::new(), ["run multiple commands"], cx)
|
|
})
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling two terminal tools
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "terminal_tool_1".into(),
|
|
name: TerminalTool::NAME.into(),
|
|
raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
|
|
input: json!({"command": "sleep 1000", "cd": "."}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "terminal_tool_2".into(),
|
|
name: TerminalTool::NAME.into(),
|
|
raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
|
|
input: json!({"command": "sleep 2000", "cd": "."}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Wait for both terminal tools to start by counting terminal content updates
|
|
let mut terminals_started = 0;
|
|
let deadline = cx.executor().num_cpus() * 100;
|
|
for _ in 0..deadline {
|
|
cx.run_until_parked();
|
|
|
|
while let Some(Some(event)) = events.next().now_or_never() {
|
|
if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
|
update,
|
|
))) = &event
|
|
{
|
|
if update.fields.content.as_ref().is_some_and(|content| {
|
|
content
|
|
.iter()
|
|
.any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
|
|
}) {
|
|
terminals_started += 1;
|
|
if terminals_started >= 2 {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if terminals_started >= 2 {
|
|
break;
|
|
}
|
|
|
|
cx.background_executor
|
|
.timer(Duration::from_millis(10))
|
|
.await;
|
|
}
|
|
assert!(
|
|
terminals_started >= 2,
|
|
"expected 2 terminal tools to start, got {terminals_started}"
|
|
);
|
|
|
|
// Cancel the thread while both terminals are running
|
|
thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
|
|
|
|
// Collect remaining events
|
|
let remaining_events = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify both terminal handles were killed
|
|
let handles = environment.handles();
|
|
assert_eq!(
|
|
handles.len(),
|
|
2,
|
|
"expected 2 terminal handles to be created"
|
|
);
|
|
assert!(
|
|
handles[0].was_killed(),
|
|
"expected first terminal handle to be killed on cancellation"
|
|
);
|
|
assert!(
|
|
handles[1].was_killed(),
|
|
"expected second terminal handle to be killed on cancellation"
|
|
);
|
|
|
|
// Verify we got a cancellation stop event
|
|
assert_eq!(
|
|
stop_events(remaining_events),
|
|
vec![acp::StopReason::Cancelled],
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
|
|
// Tests that clicking the stop button on the terminal card (as opposed to the main
|
|
// cancel button) properly reports user stopped via the was_stopped_by_user path.
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
always_allow_tools(cx);
|
|
let fake_model = model.as_fake();
|
|
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
let handle = environment.terminal_handle.clone().unwrap();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(crate::TerminalTool::new(
|
|
thread.project().clone(),
|
|
environment,
|
|
));
|
|
thread.send(UserMessageId::new(), ["run a command"], cx)
|
|
})
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling the terminal tool
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "terminal_tool_1".into(),
|
|
name: TerminalTool::NAME.into(),
|
|
raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
|
|
input: json!({"command": "sleep 1000", "cd": "."}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Wait for the terminal tool to start running
|
|
wait_for_terminal_tool_started(&mut events, cx).await;
|
|
|
|
// Simulate user clicking stop on the terminal card itself.
|
|
// This sets the flag and signals exit (simulating what the real UI would do).
|
|
handle.set_stopped_by_user(true);
|
|
handle.killed.store(true, Ordering::SeqCst);
|
|
handle.signal_exit();
|
|
|
|
// Wait for the tool to complete
|
|
cx.run_until_parked();
|
|
|
|
// The thread continues after tool completion - simulate the model ending its turn
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Collect remaining events
|
|
let remaining_events = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
|
|
assert_eq!(
|
|
stop_events(remaining_events),
|
|
vec![acp::StopReason::EndTurn],
|
|
);
|
|
|
|
// Verify the tool result indicates user stopped
|
|
thread.update(cx, |thread, _cx| {
|
|
let message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = message.as_agent_message().unwrap();
|
|
|
|
let tool_use = agent_message
|
|
.content
|
|
.iter()
|
|
.find_map(|content| match content {
|
|
AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
|
|
_ => None,
|
|
})
|
|
.expect("expected tool use in agent message");
|
|
|
|
let tool_result = agent_message
|
|
.tool_results
|
|
.get(&tool_use.id)
|
|
.expect("expected tool result");
|
|
|
|
let result_text = tool_result.text_contents();
|
|
|
|
assert!(
|
|
result_text.contains("The user stopped this command"),
|
|
"expected tool result to indicate user stopped, got: {result_text}"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
|
|
// Tests that when a timeout is configured and expires, the tool result indicates timeout.
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
always_allow_tools(cx);
|
|
let fake_model = model.as_fake();
|
|
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
let handle = environment.terminal_handle.clone().unwrap();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(crate::TerminalTool::new(
|
|
thread.project().clone(),
|
|
environment,
|
|
));
|
|
thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
|
|
})
|
|
.unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model calling the terminal tool with a short timeout
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "terminal_tool_1".into(),
|
|
name: TerminalTool::NAME.into(),
|
|
raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
|
|
input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Wait for the terminal tool to start running
|
|
wait_for_terminal_tool_started(&mut events, cx).await;
|
|
|
|
// Advance clock past the timeout
|
|
cx.executor().advance_clock(Duration::from_millis(200));
|
|
cx.run_until_parked();
|
|
|
|
// The thread continues after tool completion - simulate the model ending its turn
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Collect remaining events
|
|
let remaining_events = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify the terminal was killed due to timeout
|
|
assert!(
|
|
handle.was_killed(),
|
|
"expected terminal handle to be killed on timeout"
|
|
);
|
|
|
|
// Verify we got an EndTurn (the tool completed, just with timeout)
|
|
assert_eq!(
|
|
stop_events(remaining_events),
|
|
vec![acp::StopReason::EndTurn],
|
|
);
|
|
|
|
// Verify the tool result indicates timeout, not user stopped
|
|
thread.update(cx, |thread, _cx| {
|
|
let message = thread.last_received_or_pending_message().unwrap();
|
|
let agent_message = message.as_agent_message().unwrap();
|
|
|
|
let tool_use = agent_message
|
|
.content
|
|
.iter()
|
|
.find_map(|content| match content {
|
|
AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
|
|
_ => None,
|
|
})
|
|
.expect("expected tool use in agent message");
|
|
|
|
let tool_result = agent_message
|
|
.tool_results
|
|
.get(&tool_use.id)
|
|
.expect("expected tool result");
|
|
|
|
let result_text = tool_result.text_contents();
|
|
|
|
assert!(
|
|
result_text.contains("timed out"),
|
|
"expected tool result to indicate timeout, got: {result_text}"
|
|
);
|
|
assert!(
|
|
!result_text.contains("The user stopped"),
|
|
"tool result should not mention user stopped when it timed out, got: {result_text}"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let events_1 = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
|
cx.run_until_parked();
|
|
|
|
let events_2 = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let events_1 = events_1.collect::<Vec<_>>().await;
|
|
assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
|
|
let events_2 = events_2.collect::<Vec<_>>().await;
|
|
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) {
|
|
// Regression test: when a completion fails with a retryable error (e.g. upstream 500),
|
|
// the retry loop waits on a timer. If the user switches models and sends a new message
|
|
// during that delay, the old turn should exit immediately instead of retrying with the
|
|
// stale model.
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let model_a = model.as_fake();
|
|
|
|
// Start a turn with model_a.
|
|
let events_1 = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
assert_eq!(model_a.completion_count(), 1);
|
|
|
|
// Model returns a retryable upstream 500. The turn enters the retry delay.
|
|
model_a.send_last_completion_stream_error(
|
|
LanguageModelCompletionError::UpstreamProviderError {
|
|
message: "Internal server error".to_string(),
|
|
status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
|
|
retry_after: None,
|
|
},
|
|
);
|
|
model_a.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// The old completion was consumed; model_a has no pending requests yet because the
|
|
// retry timer hasn't fired.
|
|
assert_eq!(model_a.completion_count(), 0);
|
|
|
|
// Switch to model_b and send a new message. This cancels the old turn.
|
|
let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"fake", "model-b", "Model B", false,
|
|
));
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model_b.clone(), cx);
|
|
});
|
|
let events_2 = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Continue"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// model_b should have received its completion request.
|
|
assert_eq!(model_b.as_fake().completion_count(), 1);
|
|
|
|
// Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s).
|
|
cx.executor().advance_clock(Duration::from_secs(10));
|
|
cx.run_until_parked();
|
|
|
|
// model_a must NOT have received another completion request — the cancelled turn
|
|
// should have exited during the retry delay rather than retrying with the old model.
|
|
assert_eq!(
|
|
model_a.completion_count(),
|
|
0,
|
|
"old model should not receive a retry request after cancellation"
|
|
);
|
|
|
|
// Complete model_b's turn.
|
|
model_b
|
|
.as_fake()
|
|
.send_last_completion_stream_text_chunk("Done!");
|
|
model_b
|
|
.as_fake()
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
model_b.as_fake().end_last_completion_stream();
|
|
|
|
let events_1 = events_1.collect::<Vec<_>>().await;
|
|
assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
|
|
|
|
let events_2 = events_2.collect::<Vec<_>>().await;
|
|
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let events_1 = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
let events_1 = events_1.collect::<Vec<_>>().await;
|
|
|
|
let events_2 = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
fake_model.end_last_completion_stream();
|
|
let events_2 = events_2.collect::<Vec<_>>().await;
|
|
|
|
assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
|
|
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_refusal(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hello
|
|
"}
|
|
);
|
|
});
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hello
|
|
|
|
## Assistant
|
|
|
|
Hey!
|
|
"}
|
|
);
|
|
});
|
|
|
|
// If the model refuses to continue, the thread should remove all the messages after the last user message.
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
|
|
let events = events.collect::<Vec<_>>().await;
|
|
assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.to_markdown(), "");
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let message_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(message_id.clone(), ["Hello"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hello
|
|
"}
|
|
);
|
|
assert_eq!(thread.latest_token_usage(), None);
|
|
});
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 32_000,
|
|
output_tokens: 16_000,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hello
|
|
|
|
## Assistant
|
|
|
|
Hey!
|
|
"}
|
|
);
|
|
assert_eq!(
|
|
thread.latest_token_usage(),
|
|
Some(acp_thread::TokenUsage {
|
|
used_tokens: 32_000 + 16_000,
|
|
max_tokens: 1_000_000,
|
|
max_output_tokens: None,
|
|
input_tokens: 32_000,
|
|
output_tokens: 16_000,
|
|
})
|
|
);
|
|
});
|
|
|
|
thread
|
|
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.to_markdown(), "");
|
|
assert_eq!(thread.latest_token_usage(), None);
|
|
});
|
|
|
|
// Ensure we can still send a new message after truncation.
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hi"], cx)
|
|
})
|
|
.unwrap();
|
|
thread.update(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hi
|
|
"}
|
|
);
|
|
});
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 40_000,
|
|
output_tokens: 20_000,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hi
|
|
|
|
## Assistant
|
|
|
|
Ahoy!
|
|
"}
|
|
);
|
|
|
|
assert_eq!(
|
|
thread.latest_token_usage(),
|
|
Some(acp_thread::TokenUsage {
|
|
used_tokens: 40_000 + 20_000,
|
|
max_tokens: 1_000_000,
|
|
max_output_tokens: None,
|
|
input_tokens: 40_000,
|
|
output_tokens: 20_000,
|
|
})
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 32_000,
|
|
output_tokens: 16_000,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let assert_first_message_state = |cx: &mut TestAppContext| {
|
|
thread.clone().read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Message 1
|
|
|
|
## Assistant
|
|
|
|
Message 1 response
|
|
"}
|
|
);
|
|
|
|
assert_eq!(
|
|
thread.latest_token_usage(),
|
|
Some(acp_thread::TokenUsage {
|
|
used_tokens: 32_000 + 16_000,
|
|
max_tokens: 1_000_000,
|
|
max_output_tokens: None,
|
|
input_tokens: 32_000,
|
|
output_tokens: 16_000,
|
|
})
|
|
);
|
|
});
|
|
};
|
|
|
|
assert_first_message_state(cx);
|
|
|
|
let second_message_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(second_message_id.clone(), ["Message 2"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 40_000,
|
|
output_tokens: 20_000,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Message 1
|
|
|
|
## Assistant
|
|
|
|
Message 1 response
|
|
|
|
## User
|
|
|
|
Message 2
|
|
|
|
## Assistant
|
|
|
|
Message 2 response
|
|
"}
|
|
);
|
|
|
|
assert_eq!(
|
|
thread.latest_token_usage(),
|
|
Some(acp_thread::TokenUsage {
|
|
used_tokens: 40_000 + 20_000,
|
|
max_tokens: 1_000_000,
|
|
max_output_tokens: None,
|
|
input_tokens: 40_000,
|
|
output_tokens: 20_000,
|
|
})
|
|
);
|
|
});
|
|
|
|
thread
|
|
.update(cx, |thread, cx| thread.truncate(second_message_id, cx))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
assert_first_message_state(cx);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_title_generation(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_summarization_model(Some(summary_model.clone()), cx)
|
|
});
|
|
|
|
let send = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), None));
|
|
|
|
// Ensure the summary model has been invoked to generate a title.
|
|
summary_model.send_last_completion_stream_text_chunk("Hello ");
|
|
summary_model.send_last_completion_stream_text_chunk("world\nG");
|
|
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
|
|
summary_model.end_last_completion_stream();
|
|
send.collect::<Vec<_>>().await;
|
|
cx.run_until_parked();
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.title(), Some("Hello world".into()))
|
|
});
|
|
|
|
// Send another message, ensuring no title is generated this time.
|
|
let send = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello again"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Hey again!");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
assert_eq!(summary_model.pending_completions(), Vec::new());
|
|
send.collect::<Vec<_>>().await;
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.title(), Some("Hello world".into()))
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_title_generation_failure_allows_retry(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
|
let fake_summary_model = summary_model.as_fake();
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_summarization_model(Some(summary_model.clone()), cx)
|
|
});
|
|
|
|
let send = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
fake_summary_model.send_last_completion_stream_error(
|
|
LanguageModelCompletionError::UpstreamProviderError {
|
|
message: "Internal server error".to_string(),
|
|
status: gpui::http_client::StatusCode::INTERNAL_SERVER_ERROR,
|
|
retry_after: None,
|
|
},
|
|
);
|
|
fake_summary_model.end_last_completion_stream();
|
|
send.collect::<Vec<_>>().await;
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.title(), None);
|
|
assert!(thread.has_failed_title_generation());
|
|
assert!(!thread.is_generating_title());
|
|
});
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.generate_title(cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert!(!thread.has_failed_title_generation());
|
|
assert!(thread.is_generating_title());
|
|
});
|
|
|
|
fake_summary_model.send_last_completion_stream_text_chunk("Retried title");
|
|
fake_summary_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.title(), Some("Retried title".into()));
|
|
assert!(!thread.has_failed_title_generation());
|
|
assert!(!thread.is_generating_title());
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let _events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.add_tool(EchoTool);
|
|
thread.send(UserMessageId::new(), ["Hey!"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let permission_tool_use = LanguageModelToolUse {
|
|
id: "tool_id_1".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
let echo_tool_use = LanguageModelToolUse {
|
|
id: "tool_id_2".into(),
|
|
name: EchoTool::NAME.into(),
|
|
raw_input: json!({"text": "test"}).to_string(),
|
|
input: json!({"text": "test"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
fake_model.send_last_completion_stream_text_chunk("Hi!");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
permission_tool_use,
|
|
));
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
echo_tool_use.clone(),
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// Ensure pending tools are skipped when building a request.
|
|
let request = thread
|
|
.read_with(cx, |thread, cx| {
|
|
thread.build_completion_request(CompletionIntent::EditFile, cx)
|
|
})
|
|
.unwrap();
|
|
assert_eq!(
|
|
request.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Hey!".into()],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec![
|
|
MessageContent::Text("Hi!".into()),
|
|
MessageContent::ToolUse(echo_tool_use.clone())
|
|
],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: echo_tool_use.id.clone(),
|
|
tool_name: echo_tool_use.name,
|
|
is_error: false,
|
|
content: vec!["test".into()],
|
|
output: Some("test".into())
|
|
})],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
],
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_agent_connection(cx: &mut TestAppContext) {
|
|
cx.update(settings::init);
|
|
let templates = Templates::new();
|
|
|
|
// Initialize language model system with test provider
|
|
cx.update(|cx| {
|
|
gpui_tokio::init(cx);
|
|
|
|
let http_client = FakeHttpClient::with_404_response();
|
|
let clock = Arc::new(clock::FakeSystemClock::new());
|
|
let client = Client::new(clock, http_client, cx);
|
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
language_model::init(cx);
|
|
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
|
language_models::init(user_store, client.clone(), cx);
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.executor().forbid_parking();
|
|
|
|
// Create a project for new_thread
|
|
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
|
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
|
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
|
|
let cwd = PathList::new(&[Path::new("/test")]);
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
|
|
// Create agent and connection
|
|
let agent = cx
|
|
.update(|cx| NativeAgent::new(thread_store, templates.clone(), None, fake_fs.clone(), cx));
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
|
|
// Create a thread using new_thread
|
|
let connection_rc = Rc::new(connection.clone());
|
|
let acp_thread = cx
|
|
.update(|cx| connection_rc.new_session(project, cwd, cx))
|
|
.await
|
|
.expect("new_thread should succeed");
|
|
|
|
// Get the session_id from the AcpThread
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
// Test model_selector returns Some
|
|
let selector_opt = connection.model_selector(&session_id);
|
|
assert!(
|
|
selector_opt.is_some(),
|
|
"agent should always support ModelSelector"
|
|
);
|
|
let selector = selector_opt.unwrap();
|
|
|
|
// Test list_models
|
|
let listed_models = cx
|
|
.update(|cx| selector.list_models(cx))
|
|
.await
|
|
.expect("list_models should succeed");
|
|
let AgentModelList::Grouped(listed_models) = listed_models else {
|
|
panic!("Unexpected model list type");
|
|
};
|
|
assert!(!listed_models.is_empty(), "should have at least one model");
|
|
assert_eq!(
|
|
listed_models[&AgentModelGroupName("Fake".into())][0]
|
|
.id
|
|
.0
|
|
.as_ref(),
|
|
"fake/fake"
|
|
);
|
|
|
|
// Test selected_model returns the default
|
|
let model = cx
|
|
.update(|cx| selector.selected_model(cx))
|
|
.await
|
|
.expect("selected_model should succeed");
|
|
let model = cx
|
|
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
|
|
.unwrap();
|
|
let model = model.as_fake();
|
|
assert_eq!(model.id().0, "fake", "should return default model");
|
|
|
|
let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("def");
|
|
cx.run_until_parked();
|
|
acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
indoc! {"
|
|
## User
|
|
|
|
abc
|
|
|
|
## Assistant
|
|
|
|
def
|
|
|
|
"}
|
|
)
|
|
});
|
|
|
|
// Test cancel
|
|
cx.update(|cx| connection.cancel(&session_id, cx));
|
|
request.await.expect("prompt should fail gracefully");
|
|
|
|
// Explicitly close the session and drop the ACP thread.
|
|
cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
|
|
.await
|
|
.unwrap();
|
|
drop(acp_thread);
|
|
let result = cx
|
|
.update(|cx| {
|
|
connection.prompt(
|
|
acp_thread::UserMessageId::new(),
|
|
acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
|
|
cx,
|
|
)
|
|
})
|
|
.await;
|
|
assert_eq!(
|
|
result.as_ref().unwrap_err().to_string(),
|
|
"Session not found",
|
|
"unexpected result: {:?}",
|
|
result
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
|
thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Echo something"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Simulate streaming partial input.
|
|
let input = json!({});
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "1".into(),
|
|
name: EchoTool::NAME.into(),
|
|
raw_input: input.to_string(),
|
|
input,
|
|
is_input_complete: false,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
|
|
// Input streaming completed
|
|
let input = json!({ "text": "Hello!" });
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "1".into(),
|
|
name: "echo".into(),
|
|
raw_input: input.to_string(),
|
|
input,
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let tool_call = expect_tool_call(&mut events).await;
|
|
assert_eq!(
|
|
tool_call,
|
|
acp::ToolCall::new("1", "Echo")
|
|
.raw_input(json!({}))
|
|
.meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
|
|
);
|
|
let update = expect_tool_call_update_fields(&mut events).await;
|
|
assert_eq!(
|
|
update,
|
|
acp::ToolCallUpdate::new(
|
|
"1",
|
|
acp::ToolCallUpdateFields::new()
|
|
.title("Echo")
|
|
.kind(acp::ToolKind::Other)
|
|
.raw_input(json!({ "text": "Hello!"}))
|
|
)
|
|
);
|
|
let update = expect_tool_call_update_fields(&mut events).await;
|
|
assert_eq!(
|
|
update,
|
|
acp::ToolCallUpdate::new(
|
|
"1",
|
|
acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
|
|
)
|
|
);
|
|
let update = expect_tool_call_update_fields(&mut events).await;
|
|
assert_eq!(
|
|
update,
|
|
acp::ToolCallUpdate::new(
|
|
"1",
|
|
acp::ToolCallUpdateFields::new()
|
|
.status(acp::ToolCallStatus::Completed)
|
|
.raw_output("Hello!")
|
|
)
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_update_plan_tool_updates_thread_events(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
|
thread.update(cx, |thread, _cx| thread.add_tool(UpdatePlanTool));
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Make a plan"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let input = json!({
|
|
"plan": [
|
|
{
|
|
"step": "Inspect the code",
|
|
"status": "completed",
|
|
},
|
|
{
|
|
"step": "Implement the tool",
|
|
"status": "in_progress"
|
|
},
|
|
{
|
|
"step": "Run tests",
|
|
"status": "pending",
|
|
}
|
|
]
|
|
});
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "plan_1".into(),
|
|
name: UpdatePlanTool::NAME.into(),
|
|
raw_input: input.to_string(),
|
|
input,
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let tool_call = expect_tool_call(&mut events).await;
|
|
assert_eq!(
|
|
tool_call,
|
|
acp::ToolCall::new("plan_1", "Update plan")
|
|
.kind(acp::ToolKind::Think)
|
|
.raw_input(json!({
|
|
"plan": [
|
|
{
|
|
"step": "Inspect the code",
|
|
"status": "completed",
|
|
},
|
|
{
|
|
"step": "Implement the tool",
|
|
"status": "in_progress"
|
|
},
|
|
{
|
|
"step": "Run tests",
|
|
"status": "pending",
|
|
}
|
|
]
|
|
}))
|
|
.meta(acp::Meta::from_iter([(
|
|
"tool_name".into(),
|
|
"update_plan".into()
|
|
)]))
|
|
);
|
|
|
|
let update = expect_tool_call_update_fields(&mut events).await;
|
|
assert_eq!(
|
|
update,
|
|
acp::ToolCallUpdate::new(
|
|
"plan_1",
|
|
acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
|
|
)
|
|
);
|
|
|
|
let plan = expect_plan(&mut events).await;
|
|
assert_eq!(
|
|
plan,
|
|
acp::Plan::new(vec![
|
|
acp::PlanEntry::new(
|
|
"Inspect the code",
|
|
acp::PlanEntryPriority::Medium,
|
|
acp::PlanEntryStatus::Completed,
|
|
),
|
|
acp::PlanEntry::new(
|
|
"Implement the tool",
|
|
acp::PlanEntryPriority::Medium,
|
|
acp::PlanEntryStatus::InProgress,
|
|
),
|
|
acp::PlanEntry::new(
|
|
"Run tests",
|
|
acp::PlanEntryPriority::Medium,
|
|
acp::PlanEntryStatus::Pending,
|
|
),
|
|
])
|
|
);
|
|
|
|
let update = expect_tool_call_update_fields(&mut events).await;
|
|
assert_eq!(
|
|
update,
|
|
acp::ToolCallUpdate::new(
|
|
"plan_1",
|
|
acp::ToolCallUpdateFields::new()
|
|
.status(acp::ToolCallStatus::Completed)
|
|
.raw_output("Plan updated")
|
|
)
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let mut retry_events = Vec::new();
|
|
while let Some(Ok(event)) = events.next().await {
|
|
match event {
|
|
ThreadEvent::Retry(retry_status) => {
|
|
retry_events.push(retry_status);
|
|
}
|
|
ThreadEvent::Stop(..) => break,
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
assert_eq!(retry_events.len(), 0);
|
|
thread.read_with(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hello!
|
|
|
|
## Assistant
|
|
|
|
Hey!
|
|
"}
|
|
)
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Hey,");
|
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
|
provider: LanguageModelProviderName::new("Anthropic"),
|
|
retry_after: Some(Duration::from_secs(3)),
|
|
});
|
|
fake_model.end_last_completion_stream();
|
|
|
|
cx.executor().advance_clock(Duration::from_secs(3));
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("there!");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let mut retry_events = Vec::new();
|
|
while let Some(Ok(event)) = events.next().await {
|
|
match event {
|
|
ThreadEvent::Retry(retry_status) => {
|
|
retry_events.push(retry_status);
|
|
}
|
|
ThreadEvent::Stop(..) => break,
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
assert_eq!(retry_events.len(), 1);
|
|
assert!(matches!(
|
|
retry_events[0],
|
|
acp_thread::RetryStatus { attempt: 1, .. }
|
|
));
|
|
thread.read_with(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.to_markdown(),
|
|
indoc! {"
|
|
## User
|
|
|
|
Hello!
|
|
|
|
## Assistant
|
|
|
|
Hey,
|
|
|
|
[resume]
|
|
|
|
## Assistant
|
|
|
|
there!
|
|
"}
|
|
)
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(EchoTool);
|
|
thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let tool_use_1 = LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: EchoTool::NAME.into(),
|
|
raw_input: json!({"text": "test"}).to_string(),
|
|
input: json!({"text": "test"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
tool_use_1.clone(),
|
|
));
|
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
|
provider: LanguageModelProviderName::new("Anthropic"),
|
|
retry_after: Some(Duration::from_secs(3)),
|
|
});
|
|
fake_model.end_last_completion_stream();
|
|
|
|
cx.executor().advance_clock(Duration::from_secs(3));
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
assert_eq!(
|
|
completion.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Call the echo tool!".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![language_model::MessageContent::ToolResult(
|
|
LanguageModelToolResult {
|
|
tool_use_id: tool_use_1.id.clone(),
|
|
tool_name: tool_use_1.name.clone(),
|
|
is_error: false,
|
|
content: vec!["test".into()],
|
|
output: Some("test".into())
|
|
}
|
|
)],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
},
|
|
]
|
|
);
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Done");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
events.collect::<Vec<_>>().await;
|
|
thread.read_with(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.last_received_or_pending_message(),
|
|
Some(Message::Agent(AgentMessage {
|
|
content: vec![AgentMessageContent::Text("Done".into())],
|
|
tool_results: IndexMap::default(),
|
|
reasoning_details: None,
|
|
}))
|
|
);
|
|
})
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
|
fake_model.send_last_completion_stream_error(
|
|
LanguageModelCompletionError::ServerOverloaded {
|
|
provider: LanguageModelProviderName::new("Anthropic"),
|
|
retry_after: Some(Duration::from_secs(3)),
|
|
},
|
|
);
|
|
fake_model.end_last_completion_stream();
|
|
cx.executor().advance_clock(Duration::from_secs(3));
|
|
cx.run_until_parked();
|
|
}
|
|
|
|
let mut errors = Vec::new();
|
|
let mut retry_events = Vec::new();
|
|
while let Some(event) = events.next().await {
|
|
match event {
|
|
Ok(ThreadEvent::Retry(retry_status)) => {
|
|
retry_events.push(retry_status);
|
|
}
|
|
Ok(ThreadEvent::Stop(..)) => break,
|
|
Err(error) => errors.push(error),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
assert_eq!(
|
|
retry_events.len(),
|
|
crate::thread::MAX_RETRY_ATTEMPTS as usize
|
|
);
|
|
for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
|
|
assert_eq!(retry_events[i].attempt, i + 1);
|
|
}
|
|
assert_eq!(errors.len(), 1);
|
|
let error = errors[0]
|
|
.downcast_ref::<LanguageModelCompletionError>()
|
|
.unwrap();
|
|
assert!(matches!(
|
|
error,
|
|
LanguageModelCompletionError::ServerOverloaded { .. }
|
|
));
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(StreamingEchoTool::new());
|
|
});
|
|
|
|
let _events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Use the streaming_echo tool"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Send a partial tool use (is_input_complete = false), simulating the LLM
|
|
// streaming input for a tool.
|
|
let tool_use = LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: "streaming_echo".into(),
|
|
raw_input: r#"{"text": "partial"}"#.into(),
|
|
input: json!({"text": "partial"}),
|
|
is_input_complete: false,
|
|
thought_signature: None,
|
|
};
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
|
cx.run_until_parked();
|
|
|
|
// Send a stream error WITHOUT ever sending is_input_complete = true.
|
|
// Before the fix, this would deadlock: the tool waits for more partials
|
|
// (or cancellation), run_turn_internal waits for the tool, and the sender
|
|
// keeping the channel open lives inside RunningTurn.
|
|
fake_model.send_last_completion_stream_error(
|
|
LanguageModelCompletionError::UpstreamProviderError {
|
|
message: "Internal server error".to_string(),
|
|
status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
|
|
retry_after: None,
|
|
},
|
|
);
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Advance past the retry delay so run_turn_internal retries.
|
|
cx.executor().advance_clock(Duration::from_secs(5));
|
|
cx.run_until_parked();
|
|
|
|
// The retry request should contain the streaming tool's error result,
|
|
// proving the tool terminated and its result was forwarded.
|
|
let completion = fake_model
|
|
.pending_completions()
|
|
.pop()
|
|
.expect("No running turn");
|
|
assert_eq!(
|
|
completion.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Use the streaming_echo tool".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![language_model::MessageContent::ToolResult(
|
|
LanguageModelToolResult {
|
|
tool_use_id: tool_use.id.clone(),
|
|
tool_name: tool_use.name,
|
|
is_error: true,
|
|
content: vec![
|
|
"Failed to receive tool input: tool input was not fully received"
|
|
.into(),
|
|
],
|
|
output: Some(
|
|
"Failed to receive tool input: tool input was not fully received"
|
|
.into()
|
|
),
|
|
}
|
|
)],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
},
|
|
]
|
|
);
|
|
|
|
// Finish the retry round so the turn completes cleanly.
|
|
fake_model.send_last_completion_stream_text_chunk("Done");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _cx| {
|
|
assert!(
|
|
thread.is_turn_complete(),
|
|
"Thread should not be stuck; the turn should have completed",
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool(
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(StreamingJsonErrorContextTool);
|
|
});
|
|
|
|
let _events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Use the streaming_json_error_context tool"],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let tool_use = LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: StreamingJsonErrorContextTool::NAME.into(),
|
|
raw_input: r#"{"text": "partial"#.into(),
|
|
input: json!({"text": "partial"}),
|
|
is_input_complete: false,
|
|
thought_signature: None,
|
|
};
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_event(
|
|
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
|
id: "tool_1".into(),
|
|
tool_name: StreamingJsonErrorContextTool::NAME.into(),
|
|
raw_input: r#"{"text": "partial"#.into(),
|
|
json_parse_error: "EOF while parsing a string at line 1 column 17".into(),
|
|
},
|
|
);
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
cx.executor().advance_clock(Duration::from_secs(5));
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model
|
|
.pending_completions()
|
|
.pop()
|
|
.expect("No running turn");
|
|
|
|
let tool_results: Vec<_> = completion
|
|
.messages
|
|
.iter()
|
|
.flat_map(|message| &message.content)
|
|
.filter_map(|content| match content {
|
|
MessageContent::ToolResult(result)
|
|
if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") =>
|
|
{
|
|
Some(result)
|
|
}
|
|
_ => None,
|
|
})
|
|
.collect();
|
|
|
|
assert_eq!(
|
|
tool_results.len(),
|
|
1,
|
|
"Expected exactly 1 tool result for tool_1, got {}: {:#?}",
|
|
tool_results.len(),
|
|
tool_results
|
|
);
|
|
|
|
let result = tool_results[0];
|
|
assert!(result.is_error);
|
|
let content_text = result.text_contents();
|
|
assert!(
|
|
content_text.contains("Saw partial text 'partial' before invalid JSON"),
|
|
"Expected tool-enriched partial context, got: {content_text}"
|
|
);
|
|
assert!(
|
|
content_text
|
|
.contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"),
|
|
"Expected forwarded JSON parse error, got: {content_text}"
|
|
);
|
|
assert!(
|
|
!content_text.contains("tool input was not fully received"),
|
|
"Should not contain orphaned sender error, got: {content_text}"
|
|
);
|
|
|
|
fake_model.send_last_completion_stream_text_chunk("Done");
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
thread.read_with(cx, |thread, _cx| {
|
|
assert!(
|
|
thread.is_turn_complete(),
|
|
"Thread should not be stuck; the turn should have completed",
|
|
);
|
|
});
|
|
}
|
|
|
|
/// Filters out the stop events for asserting against in tests
|
|
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
|
|
result_events
|
|
.into_iter()
|
|
.filter_map(|event| match event.unwrap() {
|
|
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
|
|
_ => None,
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
struct ThreadTest {
|
|
model: Arc<dyn LanguageModel>,
|
|
thread: Entity<Thread>,
|
|
project_context: Entity<ProjectContext>,
|
|
context_server_store: Entity<ContextServerStore>,
|
|
fs: Arc<FakeFs>,
|
|
}
|
|
|
|
enum TestModel {
|
|
Sonnet4,
|
|
Fake,
|
|
}
|
|
|
|
impl TestModel {
|
|
fn id(&self) -> LanguageModelId {
|
|
match self {
|
|
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
|
|
TestModel::Fake => unreachable!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|
cx.executor().allow_parking();
|
|
|
|
let fs = FakeFs::new(cx.background_executor.clone());
|
|
fs.create_dir(paths::settings_file().parent().unwrap())
|
|
.await
|
|
.unwrap();
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"default_profile": "test-profile",
|
|
"profiles": {
|
|
"test-profile": {
|
|
"name": "Test Profile",
|
|
"tools": {
|
|
EchoTool::NAME: true,
|
|
DelayTool::NAME: true,
|
|
WordListTool::NAME: true,
|
|
ToolRequiringPermission::NAME: true,
|
|
ToolRequiringPermission2::NAME: true,
|
|
InfiniteTool::NAME: true,
|
|
CancellationAwareTool::NAME: true,
|
|
StreamingEchoTool::NAME: true,
|
|
StreamingJsonErrorContextTool::NAME: true,
|
|
StreamingFailingEchoTool::NAME: true,
|
|
TerminalTool::NAME: true,
|
|
UpdatePlanTool::NAME: true,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
|
|
cx.update(|cx| {
|
|
settings::init(cx);
|
|
|
|
match model {
|
|
TestModel::Fake => {}
|
|
TestModel::Sonnet4 => {
|
|
gpui_tokio::init(cx);
|
|
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
|
cx.set_http_client(Arc::new(http_client));
|
|
let client = Client::production(cx);
|
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
language_model::init(cx);
|
|
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
|
|
language_models::init(user_store, client.clone(), cx);
|
|
}
|
|
};
|
|
|
|
watch_settings(fs.clone(), cx);
|
|
});
|
|
|
|
let templates = Templates::new();
|
|
|
|
fs.insert_tree(path!("/test"), json!({})).await;
|
|
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
|
|
|
let model = cx
|
|
.update(|cx| {
|
|
if let TestModel::Fake = model {
|
|
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
|
} else {
|
|
let model_id = model.id();
|
|
let models = LanguageModelRegistry::read_global(cx);
|
|
let model = models
|
|
.available_models(cx)
|
|
.find(|model| model.id() == model_id)
|
|
.unwrap();
|
|
|
|
let provider = models.provider(&model.provider_id()).unwrap();
|
|
let authenticated = provider.authenticate(cx);
|
|
|
|
cx.spawn(async move |_cx| {
|
|
authenticated.await.unwrap();
|
|
model
|
|
})
|
|
}
|
|
})
|
|
.await;
|
|
|
|
let project_context = cx.new(|_cx| ProjectContext::default());
|
|
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
|
let context_server_registry =
|
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
|
let thread = cx.new(|cx| {
|
|
Thread::new(
|
|
project,
|
|
project_context.clone(),
|
|
context_server_registry,
|
|
templates,
|
|
Some(model.clone()),
|
|
cx,
|
|
)
|
|
});
|
|
ThreadTest {
|
|
model,
|
|
thread,
|
|
project_context,
|
|
context_server_store,
|
|
fs,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[ctor::ctor]
|
|
fn init_logger() {
|
|
if std::env::var("RUST_LOG").is_ok() {
|
|
env_logger::init();
|
|
}
|
|
}
|
|
|
|
fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
|
|
let fs = fs.clone();
|
|
cx.spawn({
|
|
async move |cx| {
|
|
let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
|
|
cx.background_executor(),
|
|
fs,
|
|
paths::settings_file().clone(),
|
|
);
|
|
let _watcher_task = watcher_task;
|
|
|
|
while let Some(new_settings_content) = new_settings_content_rx.next().await {
|
|
cx.update(|cx| {
|
|
SettingsStore::update_global(cx, |settings, cx| {
|
|
settings.set_user_settings(&new_settings_content, cx)
|
|
})
|
|
})
|
|
.ok();
|
|
}
|
|
}
|
|
})
|
|
.detach();
|
|
}
|
|
|
|
fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
|
|
completion
|
|
.tools
|
|
.iter()
|
|
.map(|tool| tool.name.clone())
|
|
.collect()
|
|
}
|
|
|
|
fn setup_context_server(
|
|
name: &'static str,
|
|
tools: Vec<context_server::types::Tool>,
|
|
context_server_store: &Entity<ContextServerStore>,
|
|
cx: &mut TestAppContext,
|
|
) -> mpsc::UnboundedReceiver<(
|
|
context_server::types::CallToolParams,
|
|
oneshot::Sender<context_server::types::CallToolResponse>,
|
|
)> {
|
|
cx.update(|cx| {
|
|
let mut settings = ProjectSettings::get_global(cx).clone();
|
|
settings.context_servers.insert(
|
|
name.into(),
|
|
project::project_settings::ContextServerSettings::Stdio {
|
|
enabled: true,
|
|
remote: false,
|
|
command: ContextServerCommand {
|
|
path: "somebinary".into(),
|
|
args: Vec::new(),
|
|
env: None,
|
|
timeout: None,
|
|
},
|
|
},
|
|
);
|
|
ProjectSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
|
|
let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
|
|
.on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
|
|
context_server::types::InitializeResponse {
|
|
protocol_version: context_server::types::ProtocolVersion(
|
|
context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
|
|
),
|
|
server_info: context_server::types::Implementation {
|
|
name: name.into(),
|
|
title: None,
|
|
version: "1.0.0".to_string(),
|
|
description: None,
|
|
},
|
|
capabilities: context_server::types::ServerCapabilities {
|
|
tools: Some(context_server::types::ToolsCapabilities {
|
|
list_changed: Some(true),
|
|
}),
|
|
..Default::default()
|
|
},
|
|
meta: None,
|
|
}
|
|
})
|
|
.on_request::<context_server::types::requests::ListTools, _>(move |_params| {
|
|
let tools = tools.clone();
|
|
async move {
|
|
context_server::types::ListToolsResponse {
|
|
tools,
|
|
next_cursor: None,
|
|
meta: None,
|
|
}
|
|
}
|
|
})
|
|
.on_request::<context_server::types::requests::CallTool, _>(move |params| {
|
|
let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
|
|
async move {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
mcp_tool_calls_tx
|
|
.unbounded_send((params, response_tx))
|
|
.unwrap();
|
|
response_rx.await.unwrap()
|
|
}
|
|
});
|
|
context_server_store.update(cx, |store, cx| {
|
|
store.start_server(
|
|
Arc::new(ContextServer::new(
|
|
ContextServerId(name.into()),
|
|
Arc::new(fake_transport),
|
|
)),
|
|
cx,
|
|
);
|
|
});
|
|
cx.run_until_parked();
|
|
mcp_tool_calls_rx
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_tokens_before_message(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// First message
|
|
let message_1_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(message_1_id.clone(), ["First message"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Before any response, tokens_before_message should return None for first message
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_1_id),
|
|
None,
|
|
"First message should have no tokens before it"
|
|
);
|
|
});
|
|
|
|
// Complete first message with usage
|
|
fake_model.send_last_completion_stream_text_chunk("Response 1");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 100,
|
|
output_tokens: 50,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// First message still has no tokens before it
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_1_id),
|
|
None,
|
|
"First message should still have no tokens before it after response"
|
|
);
|
|
});
|
|
|
|
// Second message
|
|
let message_2_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(message_2_id.clone(), ["Second message"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Second message should have first message's input tokens before it
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_2_id),
|
|
Some(100),
|
|
"Second message should have 100 tokens before it (from first request)"
|
|
);
|
|
});
|
|
|
|
// Complete second message
|
|
fake_model.send_last_completion_stream_text_chunk("Response 2");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 250, // Total for this request (includes previous context)
|
|
output_tokens: 75,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// Third message
|
|
let message_3_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(message_3_id.clone(), ["Third message"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Third message should have second message's input tokens (250) before it
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_3_id),
|
|
Some(250),
|
|
"Third message should have 250 tokens before it (from second request)"
|
|
);
|
|
// Second message should still have 100
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_2_id),
|
|
Some(100),
|
|
"Second message should still have 100 tokens before it"
|
|
);
|
|
// First message still has none
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_1_id),
|
|
None,
|
|
"First message should still have no tokens before it"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// Set up three messages with responses
|
|
let message_1_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(message_1_id.clone(), ["Message 1"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Response 1");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 100,
|
|
output_tokens: 50,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
let message_2_id = UserMessageId::new();
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(message_2_id.clone(), ["Message 2"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
fake_model.send_last_completion_stream_text_chunk("Response 2");
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
language_model::TokenUsage {
|
|
input_tokens: 250,
|
|
output_tokens: 75,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
cx.run_until_parked();
|
|
|
|
// Verify initial state
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
|
|
});
|
|
|
|
// Truncate at message 2 (removes message 2 and everything after)
|
|
thread
|
|
.update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// After truncation, message_2_id no longer exists, so lookup should return None
|
|
thread.read_with(cx, |thread, _| {
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_2_id),
|
|
None,
|
|
"After truncation, message 2 no longer exists"
|
|
);
|
|
// Message 1 still exists but has no tokens before it
|
|
assert_eq!(
|
|
thread.tokens_before_message(&message_1_id),
|
|
None,
|
|
"First message still has no tokens before it"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/root", json!({})).await;
|
|
let project = Project::test(fs, ["/root".as_ref()], cx).await;
|
|
|
|
// Test 1: Deny rule blocks command
|
|
{
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
TerminalTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Confirm),
|
|
always_allow: vec![],
|
|
always_deny: vec![
|
|
agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
|
|
],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::TerminalToolInput {
|
|
command: "rm -rf /".to_string(),
|
|
cd: ".".to_string(),
|
|
timeout_ms: None,
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(
|
|
result.is_err(),
|
|
"expected command to be blocked by deny rule"
|
|
);
|
|
let err_msg = result.unwrap_err().to_lowercase();
|
|
assert!(
|
|
err_msg.contains("blocked"),
|
|
"error should mention the command was blocked"
|
|
);
|
|
}
|
|
|
|
// Test 2: Allow rule skips confirmation (and overrides default: Deny)
|
|
{
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default()
|
|
.with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
|
|
}));
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
TerminalTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Deny),
|
|
always_allow: vec![
|
|
agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
|
|
],
|
|
always_deny: vec![],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::TerminalToolInput {
|
|
command: "echo hello".to_string(),
|
|
cd: ".".to_string(),
|
|
timeout_ms: None,
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let update = rx.expect_update_fields().await;
|
|
assert!(
|
|
update.content.iter().any(|blocks| {
|
|
blocks
|
|
.iter()
|
|
.any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
|
|
}),
|
|
"expected terminal content (allow rule should skip confirmation and override default deny)"
|
|
);
|
|
|
|
let result = task.await;
|
|
assert!(
|
|
result.is_ok(),
|
|
"expected command to succeed without confirmation"
|
|
);
|
|
}
|
|
|
|
// Test 3: global default: allow does NOT override always_confirm patterns
|
|
{
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default()
|
|
.with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
|
|
}));
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
|
|
settings.tool_permissions.tools.insert(
|
|
TerminalTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![],
|
|
always_confirm: vec![
|
|
agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
|
|
],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
|
|
let _task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::TerminalToolInput {
|
|
command: "sudo rm file".to_string(),
|
|
cd: ".".to_string(),
|
|
timeout_ms: None,
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
// With global default: allow, confirm patterns are still respected
|
|
// The expect_authorization() call will panic if no authorization is requested,
|
|
// which validates that the confirm pattern still triggers confirmation
|
|
let _auth = rx.expect_authorization().await;
|
|
|
|
drop(_task);
|
|
}
|
|
|
|
// Test 4: tool-specific default: deny is respected even with global default: allow
|
|
{
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default()
|
|
.with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
|
|
}));
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
|
|
settings.tool_permissions.tools.insert(
|
|
TerminalTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Deny),
|
|
always_allow: vec![],
|
|
always_deny: vec![],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::TerminalToolInput {
|
|
command: "echo hello".to_string(),
|
|
cd: ".".to_string(),
|
|
timeout_ms: None,
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
// tool-specific default: deny is respected even with global default: allow
|
|
let result = task.await;
|
|
assert!(
|
|
result.is_err(),
|
|
"expected command to be blocked by tool-specific deny default"
|
|
);
|
|
let err_msg = result.unwrap_err().to_lowercase();
|
|
assert!(
|
|
err_msg.contains("disabled"),
|
|
"error should mention the tool is disabled, got: {err_msg}"
|
|
);
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
// Ensure empty threads are not saved, even if they get mutated.
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "label".to_string(),
|
|
message: "subagent task prompt".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
|
thread
|
|
.running_subagent_ids(cx)
|
|
.get(0)
|
|
.expect("subagent thread should be running")
|
|
.clone()
|
|
});
|
|
|
|
let subagent_thread = agent.read_with(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get(&subagent_session_id)
|
|
.expect("subagent session should exist")
|
|
.acp_thread
|
|
.clone()
|
|
});
|
|
|
|
model.send_last_completion_stream_text_chunk("subagent task response");
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
assert_eq!(
|
|
subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
|
|
indoc! {"
|
|
## User
|
|
|
|
subagent task prompt
|
|
|
|
## Assistant
|
|
|
|
subagent task response
|
|
|
|
"}
|
|
);
|
|
|
|
model.send_last_completion_stream_text_chunk("Response");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
|
|
assert_eq!(
|
|
acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
|
|
indoc! {r#"
|
|
## User
|
|
|
|
Prompt
|
|
|
|
## Assistant
|
|
|
|
spawning subagent
|
|
|
|
**Tool Call: label**
|
|
Status: Completed
|
|
|
|
subagent task response
|
|
|
|
## Assistant
|
|
|
|
Response
|
|
|
|
"#},
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
// Ensure empty threads are not saved, even if they get mutated.
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "label".to_string(),
|
|
message: "subagent task prompt".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
|
thread
|
|
.running_subagent_ids(cx)
|
|
.get(0)
|
|
.expect("subagent thread should be running")
|
|
.clone()
|
|
});
|
|
|
|
let subagent_thread = agent.read_with(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get(&subagent_session_id)
|
|
.expect("subagent session should exist")
|
|
.acp_thread
|
|
.clone()
|
|
});
|
|
|
|
model.send_last_completion_stream_text_chunk("subagent task response 1");
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
|
|
text: "thinking more about the subagent task".into(),
|
|
signature: None,
|
|
});
|
|
model.send_last_completion_stream_text_chunk("subagent task response 2");
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
assert_eq!(
|
|
subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
|
|
indoc! {"
|
|
## User
|
|
|
|
subagent task prompt
|
|
|
|
## Assistant
|
|
|
|
subagent task response 1
|
|
|
|
<thinking>
|
|
thinking more about the subagent task
|
|
</thinking>
|
|
|
|
subagent task response 2
|
|
|
|
"}
|
|
);
|
|
|
|
model.send_last_completion_stream_text_chunk("Response");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
|
|
assert_eq!(
|
|
acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
|
|
indoc! {r#"
|
|
## User
|
|
|
|
Prompt
|
|
|
|
## Assistant
|
|
|
|
spawning subagent
|
|
|
|
**Tool Call: label**
|
|
Status: Completed
|
|
|
|
subagent task response 1
|
|
|
|
subagent task response 2
|
|
|
|
## Assistant
|
|
|
|
Response
|
|
|
|
"#},
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
// Ensure empty threads are not saved, even if they get mutated.
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "label".to_string(),
|
|
message: "subagent task prompt".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
|
thread
|
|
.running_subagent_ids(cx)
|
|
.get(0)
|
|
.expect("subagent thread should be running")
|
|
.clone()
|
|
});
|
|
let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get(&subagent_session_id)
|
|
.expect("subagent session should exist")
|
|
.acp_thread
|
|
.clone()
|
|
});
|
|
|
|
// model.send_last_completion_stream_text_chunk("subagent task response");
|
|
// model.end_last_completion_stream();
|
|
|
|
// cx.run_until_parked();
|
|
|
|
acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
|
|
|
cx.run_until_parked();
|
|
|
|
send.await.unwrap();
|
|
|
|
acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(thread.status(), ThreadStatus::Idle);
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
indoc! {"
|
|
## User
|
|
|
|
Prompt
|
|
|
|
## Assistant
|
|
|
|
spawning subagent
|
|
|
|
**Tool Call: label**
|
|
Status: Canceled
|
|
|
|
"}
|
|
);
|
|
});
|
|
subagent_acp_thread.read_with(cx, |thread, cx| {
|
|
assert_eq!(thread.status(), ThreadStatus::Idle);
|
|
assert_eq!(
|
|
thread.to_markdown(cx),
|
|
indoc! {"
|
|
## User
|
|
|
|
subagent task prompt
|
|
|
|
"}
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// === First turn: create subagent ===
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "initial task".to_string(),
|
|
message: "do the first task".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
|
thread
|
|
.running_subagent_ids(cx)
|
|
.get(0)
|
|
.expect("subagent thread should be running")
|
|
.clone()
|
|
});
|
|
|
|
let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
|
|
agent
|
|
.sessions
|
|
.get(&subagent_session_id)
|
|
.expect("subagent session should exist")
|
|
.acp_thread
|
|
.clone()
|
|
});
|
|
|
|
// Subagent responds
|
|
model.send_last_completion_stream_text_chunk("first task response");
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Parent model responds to complete first turn
|
|
model.send_last_completion_stream_text_chunk("First response");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
|
|
// Verify subagent is no longer running
|
|
thread.read_with(cx, |thread, cx| {
|
|
assert!(
|
|
thread.running_subagent_ids(cx).is_empty(),
|
|
"subagent should not be running after completion"
|
|
);
|
|
});
|
|
|
|
// === Second turn: resume subagent with session_id ===
|
|
let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("resuming subagent");
|
|
let resume_tool_input = SpawnAgentToolInput {
|
|
label: "follow-up task".to_string(),
|
|
message: "do the follow-up task".to_string(),
|
|
session_id: Some(subagent_session_id.clone()),
|
|
};
|
|
let resume_tool_use = LanguageModelToolUse {
|
|
id: "subagent_2".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
|
|
input: serde_json::to_value(&resume_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Subagent should be running again with the same session
|
|
thread.read_with(cx, |thread, cx| {
|
|
let running = thread.running_subagent_ids(cx);
|
|
assert_eq!(running.len(), 1, "subagent should be running");
|
|
assert_eq!(running[0], subagent_session_id, "should be same session");
|
|
});
|
|
|
|
// Subagent responds to follow-up
|
|
model.send_last_completion_stream_text_chunk("follow-up task response");
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Parent model responds to complete second turn
|
|
model.send_last_completion_stream_text_chunk("Second response");
|
|
model.end_last_completion_stream();
|
|
|
|
send2.await.unwrap();
|
|
|
|
// Verify subagent is no longer running
|
|
thread.read_with(cx, |thread, cx| {
|
|
assert!(
|
|
thread.running_subagent_ids(cx).is_empty(),
|
|
"subagent should not be running after resume completion"
|
|
);
|
|
});
|
|
|
|
// Verify the subagent's acp thread has both conversation turns
|
|
assert_eq!(
|
|
subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
|
|
indoc! {"
|
|
## User
|
|
|
|
do the first task
|
|
|
|
## Assistant
|
|
|
|
first task response
|
|
|
|
## User
|
|
|
|
do the follow-up task
|
|
|
|
## Assistant
|
|
|
|
follow-up task response
|
|
|
|
"}
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(path!("/test"), json!({})).await;
|
|
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
|
let project_context = cx.new(|_cx| ProjectContext::default());
|
|
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
|
let context_server_registry =
|
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
let parent_thread = cx.new(|cx| {
|
|
Thread::new(
|
|
project.clone(),
|
|
project_context,
|
|
context_server_registry,
|
|
Templates::new(),
|
|
Some(model.clone()),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
|
|
subagent_thread.read_with(cx, |subagent_thread, cx| {
|
|
assert!(subagent_thread.is_subagent());
|
|
assert_eq!(subagent_thread.depth(), 1);
|
|
assert_eq!(
|
|
subagent_thread.model().map(|model| model.id()),
|
|
Some(model.id())
|
|
);
|
|
assert_eq!(
|
|
subagent_thread.parent_thread_id(),
|
|
Some(parent_thread.read(cx).id().clone())
|
|
);
|
|
|
|
let request = subagent_thread
|
|
.build_completion_request(CompletionIntent::UserPrompt, cx)
|
|
.unwrap();
|
|
assert_eq!(request.intent, Some(CompletionIntent::Subagent));
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(path!("/test"), json!({})).await;
|
|
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
|
let project_context = cx.new(|_cx| ProjectContext::default());
|
|
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
|
let context_server_registry =
|
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
let environment = Rc::new(cx.update(|cx| {
|
|
FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
|
|
}));
|
|
|
|
let deep_parent_thread = cx.new(|cx| {
|
|
let mut thread = Thread::new(
|
|
project.clone(),
|
|
project_context,
|
|
context_server_registry,
|
|
Templates::new(),
|
|
Some(model.clone()),
|
|
cx,
|
|
);
|
|
thread.set_subagent_context(SubagentContext {
|
|
parent_thread_id: acp::SessionId::new("parent-id"),
|
|
depth: MAX_SUBAGENT_DEPTH - 1,
|
|
});
|
|
thread
|
|
});
|
|
let deep_subagent_thread = cx.new(|cx| {
|
|
let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
|
|
thread.add_default_tools(environment, cx);
|
|
thread
|
|
});
|
|
|
|
deep_subagent_thread.read_with(cx, |thread, _| {
|
|
assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
|
|
assert!(
|
|
!thread.has_registered_tool(SpawnAgentTool::NAME),
|
|
"subagent tool should not be present at max depth"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(path!("/test"), json!({})).await;
|
|
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
|
let project_context = cx.new(|_cx| ProjectContext::default());
|
|
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
|
let context_server_registry =
|
|
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
let parent = cx.new(|cx| {
|
|
Thread::new(
|
|
project.clone(),
|
|
project_context.clone(),
|
|
context_server_registry.clone(),
|
|
Templates::new(),
|
|
Some(model.clone()),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
|
|
|
|
parent.update(cx, |thread, _cx| {
|
|
thread.register_running_subagent(subagent.downgrade());
|
|
});
|
|
|
|
subagent
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
subagent.read_with(cx, |thread, _| {
|
|
assert!(!thread.is_turn_complete(), "subagent should be running");
|
|
});
|
|
|
|
parent.update(cx, |thread, cx| {
|
|
thread.cancel(cx).detach();
|
|
});
|
|
|
|
subagent.read_with(cx, |thread, _| {
|
|
assert!(
|
|
thread.is_turn_complete(),
|
|
"subagent should be cancelled when parent cancels"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// Start the parent turn
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "label".to_string(),
|
|
message: "subagent task prompt".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Verify subagent is running
|
|
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
|
thread
|
|
.running_subagent_ids(cx)
|
|
.get(0)
|
|
.expect("subagent thread should be running")
|
|
.clone()
|
|
});
|
|
|
|
// Send a usage update that crosses the warning threshold (80% of 1,000,000)
|
|
model.send_last_completion_stream_text_chunk("partial work");
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
TokenUsage {
|
|
input_tokens: 850_000,
|
|
output_tokens: 0,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
|
|
cx.run_until_parked();
|
|
|
|
// The subagent should no longer be running
|
|
thread.read_with(cx, |thread, cx| {
|
|
assert!(
|
|
thread.running_subagent_ids(cx).is_empty(),
|
|
"subagent should be stopped after context window warning"
|
|
);
|
|
});
|
|
|
|
// The parent model should get a new completion request to respond to the tool error
|
|
model.send_last_completion_stream_text_chunk("Response after warning");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
|
|
// Verify the parent thread shows the warning error in the tool call
|
|
let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
|
|
assert!(
|
|
markdown.contains("nearing the end of its context window"),
|
|
"tool output should contain context window warning message, got:\n{markdown}"
|
|
);
|
|
assert!(
|
|
markdown.contains("Status: Failed"),
|
|
"tool call should have Failed status, got:\n{markdown}"
|
|
);
|
|
|
|
// Verify the subagent session still exists (can be resumed)
|
|
agent.read_with(cx, |agent, _cx| {
|
|
assert!(
|
|
agent.sessions.contains_key(&subagent_session_id),
|
|
"subagent session should still exist for potential resume"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// === First turn: create subagent, trigger context window warning ===
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "initial task".to_string(),
|
|
message: "do the first task".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
|
thread
|
|
.running_subagent_ids(cx)
|
|
.get(0)
|
|
.expect("subagent thread should be running")
|
|
.clone()
|
|
});
|
|
|
|
// Subagent sends a usage update that crosses the warning threshold.
|
|
// This triggers Normal→Warning, stopping the subagent.
|
|
model.send_last_completion_stream_text_chunk("partial work");
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
TokenUsage {
|
|
input_tokens: 850_000,
|
|
output_tokens: 0,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Verify the first turn was stopped with a context window warning
|
|
thread.read_with(cx, |thread, cx| {
|
|
assert!(
|
|
thread.running_subagent_ids(cx).is_empty(),
|
|
"subagent should be stopped after context window warning"
|
|
);
|
|
});
|
|
|
|
// Parent model responds to complete first turn
|
|
model.send_last_completion_stream_text_chunk("First response");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
|
|
let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
|
|
assert!(
|
|
markdown.contains("nearing the end of its context window"),
|
|
"first turn should have context window warning, got:\n{markdown}"
|
|
);
|
|
|
|
// === Second turn: resume the same subagent (now at Warning level) ===
|
|
let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("resuming subagent");
|
|
let resume_tool_input = SpawnAgentToolInput {
|
|
label: "follow-up task".to_string(),
|
|
message: "do the follow-up task".to_string(),
|
|
session_id: Some(subagent_session_id.clone()),
|
|
};
|
|
let resume_tool_use = LanguageModelToolUse {
|
|
id: "subagent_2".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
|
|
input: serde_json::to_value(&resume_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Subagent responds with tokens still at warning level (no worse).
|
|
// Since ratio_before_prompt was already Warning, this should NOT
|
|
// trigger the context window warning again.
|
|
model.send_last_completion_stream_text_chunk("follow-up task response");
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
|
TokenUsage {
|
|
input_tokens: 870_000,
|
|
output_tokens: 0,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
},
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Parent model responds to complete second turn
|
|
model.send_last_completion_stream_text_chunk("Second response");
|
|
model.end_last_completion_stream();
|
|
|
|
send2.await.unwrap();
|
|
|
|
// The resumed subagent should have completed normally since the ratio
|
|
// didn't transition (it was Warning before and stayed at Warning)
|
|
let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
|
|
assert!(
|
|
markdown.contains("follow-up task response"),
|
|
"resumed subagent should complete normally when already at warning, got:\n{markdown}"
|
|
);
|
|
// The second tool call should NOT have a context window warning
|
|
let second_tool_pos = markdown
|
|
.find("follow-up task")
|
|
.expect("should find follow-up tool call");
|
|
let after_second_tool = &markdown[second_tool_pos..];
|
|
assert!(
|
|
!after_second_tool.contains("nearing the end of its context window"),
|
|
"should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
cx.update(|cx| {
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
cx.update(|cx| {
|
|
cx.update_flags(true, vec!["subagents".to_string()]);
|
|
});
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/",
|
|
json!({
|
|
"a": {
|
|
"b.md": "Lorem"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
|
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
|
let agent = cx.update(|cx| {
|
|
NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
|
|
});
|
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
|
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
connection
|
|
.clone()
|
|
.new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
|
|
})
|
|
.await
|
|
.unwrap();
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
let thread = agent.read_with(cx, |agent, _| {
|
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
|
});
|
|
let model = Arc::new(FakeLanguageModel::default());
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_model(model.clone(), cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// Start the parent turn
|
|
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
|
|
cx.run_until_parked();
|
|
model.send_last_completion_stream_text_chunk("spawning subagent");
|
|
let subagent_tool_input = SpawnAgentToolInput {
|
|
label: "label".to_string(),
|
|
message: "subagent task prompt".to_string(),
|
|
session_id: None,
|
|
};
|
|
let subagent_tool_use = LanguageModelToolUse {
|
|
id: "subagent_1".into(),
|
|
name: SpawnAgentTool::NAME.into(),
|
|
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
|
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
subagent_tool_use,
|
|
));
|
|
model.end_last_completion_stream();
|
|
|
|
cx.run_until_parked();
|
|
|
|
// Verify subagent is running
|
|
thread.read_with(cx, |thread, cx| {
|
|
assert!(
|
|
!thread.running_subagent_ids(cx).is_empty(),
|
|
"subagent should be running"
|
|
);
|
|
});
|
|
|
|
// The subagent's model returns a non-retryable error
|
|
model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
|
|
tokens: None,
|
|
});
|
|
|
|
cx.run_until_parked();
|
|
|
|
// The subagent should no longer be running
|
|
thread.read_with(cx, |thread, cx| {
|
|
assert!(
|
|
thread.running_subagent_ids(cx).is_empty(),
|
|
"subagent should not be running after error"
|
|
);
|
|
});
|
|
|
|
// The parent model should get a new completion request to respond to the tool error
|
|
model.send_last_completion_stream_text_chunk("Response after error");
|
|
model.end_last_completion_stream();
|
|
|
|
send.await.unwrap();
|
|
|
|
// Verify the parent thread shows the error in the tool call
|
|
let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
|
|
assert!(
|
|
markdown.contains("Status: Failed"),
|
|
"tool call should have Failed status after model error, got:\n{markdown}"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
EditFileTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let context_server_registry =
|
|
cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
|
let templates = crate::Templates::new();
|
|
let thread = cx.new(|cx| {
|
|
crate::Thread::new(
|
|
project.clone(),
|
|
cx.new(|_cx| prompt_store::ProjectContext::default()),
|
|
context_server_registry,
|
|
templates.clone(),
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
let action_log = cx.update(|cx| thread.read(cx).action_log.clone());
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::EditFileTool::new(
|
|
project.clone(),
|
|
thread.downgrade(),
|
|
action_log,
|
|
language_registry,
|
|
));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::EditFileToolInput {
|
|
display_description: "Edit sensitive file".to_string(),
|
|
path: "root/sensitive_config.txt".into(),
|
|
mode: crate::EditFileMode::Edit,
|
|
content: None,
|
|
edits: Some(vec![]),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(result.is_err(), "expected edit to be blocked");
|
|
assert!(
|
|
result.unwrap_err().to_string().contains("blocked"),
|
|
"error should mention the edit was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
DeletePathTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::DeletePathToolInput {
|
|
path: "root/important_data.txt".to_string(),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(result.is_err(), "expected deletion to be blocked");
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the deletion was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/root",
|
|
json!({
|
|
"safe.txt": "content",
|
|
"protected": {}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
MovePathTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::MovePathTool::new(project));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::MovePathToolInput {
|
|
source_path: "root/safe.txt".to_string(),
|
|
destination_path: "root/protected/safe.txt".to_string(),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(
|
|
result.is_err(),
|
|
"expected move to be blocked due to destination path"
|
|
);
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the move was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/root",
|
|
json!({
|
|
"secret.txt": "secret content",
|
|
"public": {}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
MovePathTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::MovePathTool::new(project));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::MovePathToolInput {
|
|
source_path: "root/secret.txt".to_string(),
|
|
destination_path: "root/public/not_secret.txt".to_string(),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(
|
|
result.is_err(),
|
|
"expected move to be blocked due to source path"
|
|
);
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the move was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/root",
|
|
json!({
|
|
"confidential.txt": "confidential data",
|
|
"dest": {}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
CopyPathTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![
|
|
agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
|
|
],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::CopyPathTool::new(project));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::CopyPathToolInput {
|
|
source_path: "root/confidential.txt".to_string(),
|
|
destination_path: "root/dest/copy.txt".to_string(),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(result.is_err(), "expected copy to be blocked");
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the copy was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/root",
|
|
json!({
|
|
"normal.txt": "normal content",
|
|
"readonly": {
|
|
"config.txt": "readonly content"
|
|
}
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
SaveFileTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::SaveFileTool::new(project));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::SaveFileToolInput {
|
|
paths: vec![
|
|
std::path::PathBuf::from("root/normal.txt"),
|
|
std::path::PathBuf::from("root/readonly/config.txt"),
|
|
],
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(
|
|
result.is_err(),
|
|
"expected save to be blocked due to denied path"
|
|
);
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the save was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/root", json!({"config.secret": "secret config"}))
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
SaveFileTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::SaveFileTool::new(project));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::SaveFileToolInput {
|
|
paths: vec![std::path::PathBuf::from("root/config.secret")],
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let result = task.await;
|
|
assert!(result.is_err(), "expected save to be blocked");
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the save was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
WebSearchTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![
|
|
agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
|
|
],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::WebSearchTool);
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let input: crate::WebSearchToolInput =
|
|
serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
|
|
|
|
let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
|
|
|
|
let result = task.await;
|
|
assert!(result.is_err(), "expected search to be blocked");
|
|
match result.unwrap_err() {
|
|
crate::WebSearchToolOutput::Error { error } => {
|
|
assert!(
|
|
error.contains("blocked"),
|
|
"error should mention the search was blocked"
|
|
);
|
|
}
|
|
other => panic!("expected Error variant, got: {other:?}"),
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree("/root", json!({"README.md": "# Hello"}))
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
EditFileTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Confirm),
|
|
always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
|
|
always_deny: vec![],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let context_server_registry =
|
|
cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
|
let templates = crate::Templates::new();
|
|
let thread = cx.new(|cx| {
|
|
crate::Thread::new(
|
|
project.clone(),
|
|
cx.new(|_cx| prompt_store::ProjectContext::default()),
|
|
context_server_registry,
|
|
templates.clone(),
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
let action_log = thread.read_with(cx, |thread, _cx| thread.action_log().clone());
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::EditFileTool::new(
|
|
project,
|
|
thread.downgrade(),
|
|
action_log,
|
|
language_registry,
|
|
));
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
|
|
let _task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::EditFileToolInput {
|
|
display_description: "Edit README".to_string(),
|
|
path: "root/README.md".into(),
|
|
mode: crate::EditFileMode::Edit,
|
|
content: None,
|
|
edits: Some(vec![]),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.run_until_parked();
|
|
|
|
let event = rx.try_recv();
|
|
assert!(
|
|
!matches!(event, Ok(Ok(ThreadEvent::ToolCallAuthorization(_)))),
|
|
"expected no authorization request for allowed .md file"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fs = FakeFs::new(cx.executor());
|
|
fs.insert_tree(
|
|
"/root",
|
|
json!({
|
|
".zed": {
|
|
"settings.json": "{}"
|
|
},
|
|
"README.md": "# Hello"
|
|
}),
|
|
)
|
|
.await;
|
|
let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let context_server_registry =
|
|
cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
|
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
|
let templates = crate::Templates::new();
|
|
let thread = cx.new(|cx| {
|
|
crate::Thread::new(
|
|
project.clone(),
|
|
cx.new(|_cx| prompt_store::ProjectContext::default()),
|
|
context_server_registry,
|
|
templates.clone(),
|
|
None,
|
|
cx,
|
|
)
|
|
});
|
|
let action_log = thread.read_with(cx, |thread, _cx| thread.action_log().clone());
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::EditFileTool::new(
|
|
project,
|
|
thread.downgrade(),
|
|
action_log,
|
|
language_registry,
|
|
));
|
|
|
|
// Editing a file inside .zed/ should still prompt even with global default: allow,
|
|
// because local settings paths are sensitive and require confirmation regardless.
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
let _task = cx.update(|cx| {
|
|
tool.run(
|
|
ToolInput::resolved(crate::EditFileToolInput {
|
|
display_description: "Edit local settings".to_string(),
|
|
path: "root/.zed/settings.json".into(),
|
|
mode: crate::EditFileMode::Edit,
|
|
content: None,
|
|
edits: Some(vec![]),
|
|
}),
|
|
event_stream,
|
|
cx,
|
|
)
|
|
});
|
|
|
|
let _update = rx.expect_update_fields().await;
|
|
let _auth = rx.expect_authorization().await;
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
FetchTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![
|
|
agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
|
|
],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let http_client = gpui::http_client::FakeHttpClient::with_200_response();
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::FetchTool::new(http_client));
|
|
let (event_stream, _rx) = crate::ToolCallEventStream::test();
|
|
|
|
let input: crate::FetchToolInput =
|
|
serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
|
|
|
|
let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
|
|
|
|
let result = task.await;
|
|
assert!(result.is_err(), "expected fetch to be blocked");
|
|
assert!(
|
|
result.unwrap_err().contains("blocked"),
|
|
"error should mention the fetch was blocked"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
FetchTool::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Confirm),
|
|
always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
|
|
always_deny: vec![],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
|
|
let http_client = gpui::http_client::FakeHttpClient::with_200_response();
|
|
|
|
#[allow(clippy::arc_with_non_send_sync)]
|
|
let tool = Arc::new(crate::FetchTool::new(http_client));
|
|
let (event_stream, mut rx) = crate::ToolCallEventStream::test();
|
|
|
|
let input: crate::FetchToolInput =
|
|
serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
|
|
|
|
let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
|
|
|
|
cx.run_until_parked();
|
|
|
|
let event = rx.try_recv();
|
|
assert!(
|
|
!matches!(event, Ok(Ok(ThreadEvent::ToolCallAuthorization(_)))),
|
|
"expected no authorization request for allowed docs.rs URL"
|
|
);
|
|
}
|
|
|
|
/// Approving one pending tool call with "Always for <tool>" auto-resolves
|
|
/// sibling pending authorizations for the same tool in the same turn.
|
|
#[gpui::test]
|
|
async fn test_always_allow_resolves_pending_authorizations(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Two parallel tool calls, both require permission.
|
|
for id in ["tool_id_1", "tool_id_2"] {
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: id.into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
}
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
|
|
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
|
|
|
|
// Approve the first with "always allow" — this persists a setting that
|
|
// makes the tool unconditionally allowed. The second pending
|
|
// authorization should resolve without user interaction.
|
|
tool_call_auth_1
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("always_allow:tool_requiring_permission"),
|
|
acp::PermissionOptionKind::AllowAlways,
|
|
))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// The second tool's receiver was dropped by the auto-resolve path, so
|
|
// sending a late response should fail.
|
|
let late_send = tool_call_auth_2
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("allow"),
|
|
acp::PermissionOptionKind::AllowOnce,
|
|
));
|
|
assert!(
|
|
late_send.is_err(),
|
|
"expected tool 2's response receiver to be dropped after auto-resolve"
|
|
);
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
let results: Vec<_> = message
|
|
.content
|
|
.iter()
|
|
.filter_map(|c| match c {
|
|
language_model::MessageContent::ToolResult(r) => Some(r),
|
|
_ => None,
|
|
})
|
|
.collect();
|
|
assert_eq!(
|
|
results.len(),
|
|
2,
|
|
"both tool calls should have produced results"
|
|
);
|
|
assert!(
|
|
results.iter().all(|r| !r.is_error),
|
|
"both results should be successful after auto-resolve, got: {:?}",
|
|
results
|
|
);
|
|
}
|
|
|
|
/// Externally editing settings (e.g. the user opening settings.json and
|
|
/// adding an `always_allow` rule) resolves pending authorization prompts
|
|
/// for tool calls that match the new rule.
|
|
#[gpui::test]
|
|
async fn test_external_settings_edit_resolves_pending_authorization(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_1".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let tool_call_auth = next_tool_call_authorization(&mut events).await;
|
|
|
|
// Simulate the user editing settings.json to globally allow the tool.
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
ToolRequiringPermission::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Allow),
|
|
always_allow: vec![],
|
|
always_deny: vec![],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// The pending prompt auto-resolves without the user clicking anything.
|
|
let late_send = tool_call_auth
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("allow"),
|
|
acp::PermissionOptionKind::AllowOnce,
|
|
));
|
|
assert!(
|
|
late_send.is_err(),
|
|
"response receiver should have been dropped after settings-driven auto-resolve"
|
|
);
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
let result = message
|
|
.content
|
|
.iter()
|
|
.find_map(|c| match c {
|
|
language_model::MessageContent::ToolResult(r) => Some(r),
|
|
_ => None,
|
|
})
|
|
.expect("expected a tool result");
|
|
assert!(!result.is_error, "tool should have been auto-allowed");
|
|
}
|
|
|
|
/// Externally adding a deny rule to settings dismisses a pending
|
|
/// authorization prompt and returns the tool call as denied.
|
|
#[gpui::test]
|
|
async fn test_external_deny_rule_resolves_pending_authorization(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_1".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let tool_call_auth = next_tool_call_authorization(&mut events).await;
|
|
|
|
// Simulate the user adding a deny default for the tool.
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.tool_permissions.tools.insert(
|
|
ToolRequiringPermission::NAME.into(),
|
|
agent_settings::ToolRules {
|
|
default: Some(settings::ToolPermissionMode::Deny),
|
|
always_allow: vec![],
|
|
always_deny: vec![],
|
|
always_confirm: vec![],
|
|
invalid_patterns: vec![],
|
|
},
|
|
);
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
let late_send = tool_call_auth
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("allow"),
|
|
acp::PermissionOptionKind::AllowOnce,
|
|
));
|
|
assert!(
|
|
late_send.is_err(),
|
|
"response receiver should have been dropped after deny auto-resolve"
|
|
);
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
let result = message
|
|
.content
|
|
.iter()
|
|
.find_map(|c| match c {
|
|
language_model::MessageContent::ToolResult(r) => Some(r),
|
|
_ => None,
|
|
})
|
|
.expect("expected a tool result");
|
|
assert!(
|
|
result.is_error,
|
|
"tool should have been auto-denied by the new rule"
|
|
);
|
|
}
|
|
|
|
/// Unrelated settings changes must not spuriously resolve pending
|
|
/// authorizations: if the re-check still returns `Confirm`, the prompt
|
|
/// stays visible and waits for the user.
|
|
#[gpui::test]
|
|
async fn test_unrelated_settings_change_does_not_resolve_pending_authorization(
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_id_1".into(),
|
|
name: ToolRequiringPermission::NAME.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let tool_call_auth = next_tool_call_authorization(&mut events).await;
|
|
|
|
// Touch SettingsStore with a change that doesn't affect tool
|
|
// permissions; the pending authorization should remain pending.
|
|
cx.update(|cx| {
|
|
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
|
|
settings.single_file_review = !settings.single_file_review;
|
|
agent_settings::AgentSettings::override_global(settings, cx);
|
|
});
|
|
cx.run_until_parked();
|
|
|
|
// The user still has to act — resolve with an Allow Once.
|
|
tool_call_auth
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("allow"),
|
|
acp::PermissionOptionKind::AllowOnce,
|
|
))
|
|
.expect("response receiver should still be alive");
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
let result = message
|
|
.content
|
|
.iter()
|
|
.find_map(|c| match c {
|
|
language_model::MessageContent::ToolResult(r) => Some(r),
|
|
_ => None,
|
|
})
|
|
.expect("expected a tool result");
|
|
assert!(!result.is_error);
|
|
}
|
|
|
|
/// Approving one pending tool call with "Always for <tool A>" must not
|
|
/// dismiss a sibling pending authorization for a *different* tool: the
|
|
/// persisted rule is scoped to tool A, so tool B's prompt stays visible
|
|
/// and waits for the user.
|
|
#[gpui::test]
|
|
async fn test_always_allow_does_not_resolve_unrelated_tool_authorization(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(ToolRequiringPermission);
|
|
thread.add_tool(ToolRequiringPermission2);
|
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Two parallel tool calls, each for a distinct tool with its own
|
|
// permission scope.
|
|
for (id, name) in [
|
|
("tool_id_1", ToolRequiringPermission::NAME),
|
|
("tool_id_2", ToolRequiringPermission2::NAME),
|
|
] {
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: id.into(),
|
|
name: name.into(),
|
|
raw_input: "{}".into(),
|
|
input: json!({}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
}
|
|
fake_model.end_last_completion_stream();
|
|
|
|
let auth_a = next_tool_call_authorization(&mut events).await;
|
|
let auth_b = next_tool_call_authorization(&mut events).await;
|
|
|
|
// Match prompts back to their originating tools via the authorization
|
|
// context so the test doesn't depend on scheduling order.
|
|
let (auth_for_tool_1, auth_for_tool_2) = {
|
|
let a_name = auth_a
|
|
.context
|
|
.as_ref()
|
|
.expect("settings-driven authorization must carry a context")
|
|
.tool_name
|
|
.clone();
|
|
if a_name == ToolRequiringPermission::NAME {
|
|
(auth_a, auth_b)
|
|
} else {
|
|
(auth_b, auth_a)
|
|
}
|
|
};
|
|
|
|
// Approve tool 1 with "always allow". Only tool 1's rule is persisted.
|
|
auth_for_tool_1
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("always_allow:tool_requiring_permission"),
|
|
acp::PermissionOptionKind::AllowAlways,
|
|
))
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Tool 2's receiver must still be alive: its permission is unrelated
|
|
// to the rule that was just added, so its prompt stays pending.
|
|
auth_for_tool_2
|
|
.response
|
|
.send(acp_thread::SelectedPermissionOutcome::new(
|
|
acp::PermissionOptionId::new("allow"),
|
|
acp::PermissionOptionKind::AllowOnce,
|
|
))
|
|
.expect("tool 2's response receiver should still be alive");
|
|
cx.run_until_parked();
|
|
|
|
let completion = fake_model.pending_completions().pop().unwrap();
|
|
let message = completion.messages.last().unwrap();
|
|
let results: Vec<_> = message
|
|
.content
|
|
.iter()
|
|
.filter_map(|c| match c {
|
|
language_model::MessageContent::ToolResult(r) => Some(r),
|
|
_ => None,
|
|
})
|
|
.collect();
|
|
assert_eq!(
|
|
results.len(),
|
|
2,
|
|
"both tool calls should have produced results"
|
|
);
|
|
assert!(
|
|
results.iter().all(|r| !r.is_error),
|
|
"both results should be successful, got: {:?}",
|
|
results
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
// Add a tool so we can simulate tool calls
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(EchoTool);
|
|
});
|
|
|
|
// Start a turn by sending a message
|
|
let mut events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Simulate the model making a tool call
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: "echo".into(),
|
|
raw_input: r#"{"text": "hello"}"#.into(),
|
|
input: json!({"text": "hello"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
|
|
|
|
// Signal that a message is queued before ending the stream
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.set_has_queued_message(true);
|
|
});
|
|
|
|
// Now end the stream - tool will run, and the boundary check should see the queue
|
|
fake_model.end_last_completion_stream();
|
|
|
|
// Collect all events until the turn stops
|
|
let all_events = collect_events_until_stop(&mut events, cx).await;
|
|
|
|
// Verify we received the tool call event
|
|
let tool_call_ids: Vec<_> = all_events
|
|
.iter()
|
|
.filter_map(|e| match e {
|
|
Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
|
|
_ => None,
|
|
})
|
|
.collect();
|
|
assert_eq!(
|
|
tool_call_ids,
|
|
vec!["tool_1"],
|
|
"Should have received a tool call event for our echo tool"
|
|
);
|
|
|
|
// The turn should have stopped with EndTurn
|
|
let stop_reasons = stop_events(all_events);
|
|
assert_eq!(
|
|
stop_reasons,
|
|
vec![acp::StopReason::EndTurn],
|
|
"Turn should have ended after tool completion due to queued message"
|
|
);
|
|
|
|
// Verify the queued message flag is still set
|
|
thread.update(cx, |thread, _cx| {
|
|
assert!(
|
|
thread.has_queued_message(),
|
|
"Should still have queued message flag set"
|
|
);
|
|
});
|
|
|
|
// Thread should be idle now
|
|
thread.update(cx, |thread, _cx| {
|
|
assert!(
|
|
thread.is_turn_complete(),
|
|
"Thread should not be running after turn ends"
|
|
);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(StreamingFailingEchoTool {
|
|
receive_chunks_until_failure: 1,
|
|
});
|
|
});
|
|
|
|
let _events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Use the streaming_failing_echo tool"],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
let tool_use = LanguageModelToolUse {
|
|
id: "call_1".into(),
|
|
name: StreamingFailingEchoTool::NAME.into(),
|
|
raw_input: "hello".into(),
|
|
input: json!({}),
|
|
is_input_complete: false,
|
|
thought_signature: None,
|
|
};
|
|
|
|
fake_model
|
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
|
|
|
cx.run_until_parked();
|
|
|
|
let completions = fake_model.pending_completions();
|
|
let last_completion = completions.last().unwrap();
|
|
|
|
assert_eq!(
|
|
last_completion.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec!["Use the streaming_failing_echo tool".into()],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![language_model::MessageContent::ToolResult(
|
|
LanguageModelToolResult {
|
|
tool_use_id: tool_use.id.clone(),
|
|
tool_name: tool_use.name,
|
|
is_error: true,
|
|
content: vec!["failed".into()],
|
|
output: Some("failed".into()),
|
|
}
|
|
)],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
},
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
always_allow_tools(cx);
|
|
|
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
|
let fake_model = model.as_fake();
|
|
|
|
let (complete_streaming_echo_tool_call_tx, complete_streaming_echo_tool_call_rx) =
|
|
oneshot::channel();
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(
|
|
StreamingEchoTool::new().with_wait_until_complete(complete_streaming_echo_tool_call_rx),
|
|
);
|
|
thread.add_tool(StreamingFailingEchoTool {
|
|
receive_chunks_until_failure: 1,
|
|
});
|
|
});
|
|
|
|
let _events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(
|
|
UserMessageId::new(),
|
|
["Use the streaming_echo tool and the streaming_failing_echo tool"],
|
|
cx,
|
|
)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "call_1".into(),
|
|
name: StreamingEchoTool::NAME.into(),
|
|
raw_input: "hello".into(),
|
|
input: json!({ "text": "hello" }),
|
|
is_input_complete: false,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
let first_tool_use = LanguageModelToolUse {
|
|
id: "call_1".into(),
|
|
name: StreamingEchoTool::NAME.into(),
|
|
raw_input: "hello world".into(),
|
|
input: json!({ "text": "hello world" }),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
};
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
first_tool_use.clone(),
|
|
));
|
|
let second_tool_use = LanguageModelToolUse {
|
|
name: StreamingFailingEchoTool::NAME.into(),
|
|
raw_input: "hello".into(),
|
|
input: json!({ "text": "hello" }),
|
|
is_input_complete: false,
|
|
thought_signature: None,
|
|
id: "call_2".into(),
|
|
};
|
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
second_tool_use.clone(),
|
|
));
|
|
|
|
cx.run_until_parked();
|
|
|
|
complete_streaming_echo_tool_call_tx.send(()).unwrap();
|
|
|
|
cx.run_until_parked();
|
|
|
|
let completions = fake_model.pending_completions();
|
|
let last_completion = completions.last().unwrap();
|
|
|
|
assert_eq!(
|
|
last_completion.messages[1..],
|
|
vec![
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![
|
|
"Use the streaming_echo tool and the streaming_failing_echo tool".into()
|
|
],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::Assistant,
|
|
content: vec![
|
|
language_model::MessageContent::ToolUse(first_tool_use.clone()),
|
|
language_model::MessageContent::ToolUse(second_tool_use.clone())
|
|
],
|
|
cache: false,
|
|
reasoning_details: None,
|
|
},
|
|
LanguageModelRequestMessage {
|
|
role: Role::User,
|
|
content: vec![
|
|
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: second_tool_use.id.clone(),
|
|
tool_name: second_tool_use.name,
|
|
is_error: true,
|
|
content: vec!["failed".into()],
|
|
output: Some("failed".into()),
|
|
}),
|
|
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
|
tool_use_id: first_tool_use.id.clone(),
|
|
tool_name: first_tool_use.name,
|
|
is_error: false,
|
|
content: vec!["hello world".into()],
|
|
output: Some("hello world".into()),
|
|
}),
|
|
],
|
|
cache: true,
|
|
reasoning_details: None,
|
|
},
|
|
]
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_mid_turn_model_and_settings_refresh(cx: &mut TestAppContext) {
|
|
let ThreadTest {
|
|
model, thread, fs, ..
|
|
} = setup(cx, TestModel::Fake).await;
|
|
let fake_model_a = model.as_fake();
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
thread.add_tool(EchoTool);
|
|
thread.add_tool(DelayTool);
|
|
});
|
|
|
|
// Set up two profiles: profile-a has both tools, profile-b has only DelayTool.
|
|
fs.insert_file(
|
|
paths::settings_file(),
|
|
json!({
|
|
"agent": {
|
|
"profiles": {
|
|
"profile-a": {
|
|
"name": "Profile A",
|
|
"tools": {
|
|
EchoTool::NAME: true,
|
|
DelayTool::NAME: true,
|
|
}
|
|
},
|
|
"profile-b": {
|
|
"name": "Profile B",
|
|
"tools": {
|
|
DelayTool::NAME: true,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
.to_string()
|
|
.into_bytes(),
|
|
)
|
|
.await;
|
|
cx.run_until_parked();
|
|
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("profile-a".into()), cx);
|
|
thread.set_thinking_enabled(false, cx);
|
|
});
|
|
|
|
// Send a message — first iteration starts with model A, profile-a, thinking off.
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(UserMessageId::new(), ["test mid-turn refresh"], cx)
|
|
})
|
|
.unwrap();
|
|
cx.run_until_parked();
|
|
|
|
// Verify first request has both tools and thinking disabled.
|
|
let completions = fake_model_a.pending_completions();
|
|
assert_eq!(completions.len(), 1);
|
|
let first_tools = tool_names_for_completion(&completions[0]);
|
|
assert_eq!(first_tools, vec![DelayTool::NAME, EchoTool::NAME]);
|
|
assert!(!completions[0].thinking_allowed);
|
|
|
|
// Model A responds with an echo tool call.
|
|
fake_model_a.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: "tool_1".into(),
|
|
name: "echo".into(),
|
|
raw_input: r#"{"text":"hello"}"#.into(),
|
|
input: json!({"text": "hello"}),
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
fake_model_a.end_last_completion_stream();
|
|
|
|
// Before the next iteration runs, switch to profile-b (only DelayTool),
|
|
// swap in a new model, and enable thinking.
|
|
let fake_model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
|
|
"test-provider",
|
|
"model-b",
|
|
"Model B",
|
|
true,
|
|
));
|
|
thread.update(cx, |thread, cx| {
|
|
thread.set_profile(AgentProfileId("profile-b".into()), cx);
|
|
thread.set_model(fake_model_b.clone() as Arc<dyn LanguageModel>, cx);
|
|
thread.set_thinking_enabled(true, cx);
|
|
});
|
|
|
|
// Run until parked — processes the echo tool call, loops back, picks up
|
|
// the new model/profile/thinking, and makes a second request to model B.
|
|
cx.run_until_parked();
|
|
|
|
// The second request should have gone to model B.
|
|
let model_b_completions = fake_model_b.pending_completions();
|
|
assert_eq!(
|
|
model_b_completions.len(),
|
|
1,
|
|
"second request should go to model B"
|
|
);
|
|
|
|
// Profile-b only has DelayTool, so echo should be gone.
|
|
let second_tools = tool_names_for_completion(&model_b_completions[0]);
|
|
assert_eq!(second_tools, vec![DelayTool::NAME]);
|
|
|
|
// Thinking should now be enabled.
|
|
assert!(model_b_completions[0].thinking_allowed);
|
|
}
|