Add a streaming edit file tool (#47244)

Release Notes:

- N/A
This commit is contained in:
Michael Benfield 2026-01-21 09:41:03 -08:00 committed by GitHub
parent e1e7676b5a
commit 6b5c06e323
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 958 additions and 7 deletions

View file

@ -2,7 +2,7 @@ mod create_file_parser;
mod edit_parser;
#[cfg(test)]
mod evals;
mod streaming_fuzzy_matcher;
pub mod streaming_fuzzy_matcher;
use crate::{Template, Templates};
use action_log::ActionLog;

View file

@ -2,13 +2,13 @@ use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool,
RestoreFileFromDiskTool, SaveFileTool, SubagentTool, SystemPromptTemplate, Template, Templates,
TerminalTool, ThinkingTool, ToolPermissionDecision, WebSearchTool,
decide_permission_from_settings,
RestoreFileFromDiskTool, SaveFileTool, StreamingEditFileTool, SubagentTool,
SystemPromptTemplate, Template, Templates, TerminalTool, ThinkingTool, ToolPermissionDecision,
WebSearchTool, decide_permission_from_settings,
};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use feature_flags::{FeatureFlagAppExt as _, SubagentsFeatureFlag};
use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt as _, SubagentsFeatureFlag};
use agent_client_protocol as acp;
use agent_settings::{
@ -1189,6 +1189,12 @@ impl Thread {
));
self.add_tool(DiagnosticsTool::new(self.project.clone()));
self.add_tool(EditFileTool::new(
self.project.clone(),
cx.weak_entity(),
language_registry.clone(),
Templates::new(),
));
self.add_tool(StreamingEditFileTool::new(
self.project.clone(),
cx.weak_entity(),
language_registry,
@ -2310,14 +2316,31 @@ impl Thread {
}
}
let use_streaming_edit_tool =
cx.has_flag::<AgentV2FeatureFlag>() && model.supports_streaming_tools();
let mut tools = self
.tools
.iter()
.filter_map(|(tool_name, tool)| {
// For streaming_edit_file, check profile against "edit_file" since that's what users configure
let profile_tool_name = if tool_name == "streaming_edit_file" {
"edit_file"
} else {
tool_name.as_ref()
};
if tool.supports_provider(&model.provider_id())
&& profile.is_tool_enabled(tool_name)
&& profile.is_tool_enabled(profile_tool_name)
{
Some((truncate(tool_name), tool.clone()))
match (tool_name.as_ref(), use_streaming_edit_tool) {
("streaming_edit_file", true) => {
// Expose streaming tool as "edit_file"
Some((SharedString::from("edit_file"), tool.clone()))
}
("edit_file", true) => None,
_ => Some((truncate(tool_name), tool.clone())),
}
} else {
None
}

View file

@ -14,6 +14,7 @@ mod open_tool;
mod read_file_tool;
mod restore_file_from_disk_tool;
mod save_file_tool;
mod streaming_edit_file_tool;
mod subagent_tool;
mod terminal_tool;
mod thinking_tool;
@ -40,6 +41,7 @@ pub use open_tool::*;
pub use read_file_tool::*;
pub use restore_file_from_disk_tool::*;
pub use save_file_tool::*;
pub use streaming_edit_file_tool::*;
pub use subagent_tool::*;
pub use terminal_tool::*;
pub use thinking_tool::*;

View file

@ -0,0 +1,926 @@
use crate::{
AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
};
use acp_thread::Diff;
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
use anyhow::{Context as _, Result, anyhow};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use language::{Anchor, LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::{Project, ProjectPath};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::ffi::OsStr;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use text::BufferSnapshot;
use ui::SharedString;
use util::rel_path::RelPath;
const DEFAULT_UI_TEXT: &str = "Editing file";
/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
///
/// Before using this tool:
///
/// 1. Use the `read_file` tool to understand the file's contents and context
///
/// 2. Verify the directory path is correct (only applicable when creating new files):
/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolInput {
/// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
///
/// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
///
/// NEVER mention the file path in this description.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object so that we can display it immediately.
pub display_description: String,
/// The full path of the file to create or modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - /a/b/backend
/// - /c/d/frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// The mode of operation on the file. Possible values:
/// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
/// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
/// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
///
/// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
pub mode: StreamingEditFileMode,
/// The complete content for the new file (required for 'create' and 'overwrite' modes).
/// This field should contain the entire file content.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// List of edit operations to apply sequentially (required for 'edit' mode).
/// Each edit finds `old_text` in the file and replaces it with `new_text`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub edits: Option<Vec<EditOperation>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum StreamingEditFileMode {
/// Create a new file if it doesn't exist
Create,
/// Replace the entire contents of an existing file
Overwrite,
/// Make granular edits to an existing file
Edit,
}
/// A single edit operation that replaces old text with new text
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditOperation {
/// The exact text to find in the file. This will be matched using fuzzy matching
/// to handle minor differences in whitespace or formatting.
pub old_text: String,
/// The text to replace it with
pub new_text: String,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
struct StreamingEditFileToolPartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct StreamingEditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf,
new_text: String,
old_text: Arc<String>,
#[serde(default)]
diff: String,
}
impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
fn from(output: StreamingEditFileToolOutput) -> Self {
if output.diff.is_empty() {
"No edits were made.".into()
} else {
format!(
"Edited {}:\n\n```diff\n{}\n```",
output.input_path.display(),
output.diff
)
.into()
}
}
}
pub struct StreamingEditFileTool {
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
project: Entity<Project>,
#[allow(dead_code)]
templates: Arc<Templates>,
}
impl StreamingEditFileTool {
pub fn new(
project: Entity<Project>,
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
templates: Arc<Templates>,
) -> Self {
Self {
project,
thread,
language_registry,
templates,
}
}
fn authorize(
&self,
input: &StreamingEditFileToolInput,
event_stream: &ToolCallEventStream,
cx: &mut App,
) -> Task<Result<()>> {
let path_str = input.path.to_string_lossy();
let settings = agent_settings::AgentSettings::get_global(cx);
let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
match decision {
ToolPermissionDecision::Allow => return Task::ready(Ok(())),
ToolPermissionDecision::Deny(reason) => {
return Task::ready(Err(anyhow!("{}", reason)));
}
ToolPermissionDecision::Confirm => {}
}
let local_settings_folder = paths::local_settings_folder_name();
let path = Path::new(&input.path);
if path.components().any(|component| {
component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
}) {
let context = crate::ToolPermissionContext {
tool_name: "edit_file".to_string(),
input_value: path_str.to_string(),
};
return event_stream.authorize(
format!("{} (local settings)", input.display_description),
context,
cx,
);
}
if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
&& canonical_path.starts_with(paths::config_dir())
{
let context = crate::ToolPermissionContext {
tool_name: "edit_file".to_string(),
input_value: path_str.to_string(),
};
return event_stream.authorize(
format!("{} (global settings)", input.display_description),
context,
cx,
);
}
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
if project_path.is_some() {
Task::ready(Ok(()))
} else {
let context = crate::ToolPermissionContext {
tool_name: "edit_file".to_string(),
input_value: path_str.to_string(),
};
event_stream.authorize(&input.display_description, context, cx)
}
}
}
impl AgentTool for StreamingEditFileTool {
type Input = StreamingEditFileToolInput;
type Output = StreamingEditFileToolOutput;
fn name() -> &'static str {
"streaming_edit_file"
}
fn kind() -> acp::ToolKind {
acp::ToolKind::Edit
}
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
cx: &mut App,
) -> SharedString {
match input {
Ok(input) => self
.project
.read(cx)
.find_project_path(&input.path, cx)
.and_then(|project_path| {
self.project
.read(cx)
.short_full_path_for_project_path(&project_path, cx)
})
.unwrap_or(input.path.to_string_lossy().into_owned())
.into(),
Err(raw_input) => {
if let Some(input) =
serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
{
let path = input.path.trim();
if !path.is_empty() {
return self
.project
.read(cx)
.find_project_path(&input.path, cx)
.and_then(|project_path| {
self.project
.read(cx)
.short_full_path_for_project_path(&project_path, cx)
})
.unwrap_or(input.path)
.into();
}
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string().into();
}
}
DEFAULT_UI_TEXT.into()
}
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let abs_path = project.read(cx).absolute_path(&project_path, cx);
if let Some(abs_path) = abs_path.clone() {
event_stream.update_fields(
ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
);
}
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})
.await?;
if let Some(abs_path) = abs_path.as_ref() {
let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
self.thread.update(cx, |thread, cx| {
let last_read = thread.file_read_times.get(abs_path).copied();
let current = buffer
.read(cx)
.file()
.and_then(|file| file.disk_state().mtime());
let dirty = buffer.read(cx).is_dirty();
let has_save = thread.has_tool("save_file");
let has_restore = thread.has_tool("restore_file_from_disk");
(last_read, current, dirty, has_save, has_restore)
})?;
if is_dirty {
let message = match (has_save_tool, has_restore_tool) {
(true, true) => {
"This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
}
(true, false) => {
"This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
}
(false, true) => {
"This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
}
(false, false) => {
"This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
then ask them to save or revert the file manually and inform you when it's ok to proceed."
}
};
anyhow::bail!("{}", message);
}
if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
if current != last_read {
anyhow::bail!(
"The file {} has been modified since you last read it. \
Please read the file again to get the current state before editing it.",
input.path.display()
);
}
}
}
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
event_stream.update_diff(diff.clone());
let _finalize_diff = util::defer({
let diff = diff.downgrade();
let mut cx = cx.clone();
move || {
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
}
});
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { Arc::new(old_snapshot.text()) }
})
.await;
match input.mode {
StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
let content = input.content.ok_or_else(|| {
anyhow!("'content' field is required for create and overwrite modes")
})?;
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
});
}
StreamingEditFileMode::Edit => {
let edits = input.edits.ok_or_else(|| {
anyhow!("'edits' field is required for edit mode")
})?;
apply_edits(&buffer, &edits, &diff, &event_stream, &abs_path, cx)?;
}
}
let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
});
if let Some(abs_path) = abs_path.as_ref() {
if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
buffer.file().and_then(|file| file.disk_state().mtime())
}) {
self.thread.update(cx, |thread, _| {
thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
})?;
}
}
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
let old_text = old_text.clone();
async move {
let new_text = new_snapshot.text();
let diff = language::unified_diff(&old_text, &new_text);
(new_text, diff)
}
})
.await;
let output = StreamingEditFileToolOutput {
input_path: input.path,
new_text,
old_text,
diff: unified_diff,
};
Ok(output)
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path.to_string_lossy().into_owned(),
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
}
fn apply_edits(
buffer: &Entity<language::Buffer>,
edits: &[EditOperation],
diff: &Entity<Diff>,
event_stream: &ToolCallEventStream,
abs_path: &Option<PathBuf>,
cx: &mut AsyncApp,
) -> Result<()> {
let mut emitted_location = false;
let mut failed_edits = Vec::new();
let mut ambiguous_edits = Vec::new();
for (index, edit) in edits.iter().enumerate() {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let result = apply_single_edit(buffer, &snapshot, edit, diff, cx);
match result {
Ok(Some(range)) => {
if !emitted_location {
let line = buffer.update(cx, |buffer, _cx| {
range.start.to_point(&buffer.snapshot()).row
});
if let Some(abs_path) = abs_path.clone() {
event_stream.update_fields(
ToolCallUpdateFields::new()
.locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
);
}
emitted_location = true;
}
}
Ok(None) => {
failed_edits.push(index);
}
Err(ranges) => {
ambiguous_edits.push((index, ranges));
}
}
}
if !failed_edits.is_empty() {
let indices = failed_edits
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(", ");
anyhow::bail!(
"Could not find matching text for edit(s) at index(es): {}. \
The old_text did not match any content in the file. \
Please read the file again to get the current content.",
indices
);
}
if !ambiguous_edits.is_empty() {
let details: Vec<String> = ambiguous_edits
.iter()
.map(|(index, ranges)| {
let lines = ranges
.iter()
.map(|r| r.start.to_string())
.collect::<Vec<_>>()
.join(", ");
format!("edit {}: matches at lines {}", index, lines)
})
.collect();
anyhow::bail!(
"Some edits matched multiple locations in the file:\n{}. \
Please provide more context in old_text to uniquely identify the location.",
details.join("\n")
);
}
Ok(())
}
fn apply_single_edit(
buffer: &Entity<language::Buffer>,
snapshot: &BufferSnapshot,
edit: &EditOperation,
diff: &Entity<Diff>,
cx: &mut AsyncApp,
) -> std::result::Result<Option<Range<Anchor>>, Vec<Range<usize>>> {
let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
matcher.push(&edit.old_text, None);
let matches = matcher.finish();
if matches.is_empty() {
return Ok(None);
}
if matches.len() > 1 {
return Err(matches);
}
let match_range = matches.into_iter().next().expect("checked len above");
let start_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(match_range.start));
let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(match_range.end));
diff.update(cx, |card, cx| {
card.reveal_range(start_anchor..end_anchor, cx)
});
buffer.update(cx, |buffer, cx| {
buffer.edit([(match_range.clone(), edit.new_text.as_str())], None, cx);
});
let new_end = buffer.read_with(cx, |buffer, _cx| {
buffer.anchor_after(match_range.start + edit.new_text.len())
});
Ok(Some(start_anchor..new_end))
}
fn resolve_path(
input: &StreamingEditFileToolInput,
project: Entity<Project>,
cx: &mut App,
) -> Result<ProjectPath> {
let project = project.read(cx);
match input.mode {
StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
let path = project
.find_project_path(&input.path, cx)
.context("Can't edit file: path not found")?;
let entry = project
.entry_for_path(&path, cx)
.context("Can't edit file: path not found")?;
anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
Ok(path)
}
StreamingEditFileMode::Create => {
if let Some(path) = project.find_project_path(&input.path, cx) {
anyhow::ensure!(
project.entry_for_path(&path, cx).is_none(),
"Can't create file: file already exists"
);
}
let parent_path = input
.path
.parent()
.context("Can't create file: incorrect path")?;
let parent_project_path = project.find_project_path(&parent_path, cx);
let parent_entry = parent_project_path
.as_ref()
.and_then(|path| project.entry_for_path(path, cx))
.context("Can't create file: parent directory doesn't exist")?;
anyhow::ensure!(
parent_entry.is_dir(),
"Can't create file: parent is not a directory"
);
let file_name = input
.path
.file_name()
.and_then(|file_name| file_name.to_str())
.and_then(|file_name| RelPath::unix(file_name).ok())
.context("Can't create file: invalid filename")?;
let new_file_path = parent_project_path.map(|parent| ProjectPath {
path: parent.path.join(file_name),
..parent
});
new_file_path.context("Can't create file")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ContextServerRegistry, Templates};
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use prompt_store::ProjectContext;
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"dir": {}})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
crate::Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let result = cx
.update(|cx| {
let input = StreamingEditFileToolInput {
display_description: "Create new file".into(),
path: "root/dir/new_file.txt".into(),
mode: StreamingEditFileMode::Create,
content: Some("Hello, World!".into()),
edits: None,
};
Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
Templates::new(),
))
.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.new_text, "Hello, World!");
assert!(!output.diff.is_empty());
}
#[gpui::test]
async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"file.txt": "old content"}))
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
crate::Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let result = cx
.update(|cx| {
let input = StreamingEditFileToolInput {
display_description: "Overwrite file".into(),
path: "root/file.txt".into(),
mode: StreamingEditFileMode::Overwrite,
content: Some("new content".into()),
edits: None,
};
Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
Templates::new(),
))
.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.new_text, "new content");
assert_eq!(*output.old_text, "old content");
}
#[gpui::test]
async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"file.txt": "line 1\nline 2\nline 3\n"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
crate::Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let result = cx
.update(|cx| {
let input = StreamingEditFileToolInput {
display_description: "Edit lines".into(),
path: "root/file.txt".into(),
mode: StreamingEditFileMode::Edit,
content: None,
edits: Some(vec![EditOperation {
old_text: "line 2".into(),
new_text: "modified line 2".into(),
}]),
};
Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
Templates::new(),
))
.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
}
#[gpui::test]
async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
crate::Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let result = cx
.update(|cx| {
let input = StreamingEditFileToolInput {
display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(),
mode: StreamingEditFileMode::Edit,
content: None,
edits: Some(vec![EditOperation {
old_text: "foo".into(),
new_text: "bar".into(),
}]),
};
Arc::new(StreamingEditFileTool::new(
project,
thread.downgrade(),
language_registry,
Templates::new(),
))
.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(
result.unwrap_err().to_string(),
"Can't edit file: path not found"
);
}
#[gpui::test]
async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"file.txt": "hello world"}))
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
crate::Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let result = cx
.update(|cx| {
let input = StreamingEditFileToolInput {
display_description: "Edit file".into(),
path: "root/file.txt".into(),
mode: StreamingEditFileMode::Edit,
content: None,
edits: Some(vec![EditOperation {
old_text: "nonexistent text that is not in the file".into(),
new_text: "replacement".into(),
}]),
};
Arc::new(StreamingEditFileTool::new(
project,
thread.downgrade(),
language_registry,
Templates::new(),
))
.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Could not find matching text")
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
}
}