zed/crates/edit_prediction/src/edit_prediction_tests.rs
Ben Kunkle 4742e75bc9
ep: Send settled data to cloud (#56572)
Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A or Added/Fixed/Improved ...
2026-05-14 11:55:18 +00:00

3897 lines
125 KiB
Rust

use super::*;
use crate::udiff::apply_diff_to_string;
use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
use clock::FakeSystemClock;
use clock::ReplicaId;
use cloud_api_types::{
CreateLlmTokenResponse, LlmToken, Organization, OrganizationConfiguration,
OrganizationEditPredictionConfiguration, OrganizationId, SubmitEditPredictionSettledBody,
SubmitEditPredictionSettledResponse,
};
use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
};
use db::AppDatabase;
use settings::EditPredictionDataCollectionChoice;
use futures::{
AsyncReadExt, FutureExt, StreamExt,
channel::{mpsc, oneshot},
};
use gpui::App;
use gpui::{
Entity, TestAppContext,
http_client::{FakeHttpClient, Response},
};
use indoc::indoc;
use language::{
Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
};
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use std::{ops::Range, path::Path, sync::Arc, time::Duration};
use util::{
path,
test::{TextRangeMarker, marked_text_ranges_by},
};
use uuid::Uuid;
use workspace::{AppState, CollaboratorId, MultiWorkspace};
use zeta_prompt::ZetaPromptInput;
use crate::{
BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
EditPredictionJumpsFeatureFlag, EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
};
#[gpui::test]
async fn test_current_state(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"1.txt": "Hello!\nHow\nBye\n",
"2.txt": "Hola!\nComo\nAdios\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer1 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot1.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
ep_store.register_buffer(&buffer1, &project, cx);
});
// Prediction for current file
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(
&request,
indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
ep_store.update(cx, |ep_store, cx| {
ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
});
// Prediction for diagnostic in another file
let diagnostic = lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: "Sentence is incomplete".to_string(),
..Default::default()
};
project.update(cx, |project, cx| {
project.lsp_store().update(cx, |lsp_store, cx| {
lsp_store
.update_diagnostics(
LanguageServerId(0),
lsp::PublishDiagnosticsParams {
uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
diagnostics: vec![diagnostic],
version: None,
},
None,
language::DiagnosticSourceKind::Pushed,
&[],
cx,
)
.unwrap();
});
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(
&request,
indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
@@ ... @@
Hola!
-Como
+Como estas?
Adios
"#},
))
.unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(
prediction,
BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
);
});
let buffer2 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer2, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
}
#[gpui::test]
async fn test_diagnostics_refresh_suppressed_while_following(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
cx.update(|cx| {
cx.update_flags(
false,
vec![EditPredictionJumpsFeatureFlag::NAME.to_string()],
);
});
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"1.txt": "Hello!\nHow\nBye\n",
"2.txt": "Hola!\nComo\nAdios\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let app_state = cx.update(|cx| {
let app_state = AppState::test(cx);
AppState::set_global(app_state.clone(), cx);
app_state
});
let multi_workspace =
cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
let workspace = multi_workspace
.read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone())
.unwrap();
cx.update(|cx| {
AppState::set_global(workspace.read(cx).app_state().clone(), cx);
});
let _ = app_state;
let buffer1 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot1.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
ep_store.register_buffer(&buffer1, &project, cx);
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(
&request,
indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
});
let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
multi_workspace.workspace().update(cx, |workspace, cx| {
workspace.start_following(CollaboratorId::Agent, window, cx);
});
});
cx.run_until_parked();
let diagnostic = lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: "Sentence is incomplete".to_string(),
..Default::default()
};
project.update(cx, |project, cx| {
project.lsp_store().update(cx, |lsp_store, cx| {
lsp_store
.update_diagnostics(
LanguageServerId(0),
lsp::PublishDiagnosticsParams {
uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
diagnostics: vec![diagnostic.clone()],
version: None,
},
None,
language::DiagnosticSourceKind::Pushed,
&[],
cx,
)
.unwrap();
});
});
cx.run_until_parked();
assert_no_predict_request_ready(&mut requests.predict);
let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
multi_workspace.workspace().update(cx, |workspace, cx| {
workspace.unfollow(CollaboratorId::Agent, window, cx);
});
});
cx.run_until_parked();
project.update(cx, |project, cx| {
project.lsp_store().update(cx, |lsp_store, cx| {
lsp_store
.update_diagnostics(
LanguageServerId(0),
lsp::PublishDiagnosticsParams {
uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
diagnostics: vec![diagnostic],
version: None,
},
None,
language::DiagnosticSourceKind::Pushed,
&[],
cx,
)
.unwrap();
});
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(
&request,
indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
@@ ... @@
Hola!
-Como
+Como estas?
Adios
"#},
))
.unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(
prediction,
BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
);
});
}
#[gpui::test]
async fn test_simple_request(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
let prediction_task = ep_store.update(cx, |ep_store, cx| {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
// request.excerpt_path.as_ref(),
// Path::new(path!("root/foo.md"))
// );
// assert_eq!(
// request.cursor_point,
// Point {
// line: Line(1),
// column: 3
// }
// );
respond_tx
.send(model_response(
&request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
assert_eq!(prediction.edits.len(), 1);
assert_eq!(
prediction.edits[0].0.to_point(&snapshot).start,
language::Point::new(1, 3)
);
assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
}
#[gpui::test]
async fn test_request_events(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\n\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
});
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(7..7, "How")], None, cx);
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
let prediction_task = ep_store.update(cx, |ep_store, cx| {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
let prompt = prompt_from_request(&request);
assert!(
prompt.contains(indoc! {"
--- a/root/foo.md
+++ b/root/foo.md
@@ -1,3 +1,3 @@
Hello!
-
+How
Bye
"}),
"{prompt}"
);
respond_tx
.send(model_response(
&request,
indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
assert_eq!(prediction.edits.len(), 1);
assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
}
#[gpui::test]
async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\n\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
});
// First burst: insert "How"
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(7..7, "How")], None, cx);
});
// Simulate a pause longer than the grouping threshold (e.g. 500ms).
cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
cx.run_until_parked();
// Second burst: append " are you?" immediately after "How" on the same line.
//
// Keeping both bursts on the same line ensures the existing line-span coalescing logic
// groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(10..10, " are you?")], None, cx);
});
// A second edit shortly after the first post-pause edit ensures the last edit timestamp is
// advanced after the pause boundary is recorded, making pause-splitting deterministic.
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(19..19, "!")], None, cx);
});
// With time-based splitting, there are two distinct events.
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(events.len(), 2);
let first_total_edit_range = buffer.read_with(cx, |buffer, _| {
events[0].total_edit_range.to_point(&buffer.snapshot())
});
assert_eq!(first_total_edit_range, Point::new(1, 0)..Point::new(1, 3));
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -1,3 +1,3 @@
Hello!
-
+How
Bye
"}
);
let second_total_edit_range = buffer.read_with(cx, |buffer, _| {
events[1].total_edit_range.to_point(&buffer.snapshot())
});
assert_eq!(second_total_edit_range, Point::new(1, 3)..Point::new(1, 13));
let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -1,3 +1,3 @@
Hello!
-How
+How are you?!
Bye
"}
);
}
#[gpui::test]
async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
// Create a file with 30 lines to test line-based coalescing
let content = (1..=30)
.map(|i| format!("Line {}\n", i))
.collect::<String>();
fs.insert_tree(
"/root",
json!({
"foo.md": content
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
});
// First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
buffer.update(cx, |buffer, cx| {
let start = Point::new(10, 0).to_offset(buffer);
let end = Point::new(13, 0).to_offset(buffer);
buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events(&events),
indoc! {"
@@ -8,9 +8,8 @@
Line 8
Line 9
Line 10
-Line 11
-Line 12
-Line 13
+Middle A
+Middle B
Line 14
Line 15
Line 16
"},
"After first edit"
);
// Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
// This tests that coalescing considers the START of the existing range
buffer.update(cx, |buffer, cx| {
let offset = Point::new(5, 0).to_offset(buffer);
buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events(&events),
indoc! {"
@@ -3,14 +3,14 @@
Line 3
Line 4
Line 5
+Above
Line 6
Line 7
Line 8
Line 9
Line 10
-Line 11
-Line 12
-Line 13
+Middle A
+Middle B
Line 14
Line 15
Line 16
"},
"After inserting above (should coalesce)"
);
// Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
// This tests that coalescing considers the END of the existing range
buffer.update(cx, |buffer, cx| {
let offset = Point::new(14, 0).to_offset(buffer);
buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events(&events),
indoc! {"
@@ -3,15 +3,16 @@
Line 3
Line 4
Line 5
+Above
Line 6
Line 7
Line 8
Line 9
Line 10
-Line 11
-Line 12
-Line 13
+Middle A
+Middle B
Line 14
+Below
Line 15
Line 16
Line 17
"},
"After inserting below (should coalesce)"
);
// Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
// This should NOT coalesce - creates a new event
buffer.update(cx, |buffer, cx| {
let offset = Point::new(25, 0).to_offset(buffer);
buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events(&events),
indoc! {"
@@ -3,15 +3,16 @@
Line 3
Line 4
Line 5
+Above
Line 6
Line 7
Line 8
Line 9
Line 10
-Line 11
-Line 12
-Line 13
+Middle A
+Middle B
Line 14
+Below
Line 15
Line 16
Line 17
---
@@ -23,6 +23,7 @@
Line 22
Line 23
Line 24
+Far below
Line 25
Line 26
Line 27
"},
"After inserting far below (should NOT coalesce)"
);
}
fn render_events(events: &[StoredEvent]) -> String {
events
.iter()
.map(|e| {
let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
diff.as_str()
})
.collect::<Vec<_>>()
.join("\n---\n")
}
fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
events
.iter()
.map(|e| {
let zeta_prompt::Event::BufferChange {
diff, predicted, ..
} = e.event.as_ref();
let prefix = if *predicted { "predicted" } else { "manual" };
format!("{}\n{}", prefix, diff)
})
.collect()
}
fn make_collaborator_replica(
buffer: &Entity<Buffer>,
cx: &mut TestAppContext,
) -> (Entity<Buffer>, clock::Global) {
let (state, version) =
buffer.read_with(cx, |buffer, _cx| (buffer.to_proto(_cx), buffer.version()));
let collaborator = cx.new(|_cx| {
Buffer::from_proto(ReplicaId::new(1), Capability::ReadWrite, state, None).unwrap()
});
(collaborator, version)
}
async fn apply_collaborator_edit(
collaborator: &Entity<Buffer>,
buffer: &Entity<Buffer>,
since_version: &mut clock::Global,
edit_range: Range<usize>,
new_text: &str,
cx: &mut TestAppContext,
) {
collaborator.update(cx, |collaborator, cx| {
collaborator.edit([(edit_range, new_text)], None, cx);
});
let serialize_task = collaborator.read_with(cx, |collaborator, cx| {
collaborator.serialize_ops(Some(since_version.clone()), cx)
});
let ops = serialize_task.await;
*since_version = collaborator.read_with(cx, |collaborator, _cx| collaborator.version());
buffer.update(cx, |buffer, cx| {
buffer.apply_ops(
ops.into_iter()
.map(|op| language::proto::deserialize_operation(op).unwrap()),
cx,
);
});
}
#[gpui::test]
async fn test_nearby_collaborator_edits_are_kept_in_history(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
});
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
});
let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
let (line_one_start, line_one_len) = collaborator.read_with(cx, |buffer, _cx| {
(Point::new(1, 0).to_offset(buffer), buffer.line_len(1))
});
apply_collaborator_edit(
&collaborator,
&buffer,
&mut collaborator_version,
line_one_start..line_one_start + line_one_len as usize,
"REMOTE ONE",
cx,
)
.await;
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![indoc! {"
manual
@@ -1,5 +1,5 @@
-line 0
-line 1
+LOCAL ZERO
+REMOTE ONE
line 2
line 3
line 4
"}]
);
}
#[gpui::test]
async fn test_distant_collaborator_edits_are_omitted_from_history(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.rs": (0..1000)
.map(|i| format!("line {i}\n"))
.collect::<String>()
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
});
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
});
let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
let far_line_start = buffer.read_with(cx, |buffer, _cx| Point::new(900, 0).to_offset(buffer));
apply_collaborator_edit(
&collaborator,
&buffer,
&mut collaborator_version,
far_line_start..far_line_start + 7,
"REMOTE FAR",
cx,
)
.await;
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![indoc! {"
manual
@@ -1,4 +1,4 @@
-line 0
+LOCAL ZERO
line 1
line 2
line 3
"}]
);
}
#[gpui::test]
async fn test_irrelevant_collaborator_edits_in_different_files_are_omitted_from_history(
cx: &mut TestAppContext,
) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.rs": "line 0\nline 1\nline 2\nline 3\n",
"bar.rs": "line 0\nline 1\nline 2\nline 3\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let foo_buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let bar_buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/bar.rs"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let foo_cursor = foo_buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&foo_buffer, &project, cx);
ep_store.register_buffer(&bar_buffer, &project, cx);
let _ = ep_store.prediction_at(&foo_buffer, Some(foo_cursor), &project, cx);
});
let (bar_collaborator, mut bar_version) = make_collaborator_replica(&bar_buffer, cx);
apply_collaborator_edit(
&bar_collaborator,
&bar_buffer,
&mut bar_version,
0..6,
"REMOTE BAR",
cx,
)
.await;
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert!(events.is_empty());
}
#[gpui::test]
async fn test_large_edits_are_omitted_from_history(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.rs": (0..20)
.map(|i| format!("line {i}\n"))
.collect::<String>()
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
});
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
});
let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
let (line_three_start, line_three_len) = collaborator.read_with(cx, |buffer, _cx| {
(Point::new(3, 0).to_offset(buffer), buffer.line_len(3))
});
let large_edit = "X".repeat(EDIT_HISTORY_DIFF_SIZE_LIMIT + 1);
apply_collaborator_edit(
&collaborator,
&buffer,
&mut collaborator_version,
line_three_start..line_three_start + line_three_len as usize,
&large_edit,
cx,
)
.await;
buffer.update(cx, |buffer, cx| {
let line_seven_start = Point::new(7, 0).to_offset(buffer);
let line_seven_end = Point::new(7, 6).to_offset(buffer);
buffer.edit(
vec![(line_seven_start..line_seven_end, "LOCAL SEVEN")],
None,
cx,
);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
let rendered_events = render_events_with_predicted(&events);
assert_eq!(rendered_events.len(), 2);
assert!(rendered_events[0].contains("+LOCAL ZERO"));
assert!(!rendered_events[0].contains(&large_edit));
assert!(rendered_events[1].contains("+LOCAL SEVEN"));
assert!(!rendered_events[1].contains(&large_edit));
}
#[gpui::test]
async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
});
// Case 1: Manual edits have `predicted` set to false.
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![indoc! {"
manual
@@ -1,4 +1,4 @@
-line 0
+LINE ZERO
line 1
line 2
line 3
"}]
);
// Case 2: Multiple successive manual edits near each other are merged into one
// event with `predicted` set to false.
buffer.update(cx, |buffer, cx| {
let offset = Point::new(1, 0).to_offset(buffer);
let end = Point::new(1, 6).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![indoc! {"
manual
@@ -1,5 +1,5 @@
-line 0
-line 1
+LINE ZERO
+LINE ONE
line 2
line 3
line 4
"}]
);
// Case 3: Accepted predictions have `predicted` set to true.
// Case 5: A manual edit that follows a predicted edit is not merged with the
// predicted edit, even if it is nearby.
ep_store.update(cx, |ep_store, cx| {
buffer.update(cx, |buffer, cx| {
let offset = Point::new(2, 0).to_offset(buffer);
let end = Point::new(2, 6).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
});
ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![
indoc! {"
manual
@@ -1,5 +1,5 @@
-line 0
-line 1
+LINE ZERO
+LINE ONE
line 2
line 3
line 4
"},
indoc! {"
predicted
@@ -1,6 +1,6 @@
LINE ZERO
LINE ONE
-line 2
+LINE TWO
line 3
line 4
line 5
"}
]
);
// Case 4: Multiple successive accepted predictions near each other are merged
// into one event with `predicted` set to true.
ep_store.update(cx, |ep_store, cx| {
buffer.update(cx, |buffer, cx| {
let offset = Point::new(3, 0).to_offset(buffer);
let end = Point::new(3, 6).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
});
ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![
indoc! {"
manual
@@ -1,5 +1,5 @@
-line 0
-line 1
+LINE ZERO
+LINE ONE
line 2
line 3
line 4
"},
indoc! {"
predicted
@@ -1,7 +1,7 @@
LINE ZERO
LINE ONE
-line 2
-line 3
+LINE TWO
+LINE THREE
line 4
line 5
line 6
"}
]
);
// Case 5 (continued): A manual edit that follows a predicted edit is not merged
// with the predicted edit, even if it is nearby.
buffer.update(cx, |buffer, cx| {
let offset = Point::new(4, 0).to_offset(buffer);
let end = Point::new(4, 6).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![
indoc! {"
manual
@@ -1,5 +1,5 @@
-line 0
-line 1
+LINE ZERO
+LINE ONE
line 2
line 3
line 4
"},
indoc! {"
predicted
@@ -1,7 +1,7 @@
LINE ZERO
LINE ONE
-line 2
-line 3
+LINE TWO
+LINE THREE
line 4
line 5
line 6
"},
indoc! {"
manual
@@ -2,7 +2,7 @@
LINE ONE
LINE TWO
LINE THREE
-line 4
+LINE FOUR
line 5
line 6
line 7
"}
]
);
// Case 6: If we then perform a manual edit at a *different* location (more than
// 8 lines away), then the edits at the prior location can be merged with each
// other, even if some are predicted and some are not. `predicted` means all
// constituent edits were predicted.
buffer.update(cx, |buffer, cx| {
let offset = Point::new(14, 0).to_offset(buffer);
let end = Point::new(14, 7).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(
render_events_with_predicted(&events),
vec![
indoc! {"
manual
@@ -1,8 +1,8 @@
-line 0
-line 1
-line 2
-line 3
-line 4
+LINE ZERO
+LINE ONE
+LINE TWO
+LINE THREE
+LINE FOUR
line 5
line 6
line 7
"},
indoc! {"
manual
@@ -12,4 +12,4 @@
line 11
line 12
line 13
-line 14
+LINE FOURTEEN
"}
]
);
}
#[gpui::test]
async fn test_empty_prediction(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
let mut response = model_response(&request, "");
response.model_version = Some("zeta2:test-empty".to_string());
let id = response.request_id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
// prediction is reported as rejected
let (reject_request, _) = requests.reject.next().await.unwrap();
assert_eq!(
&reject_request.rejections,
&[EditPredictionRejection {
request_id: id,
reason: EditPredictionRejectReason::Empty,
was_shown: false,
model_version: Some("zeta2:test-empty".to_string()),
e2e_latency_ms: Some(0),
}]
);
}
#[gpui::test]
async fn test_interpolated_empty(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
let mut response = model_response(&request, SIMPLE_DIFF);
response.model_version = Some("zeta2:test-interpolated-empty".to_string());
let id = response.request_id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
// prediction is reported as rejected
let (reject_request, _) = requests.reject.next().await.unwrap();
assert_eq!(
&reject_request.rejections,
&[EditPredictionRejection {
request_id: id,
reason: EditPredictionRejectReason::InterpolatedEmpty,
was_shown: false,
model_version: Some("zeta2:test-interpolated-empty".to_string()),
e2e_latency_ms: Some(0),
}]
);
}
const SIMPLE_DIFF: &str = indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"};
#[gpui::test]
async fn test_replace_current(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(&request, SIMPLE_DIFF);
let first_id = first_response.request_id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
first_id
);
});
// a second request is triggered
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(&request, SIMPLE_DIFF);
let second_id = second_response.request_id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// second replaces first
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
second_id
);
});
// first is reported as replaced
let (reject_request, _) = requests.reject.next().await.unwrap();
assert_eq!(
&reject_request.rejections,
&[EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Replaced,
was_shown: false,
model_version: None,
e2e_latency_ms: Some(0),
}]
);
}
#[gpui::test]
async fn test_current_preferred(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(&request, SIMPLE_DIFF);
let first_id = first_response.request_id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
first_id
);
});
// a second request is triggered
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
let mut second_response = model_response(
&request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"},
);
second_response.model_version = Some("zeta2:test-current-preferred".to_string());
let second_id = second_response.request_id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// first is preferred over second
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
first_id
);
});
// second is reported as rejected
let (reject_request, _) = requests.reject.next().await.unwrap();
assert_eq!(
&reject_request.rejections,
&[EditPredictionRejection {
request_id: second_id,
reason: EditPredictionRejectReason::CurrentPreferred,
was_shown: false,
model_version: Some("zeta2:test-current-preferred".to_string()),
e2e_latency_ms: Some(0),
}]
);
}
#[gpui::test]
async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
// start two refresh tasks
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
let second_response = model_response(&request, SIMPLE_DIFF);
let second_id = second_response.request_id.clone();
respond_second.send(second_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// current prediction is second
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
second_id
);
});
let mut first_response = model_response(&request1, SIMPLE_DIFF);
first_response.model_version = Some("zeta2:test-canceled".to_string());
let first_id = first_response.request_id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// current prediction is still second, since first was cancelled
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
second_id
);
});
// first is reported as rejected
let (reject_request, _) = requests.reject.next().await.unwrap();
cx.run_until_parked();
assert_eq!(
&reject_request.rejections,
&[EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Canceled,
was_shown: false,
model_version: Some("zeta2:test-canceled".to_string()),
e2e_latency_ms: None,
}]
);
}
#[gpui::test]
async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
// start two refresh tasks
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// start a third request
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
// 2 are pending, so 2nd is cancelled
assert_eq!(
ep_store
.get_or_init_project(&project, cx)
.cancelled_predictions
.iter()
.copied()
.collect::<Vec<_>>(),
[1]
);
});
// wait for throttle
cx.run_until_parked();
let (request3, respond_third) = requests.predict.next().await.unwrap();
let first_response = model_response(&request1, SIMPLE_DIFF);
let first_id = first_response.request_id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// current prediction is first
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
first_id
);
});
let mut cancelled_response = model_response(&request2, SIMPLE_DIFF);
cancelled_response.model_version = Some("zeta2:test-canceled-second".to_string());
let cancelled_id = cancelled_response.request_id.clone();
respond_second.send(cancelled_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// current prediction is still first, since second was cancelled
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
first_id
);
});
let third_response = model_response(&request3, SIMPLE_DIFF);
let third_response_id = third_response.request_id.clone();
respond_third.send(third_response).unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
// third completes and replaces first
assert_eq!(
ep_store
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
third_response_id
);
});
// second is reported as rejected
let (reject_request, _) = requests.reject.next().await.unwrap();
cx.run_until_parked();
assert_eq!(
&reject_request.rejections,
&[
EditPredictionRejection {
request_id: cancelled_id,
reason: EditPredictionRejectReason::Canceled,
was_shown: false,
model_version: Some("zeta2:test-canceled-second".to_string()),
e2e_latency_ms: None,
},
EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Replaced,
was_shown: false,
model_version: None,
// 2 throttle waits (for 2nd and 3rd requests) elapsed
// between this request's start and response.
e2e_latency_ms: Some(2 * EditPredictionStore::THROTTLE_TIMEOUT.as_millis()),
}
]
);
}
#[gpui::test]
async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n",
"bar.md": "Hola!\nComo\nAdios\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
ep_store.register_buffer(&buffer, &project, cx);
});
// First edit request - no prior edit, so not throttled.
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
edit_response_tx.send(empty_response()).unwrap();
cx.run_until_parked();
let diagnostic = lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: "Sentence is incomplete".to_string(),
..Default::default()
};
// First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
project.update(cx, |project, cx| {
project.lsp_store().update(cx, |lsp_store, cx| {
lsp_store
.update_diagnostics(
LanguageServerId(0),
lsp::PublishDiagnosticsParams {
uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
diagnostics: vec![diagnostic],
version: None,
},
None,
language::DiagnosticSourceKind::Pushed,
&[],
cx,
)
.unwrap();
});
});
let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
jump_response_tx.send(empty_response()).unwrap();
cx.run_until_parked();
// Second edit request - should be throttled by the first edit.
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
assert_no_predict_request_ready(&mut requests.predict);
// Second jump request - should be throttled by the first jump.
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_diagnostics(
project.clone(),
DiagnosticSearchScope::Global,
cx,
);
});
assert_no_predict_request_ready(&mut requests.predict);
// Wait for both throttles to expire.
cx.background_executor
.advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
cx.background_executor.run_until_parked();
cx.run_until_parked();
// Both requests should now go through.
let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
response_tx_1.send(empty_response()).unwrap();
cx.run_until_parked();
let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
response_tx_2.send(empty_response()).unwrap();
cx.run_until_parked();
}
#[gpui::test]
async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(1, 3));
// Enqueue two refresh calls in the same synchronous frame (no yielding).
// Both `cx.spawn` tasks are created before either executes, so they both
// capture the same `proceed_count_at_enqueue`. Only the first task should
// pass the deduplication gate; the second should be skipped.
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
// Let both spawned tasks run to completion (including any throttle waits).
cx.run_until_parked();
// Exactly one prediction request should have been sent.
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(&request, SIMPLE_DIFF))
.unwrap();
cx.run_until_parked();
// No second request should be pending.
assert_no_predict_request_ready(&mut requests.predict);
}
#[gpui::test]
async fn test_rejections_flushing(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
ep_store.update(cx, |ep_store, cx| {
ep_store.reject_prediction(
EditPredictionId("test-1".into()),
EditPredictionRejectReason::Discarded,
false,
None,
None,
cx,
);
ep_store.reject_prediction(
EditPredictionId("test-2".into()),
EditPredictionRejectReason::Canceled,
true,
None,
None,
cx,
);
});
cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
cx.run_until_parked();
let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
respond_tx.send(()).unwrap();
// batched
assert_eq!(reject_request.rejections.len(), 2);
assert_eq!(
reject_request.rejections[0],
EditPredictionRejection {
request_id: "test-1".to_string(),
reason: EditPredictionRejectReason::Discarded,
was_shown: false,
model_version: None,
e2e_latency_ms: None
}
);
assert_eq!(
reject_request.rejections[1],
EditPredictionRejection {
request_id: "test-2".to_string(),
reason: EditPredictionRejectReason::Canceled,
was_shown: true,
model_version: None,
e2e_latency_ms: None
}
);
// Reaching batch size limit sends without debounce
ep_store.update(cx, |ep_store, cx| {
for i in 0..70 {
ep_store.reject_prediction(
EditPredictionId(format!("batch-{}", i).into()),
EditPredictionRejectReason::Discarded,
false,
None,
None,
cx,
);
}
});
// First MAX/2 items are sent immediately
cx.run_until_parked();
let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
respond_tx.send(()).unwrap();
assert_eq!(reject_request.rejections.len(), 50);
assert_eq!(reject_request.rejections[0].request_id, "batch-0");
assert_eq!(reject_request.rejections[49].request_id, "batch-49");
// Remaining items are debounced with the next batch
cx.executor().advance_clock(Duration::from_secs(15));
cx.run_until_parked();
let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
respond_tx.send(()).unwrap();
assert_eq!(reject_request.rejections.len(), 20);
assert_eq!(reject_request.rejections[0].request_id, "batch-50");
assert_eq!(reject_request.rejections[19].request_id, "batch-69");
// Request failure
ep_store.update(cx, |ep_store, cx| {
ep_store.reject_prediction(
EditPredictionId("retry-1".into()),
EditPredictionRejectReason::Discarded,
false,
None,
None,
cx,
);
});
cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
cx.run_until_parked();
let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
assert_eq!(reject_request.rejections.len(), 1);
assert_eq!(reject_request.rejections[0].request_id, "retry-1");
// Simulate failure
drop(_respond_tx);
// Add another rejection
ep_store.update(cx, |ep_store, cx| {
ep_store.reject_prediction(
EditPredictionId("retry-2".into()),
EditPredictionRejectReason::Discarded,
false,
None,
None,
cx,
);
});
cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
cx.run_until_parked();
// Retry should include both the failed item and the new one
let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
respond_tx.send(()).unwrap();
assert_eq!(reject_request.rejections.len(), 2);
assert_eq!(reject_request.rejections[0].request_id, "retry-1");
assert_eq!(reject_request.rejections[1].request_id, "retry-2");
}
#[gpui::test]
fn test_active_buffer_diagnostics_fetching(cx: &mut TestAppContext) {
let diagnostic_marker: TextRangeMarker = ('«', '»').into();
let search_range_marker: TextRangeMarker = ('[', ']').into();
let (text, mut ranges) = marked_text_ranges_by(
indoc! {r#"
fn alpha() {
let «first_value» = 1;
}
[fn beta() {
let «second_value» = 2;
let third_value = second_value + missing_symbol;
}ˇ]
fn gamma() {
let «fourth_value» = missing_other_symbol;
}
"#},
vec![diagnostic_marker.clone(), search_range_marker.clone()],
);
let diagnostic_ranges = ranges.remove(&diagnostic_marker).unwrap_or_default();
let search_ranges = ranges.remove(&search_range_marker).unwrap_or_default();
let buffer = cx.new(|cx| Buffer::local(&text, cx));
buffer.update(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let diagnostics = DiagnosticSet::new(
diagnostic_ranges
.iter()
.enumerate()
.map(|(index, range)| DiagnosticEntry {
range: snapshot.offset_to_point_utf16(range.start)
..snapshot.offset_to_point_utf16(range.end),
diagnostic: Diagnostic {
severity: match index {
0 => DiagnosticSeverity::WARNING,
1 => DiagnosticSeverity::ERROR,
_ => DiagnosticSeverity::HINT,
},
message: match index {
0 => "first warning".to_string(),
1 => "second error".to_string(),
_ => "third hint".to_string(),
},
group_id: index + 1,
is_primary: true,
source_kind: language::DiagnosticSourceKind::Pushed,
..Diagnostic::default()
},
}),
&snapshot,
);
buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let search_range = snapshot.offset_to_point(search_ranges[0].start)
..snapshot.offset_to_point(search_ranges[0].end);
let active_buffer_diagnostics = zeta::active_buffer_diagnostics(&snapshot, search_range, 5, 0);
assert_eq!(
active_buffer_diagnostics,
vec![zeta_prompt::ActiveBufferDiagnostic {
severity: Some(1),
message: "second error".to_string(),
snippet: " let second_value = 2;".to_string(),
snippet_buffer_row_range: 5..5,
diagnostic_range_in_snippet: 8..20,
}]
);
let active_buffer_diagnostics =
zeta::active_buffer_diagnostics(&snapshot, Point::new(0, 0)..snapshot.max_point(), 5, 100);
assert_eq!(
active_buffer_diagnostics,
vec![
zeta_prompt::ActiveBufferDiagnostic {
severity: Some(1),
message: "second error".to_string(),
snippet: String::new(),
snippet_buffer_row_range: 5..5,
diagnostic_range_in_snippet: 0..0,
},
zeta_prompt::ActiveBufferDiagnostic {
severity: Some(2),
message: "first warning".to_string(),
snippet: String::new(),
snippet_buffer_row_range: 1..1,
diagnostic_range_in_snippet: 0..0,
},
zeta_prompt::ActiveBufferDiagnostic {
severity: Some(4),
message: "third hint".to_string(),
snippet: String::new(),
snippet_buffer_row_range: 10..10,
diagnostic_range_in_snippet: 0..0,
},
]
);
let buffer = cx.new(|cx| {
Buffer::local(
indoc! {"
one
two
three
four
five
"},
cx,
)
});
buffer.update(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let diagnostics = DiagnosticSet::new(
vec![
DiagnosticEntry {
range: text::PointUtf16::new(0, 0)..text::PointUtf16::new(0, 3),
diagnostic: Diagnostic {
severity: DiagnosticSeverity::ERROR,
message: "row zero".to_string(),
group_id: 1,
is_primary: true,
source_kind: language::DiagnosticSourceKind::Pushed,
..Diagnostic::default()
},
},
DiagnosticEntry {
range: text::PointUtf16::new(2, 0)..text::PointUtf16::new(2, 5),
diagnostic: Diagnostic {
severity: DiagnosticSeverity::WARNING,
message: "row two".to_string(),
group_id: 2,
is_primary: true,
source_kind: language::DiagnosticSourceKind::Pushed,
..Diagnostic::default()
},
},
DiagnosticEntry {
range: text::PointUtf16::new(4, 0)..text::PointUtf16::new(4, 4),
diagnostic: Diagnostic {
severity: DiagnosticSeverity::INFORMATION,
message: "row four".to_string(),
group_id: 3,
is_primary: true,
source_kind: language::DiagnosticSourceKind::Pushed,
..Diagnostic::default()
},
},
],
&snapshot,
);
buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let active_buffer_diagnostics =
zeta::active_buffer_diagnostics(&snapshot, Point::new(2, 0)..Point::new(4, 0), 3, 0);
assert_eq!(
active_buffer_diagnostics
.iter()
.map(|diagnostic| (
diagnostic.severity,
diagnostic.message.clone(),
diagnostic.snippet.clone(),
diagnostic.snippet_buffer_row_range.clone(),
diagnostic.diagnostic_range_in_snippet.clone(),
))
.collect::<Vec<_>>(),
vec![
(
Some(2),
"row two".to_string(),
"three".to_string(),
2..2,
0..5,
),
(
Some(3),
"row four".to_string(),
"five".to_string(),
4..4,
0..4,
),
]
);
}
#[gpui::test]
fn test_active_buffer_diagnostics_collection_limits(cx: &mut TestAppContext) {
let text = (0..25)
.map(|row| format!("line {row}\n"))
.collect::<String>();
let buffer = cx.new(|cx| Buffer::local(&text, cx));
buffer.update(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let diagnostics = DiagnosticSet::new(
(0..25)
.map(|row| DiagnosticEntry {
range: text::PointUtf16::new(row, 0)..text::PointUtf16::new(row, 4),
diagnostic: Diagnostic {
severity: DiagnosticSeverity::ERROR,
message: format!("row {row}"),
group_id: row as usize,
is_primary: true,
source_kind: language::DiagnosticSourceKind::Pushed,
..Diagnostic::default()
},
})
.collect::<Vec<_>>(),
&snapshot,
);
buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let active_buffer_diagnostics =
zeta::active_buffer_diagnostics(&snapshot, Point::new(0, 0)..Point::new(25, 0), 12, 0);
assert_eq!(active_buffer_diagnostics.len(), 20);
assert!(
active_buffer_diagnostics
.iter()
.any(|diagnostic| diagnostic.message == "row 12")
);
assert!(
active_buffer_diagnostics
.iter()
.all(|diagnostic| diagnostic.message != "row 0" && diagnostic.message != "row 24")
);
let text = (0..300)
.map(|row| format!("line {row} has some diagnostic context\n"))
.collect::<String>();
let buffer = cx.new(|cx| Buffer::local(&text, cx));
buffer.update(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let diagnostics = DiagnosticSet::new(
vec![DiagnosticEntry {
range: text::PointUtf16::new(150, 0)..text::PointUtf16::new(150, 4),
diagnostic: Diagnostic {
severity: DiagnosticSeverity::ERROR,
message: "long snippet".to_string(),
group_id: 1,
is_primary: true,
source_kind: language::DiagnosticSourceKind::Pushed,
..Diagnostic::default()
},
}],
&snapshot,
);
buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let active_buffer_diagnostics = zeta::active_buffer_diagnostics(
&snapshot,
Point::new(100, 0)..Point::new(200, 0),
150,
2000,
);
assert_eq!(active_buffer_diagnostics.len(), 1);
assert!(active_buffer_diagnostics[0].snippet.len() <= 512 * 3 + 2);
assert!(active_buffer_diagnostics[0].snippet.len() < text.len());
}
// Generate a model response that would apply the given diff to the active file.
fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
let editable_range =
zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
PredictEditsV3Response {
request_id: Uuid::new_v4().to_string(),
editable_range,
output: new_excerpt,
cursor_offset: None,
model_version: None,
}
}
fn empty_response() -> PredictEditsV3Response {
PredictEditsV3Response {
request_id: Uuid::new_v4().to_string(),
editable_range: 0..0,
output: String::new(),
cursor_offset: None,
model_version: None,
}
}
fn prompt_from_request(request: &PredictEditsV3Request) -> String {
zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
.expect("default zeta prompt formatting should succeed in edit prediction tests")
}
fn assert_no_predict_request_ready(
requests: &mut mpsc::UnboundedReceiver<(
PredictEditsV3Request,
oneshot::Sender<PredictEditsV3Response>,
)>,
) {
if requests.next().now_or_never().flatten().is_some() {
panic!("Unexpected prediction request while throttled.");
}
}
struct RequestChannels {
predict: mpsc::UnboundedReceiver<(
PredictEditsV3Request,
oneshot::Sender<PredictEditsV3Response>,
)>,
reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
settled: mpsc::UnboundedReceiver<SubmitEditPredictionSettledBody>,
}
fn init_test_with_fake_client(
cx: &mut TestAppContext,
) -> (Entity<EditPredictionStore>, RequestChannels) {
init_test_with_fake_client_and_legacy_data_collection(cx, None)
}
fn init_test_with_fake_client_and_legacy_data_collection(
cx: &mut TestAppContext,
legacy_data_collection_choice: Option<&str>,
) -> (Entity<EditPredictionStore>, RequestChannels) {
cx.update(move |cx| {
cx.set_global(AppDatabase::test_new());
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
zlog::init_test();
if let Some(legacy_data_collection_choice) = legacy_data_collection_choice {
KeyValueStore::global(cx)
.write_kvp(
ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
legacy_data_collection_choice.to_string(),
)
.now_or_never()
.expect("legacy data collection write should complete immediately")
.expect("legacy data collection write should succeed");
}
let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
let (settled_req_tx, settled_req_rx) = mpsc::unbounded();
let http_client = FakeHttpClient::create({
move |req| {
let uri = req.uri().path().to_string();
let mut body = req.into_body();
let predict_req_tx = predict_req_tx.clone();
let reject_req_tx = reject_req_tx.clone();
let settled_req_tx = settled_req_tx.clone();
async move {
let resp = match uri.as_str() {
"/client/llm_tokens" => serde_json::to_string(&json!({
"token": "test"
}))
.unwrap(),
"/predict_edits/v3" => {
let mut buf = Vec::new();
body.read_to_end(&mut buf).await.ok();
let decompressed = zstd::decode_all(&buf[..]).unwrap();
let req = serde_json::from_slice(&decompressed).unwrap();
let (res_tx, res_rx) = oneshot::channel();
predict_req_tx.unbounded_send((req, res_tx)).unwrap();
serde_json::to_string(&res_rx.await?).unwrap()
}
"/predict_edits/reject" => {
let mut buf = Vec::new();
body.read_to_end(&mut buf).await.ok();
let req = serde_json::from_slice(&buf).unwrap();
let (res_tx, res_rx) = oneshot::channel();
reject_req_tx.unbounded_send((req, res_tx)).unwrap();
serde_json::to_string(&res_rx.await?).unwrap()
}
"/predict_edits/settled" => {
let mut buf = Vec::new();
body.read_to_end(&mut buf).await.ok();
let req = serde_json::from_slice(&buf).unwrap();
settled_req_tx.unbounded_send(req).unwrap();
serde_json::to_string(&SubmitEditPredictionSettledResponse {}).unwrap()
}
_ => {
panic!("Unexpected path: {}", uri)
}
};
Ok(Response::builder().body(resp.into()).unwrap())
}
}
});
let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
client.cloud_client().set_credentials(1, "test".into());
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(cx);
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
ep_store,
RequestChannels {
predict: predict_req_rx,
reject: reject_req_rx,
settled: settled_req_rx,
},
)
})
}
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
});
let edit_preview = cx
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
let prediction = EditPrediction {
edits,
cursor_position: None,
editable_range: None,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
inputs: ZetaPromptInput {
events: Default::default(),
related_files: Default::default(),
active_buffer_diagnostics: vec![],
cursor_path: Path::new("").into(),
cursor_excerpt: "".into(),
cursor_offset_in_excerpt: 0,
excerpt_start_row: None,
excerpt_ranges: Default::default(),
syntax_ranges: None,
in_open_source_repo: false,
can_collect_data: false,
repo_url: None,
},
model_version: None,
};
cx.update(|cx| {
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(2..2, "REM".into()), (6..8, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(3..3, "EM".into()), (7..9, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
vec![(4..4, "M".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}
#[gpui::test]
async fn test_clean_up_diff(cx: &mut TestAppContext) {
init_test(cx);
assert_eq!(
apply_edit_prediction(
indoc! {"
fn main() {
let word_1 = \"lorem\";
let range = word.len()..word.len();
}
"},
indoc! {"
fn main() {
let word_1 = \"lorem\";
let range = word_1.len()..word_1.len();
}
"},
cx,
)
.await,
indoc! {"
fn main() {
let word_1 = \"lorem\";
let range = word_1.len()..word_1.len();
}
"},
);
assert_eq!(
apply_edit_prediction(
indoc! {"
fn main() {
let story = \"the quick\"
}
"},
indoc! {"
fn main() {
let story = \"the quick brown fox jumps over the lazy dog\";
}
"},
cx,
)
.await,
indoc! {"
fn main() {
let story = \"the quick brown fox jumps over the lazy dog\";
}
"},
);
}
#[gpui::test]
async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
init_test(cx);
let buffer_content = "lorem\n";
let completion_response = "lorem\nipsum\n";
assert_eq!(
apply_edit_prediction(buffer_content, completion_response, cx).await,
"lorem\nipsum\n"
);
}
#[gpui::test]
async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
// Test that zeta2's newline normalization logic doesn't insert spurious newlines.
// When the buffer ends without a trailing newline, but the model returns output
// with a trailing newline, zeta2 should normalize both sides before diffing
// so no spurious newline is inserted.
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
// Single line buffer with no trailing newline
fs.insert_tree(
"/root",
json!({
"foo.txt": "hello"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("root/foo.txt"), cx)
.unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(0, 5));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
// Model returns output WITH a trailing newline, even though the buffer doesn't have one.
// Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
let excerpt_length = request.input.cursor_excerpt.len();
let response = PredictEditsV3Response {
request_id: Uuid::new_v4().to_string(),
output: "hello world\n".to_string(),
editable_range: 0..excerpt_length,
model_version: None,
cursor_offset: None,
};
respond_tx.send(response).unwrap();
cx.run_until_parked();
// The prediction should insert " world" without adding a newline
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer, None, &project, cx)
.expect("should have prediction");
let edits: Vec<_> = prediction
.edits
.iter()
.map(|(range, text)| {
let snapshot = buffer.read(cx).snapshot();
(range.to_offset(&snapshot), text.clone())
})
.collect();
assert_eq!(edits, vec![(5..5, " world".into())]);
});
}
#[gpui::test]
async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.txt": "hello"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("root/foo.txt"), cx)
.unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(0, 5));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (request, respond_tx) = requests.predict.next().await.unwrap();
let excerpt_length = request.input.cursor_excerpt.len();
respond_tx
.send(PredictEditsV3Response {
request_id: Uuid::new_v4().to_string(),
output: "hello world".to_string(),
editable_range: 0..excerpt_length,
model_version: None,
cursor_offset: Some(5),
})
.unwrap();
cx.run_until_parked();
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer, None, &project, cx)
.expect("should have prediction");
let snapshot = buffer.read(cx).snapshot();
let edits: Vec<_> = prediction
.edits
.iter()
.map(|(range, text)| (range.to_offset(&snapshot), text.clone()))
.collect();
assert_eq!(edits, vec![(5..5, " world".into())]);
});
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(AppDatabase::test_new());
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
}
async fn apply_edit_prediction(
buffer_content: &str,
completion_response: &str,
cx: &mut TestAppContext,
) -> String {
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let (ep_store, response) = make_test_ep_store(&project, cx).await;
*response.lock() = completion_response.to_string();
let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
buffer.update(cx, |buffer, cx| {
buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
});
buffer.read_with(cx, |buffer, _| buffer.text())
}
async fn run_edit_prediction(
buffer: &Entity<Buffer>,
project: &Entity<Project>,
ep_store: &Entity<EditPredictionStore>,
cx: &mut TestAppContext,
) -> EditPrediction {
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(buffer, &project, cx)
});
cx.background_executor.run_until_parked();
let prediction_task = ep_store.update(cx, |ep_store, cx| {
ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
});
prediction_task.await.unwrap().unwrap().prediction.unwrap()
}
async fn make_test_ep_store(
project: &Entity<Project>,
cx: &mut TestAppContext,
) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
let default_response = "hello world\n".to_string();
let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
let http_client = FakeHttpClient::create({
let completion_response = completion_response.clone();
let mut next_request_id = 0;
move |req| {
let completion_response = completion_response.clone();
let method = req.method().clone();
let uri = req.uri().path().to_string();
let mut body = req.into_body();
async move {
match (method, uri.as_str()) {
(Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&CreateLlmTokenResponse {
token: LlmToken("the-llm-token".to_string()),
})
.unwrap()
.into(),
)
.unwrap()),
(Method::POST, "/predict_edits/v3") => {
let mut buf = Vec::new();
body.read_to_end(&mut buf).await.ok();
let decompressed = zstd::decode_all(&buf[..]).unwrap();
let req: PredictEditsV3Request =
serde_json::from_slice(&decompressed).unwrap();
next_request_id += 1;
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsV3Response {
request_id: format!("request-{next_request_id}"),
editable_range: 0..req.input.cursor_excerpt.len(),
output: completion_response.lock().clone(),
model_version: None,
cursor_offset: None,
})
.unwrap()
.into(),
)
.unwrap())
}
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".to_string().into())
.unwrap()),
}
}
}
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let _server = FakeServer::for_client(42, &client, cx).await;
let ep_store = cx.new(|cx| {
let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
for worktree in worktrees {
let worktree_id = worktree.read(cx).id();
ep_store
.get_or_init_project(project, cx)
.license_detection_watchers
.entry(worktree_id)
.or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
}
ep_store
});
(ep_store, completion_response)
}
fn to_completion_edits(
iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<Anchor>, Arc<str>)> {
let buffer = buffer.read(cx);
iterator
.into_iter()
.map(|(range, text)| {
(
buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
text,
)
})
.collect()
}
fn from_completion_edits(
editor_edits: &[(Range<Anchor>, Arc<str>)],
buffer: &Entity<Buffer>,
cx: &App,
) -> Vec<(Range<usize>, Arc<str>)> {
let buffer = buffer.read(cx);
editor_edits
.iter()
.map(|(range, text)| {
(
range.start.to_offset(buffer)..range.end.to_offset(buffer),
text.clone(),
)
})
.collect()
}
#[gpui::test]
async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/project",
serde_json::json!({
"main.rs": "fn main() {\n \n}\n"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let http_client = FakeHttpClient::create(|_req| async move {
Ok(gpui::http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap())
});
let client =
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
let buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("/project/main.rs"), cx)
.unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx)
});
cx.background_executor.run_until_parked();
let completion_task = ep_store.update(cx, |ep_store, cx| {
ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
});
let result = completion_task.await;
assert!(
result.is_err(),
"Without authentication and without custom URL, prediction should fail"
);
}
#[gpui::test]
async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
let collab_replica = clock::ReplicaId::new(10);
let anchor = buffer.read_with(cx, |buffer, _| {
buffer.snapshot().anchor_before(Point::new(row, 0))
});
let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
id: 1,
start: anchor,
end: anchor,
reversed: false,
goal: SelectionGoal::None,
}]);
buffer.update(cx, |buffer, cx| {
buffer.apply_ops(
[Operation::UpdateSelections {
selections,
lamport_timestamp: clock::Lamport {
replica_id: collab_replica,
value: 1,
},
line_mode: false,
cursor_shape: CursorShape::Bar,
}],
cx,
);
});
}
fn publish_diagnostics(
uri_path: &'static str,
rows: &[u32],
project: &Entity<Project>,
cx: &mut TestAppContext,
) {
let diagnostics: Vec<_> = rows
.iter()
.map(|&row| lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: format!("error at row {row}"),
..Default::default()
})
.collect();
project.update(cx, |project, cx| {
project.lsp_store().update(cx, |lsp_store, cx| {
lsp_store
.update_diagnostics(
LanguageServerId(0),
lsp::PublishDiagnosticsParams {
uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
diagnostics,
version: None,
},
None,
language::DiagnosticSourceKind::Pushed,
&[],
cx,
)
.expect("failed to update diagnostics");
});
});
}
init_test(cx);
let mut lines = String::new();
for i in 0..60 {
lines.push_str(&format!("line {i}\n"));
}
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"active.txt": lines,
"collab_file.txt": "error here\nsecond line\n",
"free_file.txt": "another error\nsecond line\n",
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let active_buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("/root/active.txt"), cx)
.expect("active.txt not found");
project.set_active_path(Some(path.clone()), cx);
project.open_buffer(path, cx)
})
.await
.expect("failed to open active buffer");
set_collaborator_cursor(&active_buffer, 5, cx);
publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
cx.run_until_parked();
let cursor_point = Point::new(25, 0);
let empty_search_range: Range<Point> = Default::default();
let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
let result = EditPredictionStore::next_diagnostic_location(
active_buffer.clone(),
&snapshot,
empty_search_range.clone(),
cursor_point,
&project,
&mut cx.to_async(),
)
.await
.expect("next_diagnostic_location failed");
let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
let result_row = result_buffer.read_with(cx, |buffer, _| {
result_anchor.to_point(&buffer.snapshot()).row
});
assert_ne!(
result_row, 3,
"row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
);
assert!(
result_row == 25 || result_row == 50,
"expected row 25 or 50, got {result_row}"
);
let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
let near_cursor_point = Point::new(4, 0);
let result_near = EditPredictionStore::next_diagnostic_location(
active_buffer.clone(),
&snapshot_near,
empty_search_range.clone(),
near_cursor_point,
&project,
&mut cx.to_async(),
)
.await
.expect("next_diagnostic_location failed");
let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
let near_row =
active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
assert_eq!(
near_row, 3,
"row 3 should be included when local cursor (row 4) is also near the collaborator"
);
let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
let far_cursor_point = Point::new(50, 0);
let result_far = EditPredictionStore::next_diagnostic_location(
active_buffer.clone(),
&snapshot_far,
empty_search_range.clone(),
far_cursor_point,
&project,
&mut cx.to_async(),
)
.await
.expect("next_diagnostic_location failed");
let (_, far_anchor) = result_far.expect("expected a diagnostic location");
let far_row =
active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
assert_eq!(
far_row, 50,
"row 50 is near local cursor (row 50) and far from collaborator, should be picked"
);
publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
cx.run_until_parked();
let collab_buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("/root/collab_file.txt"), cx)
.expect("collab_file.txt not found");
project.open_buffer(path, cx)
})
.await
.expect("failed to open collab buffer");
set_collaborator_cursor(&collab_buffer, 0, cx);
cx.run_until_parked();
let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
let result_cross = EditPredictionStore::next_diagnostic_location(
active_buffer.clone(),
&snapshot_cross,
no_same_file_search_range,
Point::new(0, 0),
&project,
&mut cx.to_async(),
)
.await
.expect("cross-file next_diagnostic_location failed");
let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
buffer
.file()
.expect("buffer should have a file")
.full_path(cx)
});
assert_eq!(
cross_path,
Path::new(path!("root/free_file.txt")),
"should skip collab_file.txt (has collaborator) and pick free_file.txt"
);
}
#[gpui::test]
async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
// Buffer with two clearly separated regions:
// Region A = lines 0-9 (offsets 0..50)
// Region B = lines 20-29 (offsets 105..155)
// A big gap in between so edits in one region never overlap the other.
let mut content = String::new();
for i in 0..30 {
content.push_str(&format!("line {i:02}\n"));
}
fs.insert_tree(
"/root",
json!({
"foo.md": content.clone()
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
type SettledEventRecord = (EditPredictionId, String);
let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
let settled_events = settled_events.clone();
ep_store.settled_event_callback = Some(Box::new(move |id, text| {
settled_events.lock().push((id, text));
}));
});
// --- Phase 1: edit in region A and enqueue prediction A ---
buffer.update(cx, |buffer, cx| {
// Edit at the start of line 0.
buffer.edit(vec![(0..0, "ADDED ")], None, cx);
});
cx.run_until_parked();
let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
let edit_preview_a = buffer
.read_with(cx, |buffer, cx| {
buffer.preview_edits(empty_edits.clone(), cx)
})
.await;
// Region A: first 10 lines of the buffer.
let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
ep_store.update(cx, |ep_store, cx| {
ep_store.enqueue_settled_prediction(
EditPredictionId("prediction-a".into()),
&project,
&buffer,
&snapshot_a,
editable_region_a.clone(),
&edit_preview_a,
None,
None,
Duration::from_secs(0),
cx,
);
});
// --- Phase 2: repeatedly edit in region A to keep it unsettled ---
// Let the worker process the channel message before we start advancing.
cx.run_until_parked();
for region_a_edit_offset in (5..).take(3) {
// Edit inside region A (not at the boundary) so `last_edit_at` is
// updated before the worker's next wake.
buffer.update(cx, |buffer, cx| {
buffer.edit(
vec![(region_a_edit_offset..region_a_edit_offset, "x")],
None,
cx,
);
});
cx.run_until_parked();
cx.executor()
.advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
cx.run_until_parked();
assert!(
settled_events.lock().is_empty(),
"no settled events should fire while region A is still being edited"
);
}
// Still nothing settled.
assert!(settled_events.lock().is_empty());
// --- Phase 3: edit in distinct region B, enqueue prediction B ---
// Advance a small amount so B's quiescence window starts later than A's,
// but not so much that A settles (A's last edit was at the start of
// iteration 3, and it needs a full Q to settle).
cx.executor()
.advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
cx.run_until_parked();
assert!(settled_events.lock().is_empty());
let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
buffer.update(cx, |buffer, cx| {
buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
});
cx.run_until_parked();
let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let edit_preview_b = buffer
.read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
.await;
let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
ep_store.update(cx, |ep_store, cx| {
ep_store.enqueue_settled_prediction(
EditPredictionId("prediction-b".into()),
&project,
&buffer,
&snapshot_b2,
editable_region_b.clone(),
&edit_preview_b,
None,
None,
Duration::from_secs(0),
cx,
);
});
cx.run_until_parked();
assert!(
settled_events.lock().is_empty(),
"neither prediction should have settled yet"
);
// --- Phase 4: let enough time pass for region A to settle ---
// A's last edit was at T_a (during the last loop iteration). The worker is
// sleeping until T_a + Q. We advance just enough to reach that wake time
// (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
// 3*Q/2). At that point A has been quiet for Q and settles, but B was
// enqueued only Q/4 ago and stays pending.
cx.executor()
.advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
cx.run_until_parked();
{
let events = settled_events.lock().clone();
assert_eq!(
events.len(),
1,
"prediction and capture_sample for A should have settled, got: {events:?}"
);
assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
}
// --- Phase 5: let more time pass for region B to settle ---
// B's last edit was Q/4 before A settled. The worker rescheduled to
// B's last_edit_at + Q, which is 3Q/4 from now.
cx.executor()
.advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
cx.run_until_parked();
{
let events = settled_events.lock().clone();
assert_eq!(
events.len(),
2,
"both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
);
assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
}
}
#[gpui::test]
async fn test_edit_prediction_settled_omits_body_when_data_collection_is_disabled(
cx: &mut TestAppContext,
) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"foo.md": "sensitive source\n"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx);
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
cx.update(|cx| to_completion_edits([(0..9, "replacement".into())], &buffer, cx).into());
let edit_preview = buffer
.read_with(cx, |buffer, cx| buffer.preview_edits(edits, cx))
.await;
ep_store.update(cx, |ep_store, cx| {
ep_store.enqueue_settled_prediction(
EditPredictionId("prediction-private".into()),
&project,
&buffer,
&snapshot,
0..snapshot.len(),
&edit_preview,
Some(ExampleSpec {
name: "test example".to_string(),
repository_url: "https://example.com/repo".to_string(),
revision: "rev".to_string(),
tags: Vec::new(),
reasoning: None,
uncommitted_diff: String::new(),
cursor_path: Path::new("foo.md").into(),
cursor_position: "0".to_string(),
edit_history: "sensitive edit history".to_string(),
expected_patches: vec!["sensitive patch".to_string()],
rejected_patch: None,
telemetry: None,
human_feedback: Vec::new(),
rating: None,
}),
Some("test-model".to_string()),
Duration::from_millis(42),
cx,
);
});
cx.run_until_parked();
cx.executor()
.advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE);
cx.run_until_parked();
let settled_request = requests
.settled
.next()
.await
.expect("settled request should be sent");
assert!(!settled_request.can_collect_data);
assert_eq!(settled_request.settled_editable_region, None);
assert_eq!(settled_request.example, None);
}
#[gpui::test]
fn test_buffer_path_with_id_fallback_for_untitled_buffers(cx: &mut TestAppContext) {
let buffer_1 = cx.new(|cx| Buffer::local("one", cx));
let buffer_2 = cx.new(|cx| Buffer::local("two", cx));
let snapshot_1 = buffer_1.read_with(cx, |buffer, _| buffer.text_snapshot());
let snapshot_2 = buffer_2.read_with(cx, |buffer, _| buffer.text_snapshot());
let path_1 = cx.read(|cx| buffer_path_with_id_fallback(None, &snapshot_1, cx));
let path_2 = cx.read(|cx| buffer_path_with_id_fallback(None, &snapshot_2, cx));
assert_eq!(
path_1.as_ref(),
Path::new(&format!("untitled-{}", snapshot_1.remote_id()))
);
assert_eq!(
path_2.as_ref(),
Path::new(&format!("untitled-{}", snapshot_2.remote_id()))
);
assert_ne!(path_1.as_ref(), path_2.as_ref());
}
#[gpui::test]
async fn test_data_collection_disabled_by_default(cx: &mut TestAppContext) {
let (ep_store, _channels) = init_test_with_fake_client(cx);
cx.update(|cx| {
assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
});
}
#[gpui::test]
async fn test_data_collection_enabled_via_legacy_kv_store(cx: &mut TestAppContext) {
let (ep_store, _channels) =
init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
cx.update(|cx| {
assert!(ep_store.read(cx).is_data_collection_enabled(cx));
});
}
#[gpui::test]
async fn test_data_collection_default_uses_cached_legacy_value(cx: &mut TestAppContext) {
let (ep_store, _channels) =
init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
cx.update(|cx| {
assert!(ep_store.read(cx).is_data_collection_enabled(cx));
});
cx.update(|cx| KeyValueStore::global(cx))
.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
.await
.unwrap();
cx.update(|cx| {
assert!(ep_store.read(cx).is_data_collection_enabled(cx));
});
}
#[gpui::test]
async fn test_data_collection_setting_overrides_kv_store(cx: &mut TestAppContext) {
let (ep_store, _channels) =
init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
// An explicit false in settings.json wins over the KV store.
cx.update_global::<SettingsStore, _>(|settings, cx| {
settings.update_user_settings(cx, |content| {
content
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.allow_data_collection = Some(EditPredictionDataCollectionChoice::No);
});
});
cx.update(|cx| {
assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
});
}
#[gpui::test]
async fn test_data_collection_enabled_via_setting(cx: &mut TestAppContext) {
let (ep_store, _channels) = init_test_with_fake_client(cx);
cx.update_global::<SettingsStore, _>(|settings, cx| {
settings.update_user_settings(cx, |content| {
content
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.allow_data_collection = Some(EditPredictionDataCollectionChoice::Yes);
});
});
cx.update(|cx| {
assert!(ep_store.read(cx).is_data_collection_enabled(cx));
});
}
#[gpui::test]
async fn test_data_collection_always_enabled_for_staff(cx: &mut TestAppContext) {
let (ep_store, _channels) = init_test_with_fake_client(cx);
cx.update(|cx| {
cx.set_staff(true);
assert!(ep_store.read(cx).is_data_collection_enabled(cx));
});
}
#[gpui::test]
async fn test_data_collection_disabled_by_organization_configuration(cx: &mut TestAppContext) {
let (ep_store, _channels) = init_test_with_fake_client(cx);
cx.update_global::<SettingsStore, _>(|settings, cx| {
settings.update_user_settings(cx, |content| {
content
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.allow_data_collection = Some(EditPredictionDataCollectionChoice::Yes);
});
});
let user_store = cx.update(|cx| ep_store.read(cx).user_store.clone());
cx.update(|cx| {
user_store.update(cx, |user_store, cx| {
user_store.set_current_organization_configuration_for_test(
Arc::new(Organization {
id: OrganizationId("org-1".into()),
name: "Org 1".into(),
is_personal: false,
}),
OrganizationConfiguration {
is_zed_model_provider_enabled: true,
is_agent_thread_feedback_enabled: true,
is_collaboration_enabled: true,
edit_prediction: OrganizationEditPredictionConfiguration {
is_enabled: true,
is_feedback_enabled: false,
},
},
cx,
);
});
assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
});
}
// When a user had data collection enabled via the legacy KV store (with no explicit
// setting in settings.json), toggle_data_collection must read the *resolved* state
// (true) and write Some(false).
#[gpui::test]
async fn test_toggle_data_collection_from_kv_enabled_state(cx: &mut TestAppContext) {
let (ep_store, _channels) =
init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
cx.update(|cx| {
assert!(
ep_store.read(cx).is_data_collection_enabled(cx),
"data collection should be enabled via KV store before toggle"
);
});
// Simulate what toggle_data_collection does: capture the resolved current
// state, then write its inverse.
let is_currently_enabled = cx.update(|cx| ep_store.read(cx).is_data_collection_enabled(cx));
cx.update_global::<SettingsStore, _>(|settings, cx| {
settings.update_user_settings(cx, |content| {
content
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.allow_data_collection = Some(if is_currently_enabled {
EditPredictionDataCollectionChoice::No
} else {
EditPredictionDataCollectionChoice::Yes
});
});
});
cx.update(|cx| {
assert!(
!ep_store.read(cx).is_data_collection_enabled(cx),
"data collection should be disabled after toggling off from KV-enabled state"
);
});
}
#[gpui::test]
async fn test_upsell_shown_by_default(cx: &mut TestAppContext) {
init_test(cx);
let kvp = cx.update(|cx| KeyValueStore::global(cx));
kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
.await
.ok();
kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.ok();
cx.update(|cx| assert!(should_show_upsell_modal(cx)));
}
#[gpui::test]
async fn test_upsell_dismissed_when_data_collection_choice_in_kv_store(cx: &mut TestAppContext) {
init_test(cx);
// Any value for the data collection key means the old upsell was already
// shown, regardless of whether data collection was accepted or declined.
for value in &["true", "false"] {
cx.update(|cx| KeyValueStore::global(cx))
.write_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), value.to_string())
.await
.unwrap();
cx.update(|cx| {
assert!(
!should_show_upsell_modal(cx),
"upsell should be suppressed when data collection choice is '{value}'"
);
});
}
cx.update(|cx| KeyValueStore::global(cx))
.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
.await
.unwrap();
}
#[gpui::test]
async fn test_upsell_dismissed_when_dismissed_key_set(cx: &mut TestAppContext) {
init_test(cx);
let kvp = cx.update(|cx| KeyValueStore::global(cx));
kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
.await
.ok();
kvp.write_kvp(ZedPredictUpsell::KEY.into(), "1".into())
.await
.unwrap();
cx.update(|cx| assert!(!should_show_upsell_modal(cx)));
kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.unwrap();
}
#[gpui::test]
async fn test_upsell_dismissed_via_dismissable_api(cx: &mut TestAppContext) {
init_test(cx);
let kvp = cx.update(|cx| KeyValueStore::global(cx));
kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
.await
.ok();
kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.ok();
cx.update(|cx| {
assert!(should_show_upsell_modal(cx));
ZedPredictUpsell::set_dismissed(true, cx);
});
cx.run_until_parked();
cx.update(|cx| assert!(!should_show_upsell_modal(cx)));
kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.unwrap();
}
#[ctor::ctor]
fn init_logger() {
zlog::init_test();
}