zeta2: Allow provider to suggest edits in different files (#39110)

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2025-09-29 17:48:58 +02:00 committed by GitHub
parent b7f9fd7d74
commit cda48a3a1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 689 additions and 318 deletions

1
Cargo.lock generated
View file

@ -5126,7 +5126,6 @@ dependencies = [
"client",
"gpui",
"language",
"project",
"workspace-hack",
]

View file

@ -43,15 +43,24 @@ pub struct PredictEditsRequest {
pub prompt_format: PromptFormat,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum PromptFormat {
#[default]
MarkedExcerpt,
LabeledSections,
/// Prompt format intended for use via zeta_cli
OnlySnippets,
}
impl PromptFormat {
pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
}
impl Default for PromptFormat {
fn default() -> Self {
Self::DEFAULT
}
}
impl PromptFormat {
pub fn iter() -> impl Iterator<Item = Self> {
<Self as strum::IntoEnumIterator>::iter()

View file

@ -3,7 +3,6 @@ use anyhow::Result;
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
use project::Project;
use settings::Settings;
use std::{path::Path, time::Duration};
@ -84,7 +83,6 @@ impl EditPredictionProvider for CopilotCompletionProvider {
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
@ -249,7 +247,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
None
} else {
let position = cursor_position.bias_right(buffer);
Some(EditPrediction {
Some(EditPrediction::Local {
id: None,
edits: vec![(position..position, completion_text.into())],
edit_preview: None,

View file

@ -15,5 +15,4 @@ path = "src/edit_prediction.rs"
client.workspace = true
gpui.workspace = true
language.workspace = true
project.workspace = true
workspace-hack.workspace = true

View file

@ -3,7 +3,6 @@ use std::ops::Range;
use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
use language::Buffer;
use project::Project;
// TODO: Find a better home for `Direction`.
//
@ -16,11 +15,19 @@ pub enum Direction {
}
#[derive(Clone)]
pub struct EditPrediction {
/// The ID of the completion, if it has one.
pub id: Option<SharedString>,
pub edits: Vec<(Range<language::Anchor>, String)>,
pub edit_preview: Option<language::EditPreview>,
pub enum EditPrediction {
/// Edits within the buffer that requested the prediction
Local {
id: Option<SharedString>,
edits: Vec<(Range<language::Anchor>, String)>,
edit_preview: Option<language::EditPreview>,
},
/// Jump to a different file from the one that requested the prediction
Jump {
id: Option<SharedString>,
snapshot: language::BufferSnapshot,
target: language::Anchor,
},
}
pub enum DataCollectionState {
@ -83,7 +90,6 @@ pub trait EditPredictionProvider: 'static + Sized {
fn is_refreshing(&self) -> bool;
fn refresh(
&mut self,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
@ -124,7 +130,6 @@ pub trait EditPredictionProviderHandle {
fn is_refreshing(&self, cx: &App) -> bool;
fn refresh(
&self,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
@ -198,14 +203,13 @@ where
fn refresh(
&self,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
cx: &mut App,
) {
self.update(cx, |this, cx| {
this.refresh(project, buffer, cursor_position, debounce, cx)
this.refresh(buffer, cursor_position, debounce, cx)
})
}

View file

@ -2,7 +2,6 @@ use edit_prediction::EditPredictionProvider;
use gpui::{Entity, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
use project::Project;
use std::ops::Range;
use text::{Point, ToOffset};
@ -261,7 +260,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
EditPrediction::Edit { .. } => {
// This is expected for non-Zed providers
}
EditPrediction::Move { .. } => {
EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => {
panic!(
"Non-Zed providers should not show Move predictions (jump functionality)"
);
@ -299,7 +298,7 @@ fn assert_editor_active_move_completion(
.as_ref()
.expect("editor has no active completion");
if let EditPrediction::Move { target, .. } = &completion_state.completion {
if let EditPrediction::MoveWithin { target, .. } = &completion_state.completion {
assert(editor.buffer().read(cx).snapshot(cx), *target);
} else {
panic!("expected move completion");
@ -326,7 +325,7 @@ fn propose_edits<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@ -357,7 +356,7 @@ fn propose_edits_non_zed<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@ -418,7 +417,6 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_debounce: bool,
@ -492,7 +490,6 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_debounce: bool,

View file

@ -638,17 +638,23 @@ enum EditPrediction {
display_mode: EditDisplayMode,
snapshot: BufferSnapshot,
},
Move {
/// Move to a specific location in the active editor
MoveWithin {
target: Anchor,
snapshot: BufferSnapshot,
},
/// Move to a specific location in a different editor (not the active one)
MoveOutside {
target: language::Anchor,
snapshot: BufferSnapshot,
},
}
struct EditPredictionState {
inlay_ids: Vec<InlayId>,
completion: EditPrediction,
completion_id: Option<SharedString>,
invalidation_range: Range<Anchor>,
invalidation_range: Option<Range<Anchor>>,
}
enum EditPredictionSettings {
@ -7175,13 +7181,7 @@ impl Editor {
return None;
}
provider.refresh(
self.project.clone(),
buffer,
cursor_buffer_position,
debounce,
cx,
);
provider.refresh(buffer, cursor_buffer_position, debounce, cx);
Some(())
}
@ -7424,10 +7424,8 @@ impl Editor {
return;
};
self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
match &active_edit_prediction.completion {
EditPrediction::Move { target, .. } => {
EditPrediction::MoveWithin { target, .. } => {
let target = *target;
if let Some(position_map) = &self.last_position_map {
@ -7469,7 +7467,19 @@ impl Editor {
}
}
}
EditPrediction::MoveOutside { snapshot, target } => {
if let Some(workspace) = self.workspace() {
Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
.detach_and_log_err(cx);
}
}
EditPrediction::Edit { edits, .. } => {
self.report_edit_prediction_event(
active_edit_prediction.completion_id.clone(),
true,
cx,
);
if let Some(provider) = self.edit_prediction_provider() {
provider.accept(cx);
}
@ -7522,10 +7532,8 @@ impl Editor {
return;
}
self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
match &active_edit_prediction.completion {
EditPrediction::Move { target, .. } => {
EditPrediction::MoveWithin { target, .. } => {
let target = *target;
self.change_selections(
SelectionEffects::scroll(Autoscroll::newest()),
@ -7536,7 +7544,19 @@ impl Editor {
},
);
}
EditPrediction::MoveOutside { snapshot, target } => {
if let Some(workspace) = self.workspace() {
Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
.detach_and_log_err(cx);
}
}
EditPrediction::Edit { edits, .. } => {
self.report_edit_prediction_event(
active_edit_prediction.completion_id.clone(),
true,
cx,
);
// Find an insertion that starts at the cursor position.
let snapshot = self.buffer.read(cx).snapshot(cx);
let cursor_offset = self.selections.newest::<usize>(cx).head();
@ -7631,6 +7651,36 @@ impl Editor {
);
}
fn open_editor_at_anchor(
snapshot: &language::BufferSnapshot,
target: language::Anchor,
workspace: &Entity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Task<Result<()>> {
workspace.update(cx, |workspace, cx| {
let path = snapshot.file().map(|file| file.full_path(cx));
let Some(path) =
path.and_then(|path| workspace.project().read(cx).find_project_path(path, cx))
else {
return Task::ready(Err(anyhow::anyhow!("Project path not found")));
};
let target = text::ToPoint::to_point(&target, snapshot);
let item = workspace.open_path(path, None, true, window, cx);
window.spawn(cx, async move |cx| {
let Some(editor) = item.await?.downcast::<Editor>() else {
return Ok(());
};
editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(target, window, cx);
})
.ok();
anyhow::Ok(())
})
})
}
pub fn has_active_edit_prediction(&self) -> bool {
self.active_edit_prediction.is_some()
}
@ -7846,7 +7896,10 @@ impl Editor {
.active_edit_prediction
.as_ref()
.is_some_and(|completion| {
let invalidation_range = completion.invalidation_range.to_offset(&multibuffer);
let Some(invalidation_range) = completion.invalidation_range.as_ref() else {
return false;
};
let invalidation_range = invalidation_range.to_offset(&multibuffer);
let invalidation_range = invalidation_range.start..=invalidation_range.end;
!invalidation_range.contains(&offset_selection.head())
})
@ -7882,8 +7935,31 @@ impl Editor {
}
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
let edits = edit_prediction
.edits
let (completion_id, edits, edit_preview) = match edit_prediction {
edit_prediction::EditPrediction::Local {
id,
edits,
edit_preview,
} => (id, edits, edit_preview),
edit_prediction::EditPrediction::Jump {
id,
snapshot,
target,
} => {
self.stale_edit_prediction_in_menu = None;
self.active_edit_prediction = Some(EditPredictionState {
inlay_ids: vec![],
completion: EditPrediction::MoveOutside { snapshot, target },
completion_id: id,
invalidation_range: None,
});
cx.notify();
return Some(());
}
};
let edits = edits
.into_iter()
.flat_map(|(range, new_text)| {
let start = multibuffer.anchor_in_excerpt(excerpt_id, range.start)?;
@ -7928,7 +8004,7 @@ impl Editor {
invalidation_row_range =
move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
let target = first_edit_start;
EditPrediction::Move { target, snapshot }
EditPrediction::MoveWithin { target, snapshot }
} else {
let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true)
&& !self.edit_predictions_hidden_for_vim_mode;
@ -7977,7 +8053,7 @@ impl Editor {
EditPrediction::Edit {
edits,
edit_preview: edit_prediction.edit_preview,
edit_preview,
display_mode,
snapshot,
}
@ -7994,8 +8070,8 @@ impl Editor {
self.active_edit_prediction = Some(EditPredictionState {
inlay_ids,
completion,
completion_id: edit_prediction.id,
invalidation_range,
completion_id,
invalidation_range: Some(invalidation_range),
});
cx.notify();
@ -8581,7 +8657,7 @@ impl Editor {
}
match &active_edit_prediction.completion {
EditPrediction::Move { target, .. } => {
EditPrediction::MoveWithin { target, .. } => {
let target_display_point = target.to_display_point(editor_snapshot);
if self.edit_prediction_requires_modifier() {
@ -8666,6 +8742,28 @@ impl Editor {
window,
cx,
),
EditPrediction::MoveOutside { snapshot, .. } => {
let file_name = snapshot
.file()
.map(|file| file.file_name(cx))
.unwrap_or("untitled");
let mut element = self
.render_edit_prediction_line_popover(
format!("Jump to {file_name}"),
Some(IconName::ZedPredict),
window,
cx,
)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
let origin_x = text_bounds.size.width / 2. - size.width / 2.;
let origin_y = text_bounds.size.height - size.height - px(30.);
let origin = text_bounds.origin + gpui::Point::new(origin_x, origin_y);
element.prepaint_at(origin, window, cx);
Some((element, origin))
}
}
}
@ -8730,13 +8828,13 @@ impl Editor {
.items_end()
.when(flag_on_right, |el| el.items_start())
.child(if flag_on_right {
self.render_edit_prediction_line_popover("Jump", None, window, cx)?
self.render_edit_prediction_line_popover("Jump", None, window, cx)
.rounded_bl(px(0.))
.rounded_tl(px(0.))
.border_l_2()
.border_color(border_color)
} else {
self.render_edit_prediction_line_popover("Jump", None, window, cx)?
self.render_edit_prediction_line_popover("Jump", None, window, cx)
.rounded_br(px(0.))
.rounded_tr(px(0.))
.border_r_2()
@ -8776,7 +8874,7 @@ impl Editor {
cx: &mut App,
) -> Option<(AnyElement, gpui::Point<Pixels>)> {
let mut element = self
.render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)?
.render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@ -8816,7 +8914,7 @@ impl Editor {
Some(IconName::ArrowUp),
window,
cx,
)?
)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@ -8835,7 +8933,7 @@ impl Editor {
Some(IconName::ArrowDown),
window,
cx,
)?
)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@ -8882,7 +8980,7 @@ impl Editor {
);
let mut element = self
.render_edit_prediction_line_popover(label, None, window, cx)?
.render_edit_prediction_line_popover(label, None, window, cx)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@ -8909,7 +9007,7 @@ impl Editor {
};
element = self
.render_edit_prediction_line_popover(label, Some(icon), window, cx)?
.render_edit_prediction_line_popover(label, Some(icon), window, cx)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@ -9163,13 +9261,13 @@ impl Editor {
icon: Option<IconName>,
window: &mut Window,
cx: &App,
) -> Option<Stateful<Div>> {
) -> Stateful<Div> {
let padding_right = if icon.is_some() { px(4.) } else { px(8.) };
let keybind = self.render_edit_prediction_accept_keybind(window, cx);
let has_keybind = keybind.is_some();
let result = h_flex()
h_flex()
.id("ep-line-popover")
.py_0p5()
.pl_1()
@ -9215,9 +9313,7 @@ impl Editor {
.mt(px(1.5))
.child(Icon::new(icon).size(IconSize::Small)),
)
});
Some(result)
})
}
fn edit_prediction_line_popover_bg_color(cx: &App) -> Hsla {
@ -9281,7 +9377,7 @@ impl Editor {
.rounded_tl(px(0.))
.overflow_hidden()
.child(div().px_1p5().child(match &prediction.completion {
EditPrediction::Move { target, snapshot } => {
EditPrediction::MoveWithin { target, snapshot } => {
use text::ToPoint as _;
if target.text_anchor.to_point(snapshot).row > cursor_point.row
{
@ -9290,6 +9386,10 @@ impl Editor {
Icon::new(IconName::ZedPredictUp)
}
}
EditPrediction::MoveOutside { .. } => {
// TODO [zeta2] custom icon for external jump?
Icon::new(provider_icon)
}
EditPrediction::Edit { .. } => Icon::new(provider_icon),
}))
.child(
@ -9472,7 +9572,7 @@ impl Editor {
.unwrap_or(true);
match &completion.completion {
EditPrediction::Move {
EditPrediction::MoveWithin {
target, snapshot, ..
} => {
if !supports_jump {
@ -9494,7 +9594,20 @@ impl Editor {
.child(Label::new("Jump to Edit")),
)
}
EditPrediction::MoveOutside { snapshot, .. } => {
let file_name = snapshot
.file()
.map(|file| file.file_name(cx))
.unwrap_or("untitled");
Some(
h_flex()
.px_2()
.gap_2()
.flex_1()
.child(Icon::new(IconName::ZedPredict))
.child(Label::new(format!("Jump to {file_name}"))),
)
}
EditPrediction::Edit {
edits,
edit_preview,
@ -21418,7 +21531,7 @@ impl Editor {
{
self.hide_context_menu(window, cx);
}
self.discard_edit_prediction(false, cx);
self.take_active_edit_prediction(cx);
cx.emit(EditorEvent::Blurred);
cx.notify();
}

View file

@ -8272,7 +8272,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
id: None,
edits: vec![(edit_position..edit_position, "X".into())],
edit_preview: None,

View file

@ -22,7 +22,6 @@ gpui.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true
project.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true

View file

@ -4,7 +4,6 @@ use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Anchor, Buffer, BufferSnapshot};
use project::Project;
use std::{
ops::{AddAssign, Range},
path::Path,
@ -94,7 +93,7 @@ fn completion_from_diff(
edits.push((edit_range, edit_text));
}
EditPrediction {
EditPrediction::Local {
id: None,
edits,
edit_preview: None,
@ -132,7 +131,6 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
buffer_handle: Entity<Buffer>,
cursor_position: Anchor,
debounce: bool,

View file

@ -205,42 +205,48 @@ fn assign_edit_prediction_provider(
}
}
if std::env::var("ZED_ZETA2").is_ok() {
let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
editor.project(),
&client,
&user_store,
cx,
)
});
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
&& let Some(project) = editor.project()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
if let Some(project) = editor.project() {
if std::env::var("ZED_ZETA2").is_ok() {
let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
project.clone(),
&client,
&user_store,
cx,
)
});
}
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else {
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
// TODO [zeta2] handle multibuffers
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
});
}
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
&& let Some(project) = editor.project()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else {
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(buffer, project, cx);
});
}
let provider = cx.new(|_| {
zeta::ZetaEditPredictionProvider::new(
zeta,
project.clone(),
singleton_buffer,
)
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
let provider =
cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}

View file

@ -1316,12 +1316,17 @@ pub struct ZetaEditPredictionProvider {
next_pending_completion_id: usize,
current_completion: Option<CurrentEditPrediction>,
last_request_timestamp: Instant,
project: Entity<Project>,
}
impl ZetaEditPredictionProvider {
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
pub fn new(zeta: Entity<Zeta>, singleton_buffer: Option<Entity<Buffer>>) -> Self {
pub fn new(
zeta: Entity<Zeta>,
project: Entity<Project>,
singleton_buffer: Option<Entity<Buffer>>,
) -> Self {
Self {
zeta,
singleton_buffer,
@ -1329,6 +1334,7 @@ impl ZetaEditPredictionProvider {
next_pending_completion_id: 0,
current_completion: None,
last_request_timestamp: Instant::now(),
project,
}
}
}
@ -1394,7 +1400,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
fn refresh(
&mut self,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
position: language::Anchor,
_debounce: bool,
@ -1403,9 +1408,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
if self.zeta.read(cx).update_required {
return;
}
let Some(project) = project else {
return;
};
if self
.zeta
@ -1433,6 +1435,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
self.next_pending_completion_id += 1;
let last_request_timestamp = self.last_request_timestamp;
let project = self.project.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
@ -1604,7 +1607,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
}
}
Some(edit_prediction::EditPrediction {
Some(edit_prediction::EditPrediction::Local {
id: Some(completion.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(completion.edit_preview.clone()),

View file

@ -1,35 +1,18 @@
use std::{borrow::Cow, ops::Range, sync::Arc};
use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
use anyhow::Context as _;
use cloud_llm_client::predict_edits_v3;
use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
use gpui::{App, AsyncApp, Entity};
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
};
use project::Project;
use util::ResultExt;
use uuid::Uuid;
#[derive(Clone)]
pub struct EditPrediction {
pub id: EditPredictionId,
pub edits: Arc<[(Range<Anchor>, String)]>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
}
impl EditPrediction {
pub fn interpolate(
&self,
new_snapshot: &BufferSnapshot,
) -> Option<Vec<(Range<Anchor>, String)>> {
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
}
}
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(Uuid);
impl From<Uuid> for EditPredictionId {
fn from(value: Uuid) -> Self {
EditPredictionId(value)
}
}
impl From<EditPredictionId> for gpui::ElementId {
fn from(value: EditPredictionId) -> Self {
gpui::ElementId::Uuid(value.0)
@ -42,9 +25,122 @@ impl std::fmt::Display for EditPredictionId {
}
}
#[derive(Clone)]
pub struct EditPrediction {
pub id: EditPredictionId,
pub path: Arc<Path>,
pub edits: Arc<[(Range<Anchor>, String)]>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
_buffer: Entity<Buffer>,
}
impl EditPrediction {
pub async fn from_response(
response: predict_edits_v3::PredictEditsResponse,
active_buffer_old_snapshot: &TextBufferSnapshot,
active_buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Option<Self> {
// TODO only allow cloud to return one path
let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
return None;
};
let is_same_path = active_buffer
.read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
.ok()?;
let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
active_buffer
.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
let edits: Arc<[_]> =
interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
Some((
active_buffer.clone(),
edits.clone(),
new_snapshot,
buffer.preview_edits(edits, cx),
))
})
.ok()??
} else {
let buffer_handle = project
.update(cx, |project, cx| {
let project_path = project
.find_project_path(&path, cx)
.context("Failed to find project path for zeta edit")?;
anyhow::Ok(project.open_buffer(project_path, cx))
})
.ok()?
.log_err()?
.await
.context("Failed to open buffer for zeta edit")
.log_err()?;
buffer_handle
.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let edits = edits_from_response(&response.edits, &snapshot);
if edits.is_empty() {
return None;
}
Some((
buffer_handle.clone(),
edits.clone(),
snapshot,
buffer.preview_edits(edits, cx),
))
})
.ok()??
};
let edit_preview = edit_preview_task.await;
Some(EditPrediction {
id: EditPredictionId(response.request_id),
path,
edits,
snapshot,
edit_preview,
_buffer: buffer,
})
}
pub fn interpolate(
&self,
new_snapshot: &TextBufferSnapshot,
) -> Option<Vec<(Range<Anchor>, String)>> {
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
}
pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
buffer_path_eq(buffer, &self.path, cx)
}
}
impl std::fmt::Debug for EditPrediction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EditPrediction")
.field("id", &self.id)
.field("path", &self.path)
.field("edits", &self.edits)
.finish()
}
}
pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
}
pub fn interpolate_edits(
old_snapshot: &BufferSnapshot,
new_snapshot: &BufferSnapshot,
old_snapshot: &TextBufferSnapshot,
new_snapshot: &TextBufferSnapshot,
current_edits: Arc<[(Range<Anchor>, String)]>,
) -> Option<Vec<(Range<Anchor>, String)>> {
let mut edits = Vec::new();
@ -88,14 +184,13 @@ pub fn interpolate_edits(
if edits.is_empty() { None } else { Some(edits) }
}
pub fn edits_from_response(
fn edits_from_response(
edits: &[predict_edits_v3::Edit],
snapshot: &BufferSnapshot,
snapshot: &TextBufferSnapshot,
) -> Arc<[(Range<Anchor>, String)]> {
edits
.iter()
.flat_map(|edit| {
// TODO multi-file edits
let old_text = snapshot.text_for_range(edit.range.clone());
excerpt_edits_from_response(
@ -113,7 +208,7 @@ fn excerpt_edits_from_response(
old_text: Cow<str>,
new_text: &str,
offset: usize,
snapshot: &BufferSnapshot,
snapshot: &TextBufferSnapshot,
) -> impl Iterator<Item = (Range<Anchor>, String)> {
text_diff(&old_text, new_text)
.into_iter()
@ -221,6 +316,8 @@ mod tests {
id: EditPredictionId(Uuid::new_v4()),
edits,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
path: Path::new("test.txt").into(),
_buffer: buffer.clone(),
edit_preview,
};

View file

@ -4,76 +4,44 @@ use std::{
time::{Duration, Instant},
};
use anyhow::Context as _;
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
use gpui::{App, Entity, EntityId, Task, prelude::*};
use language::{BufferSnapshot, ToPoint as _};
use gpui::{App, Entity, Task, prelude::*};
use language::ToPoint as _;
use project::Project;
use util::ResultExt as _;
use crate::{Zeta, prediction::EditPrediction};
use crate::{BufferEditPrediction, Zeta};
pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
last_request_timestamp: Instant,
project: Entity<Project>,
}
impl ZetaEditPredictionProvider {
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
pub fn new(
project: Option<&Entity<Project>>,
project: Entity<Project>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
cx: &mut App,
) -> Self {
let zeta = Zeta::global(client, user_store, cx);
if let Some(project) = project {
zeta.update(cx, |zeta, cx| {
zeta.register_project(project, cx);
});
}
zeta.update(cx, |zeta, cx| {
zeta.register_project(&project, cx);
});
Self {
zeta,
current_prediction: None,
next_pending_prediction_id: 0,
pending_predictions: ArrayVec::new(),
last_request_timestamp: Instant::now(),
}
}
}
#[derive(Clone)]
struct CurrentEditPrediction {
buffer_id: EntityId,
prediction: EditPrediction,
}
impl CurrentEditPrediction {
fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
if self.buffer_id != old_prediction.buffer_id {
return true;
}
let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
return true;
};
let Some(new_edits) = self.prediction.interpolate(snapshot) else {
return false;
};
if old_edits.len() == 1 && new_edits.len() == 1 {
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
new_range == old_range && new_text.starts_with(old_text)
} else {
true
project: project,
}
}
}
@ -128,42 +96,31 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
fn refresh(
&mut self,
project: Option<Entity<project::Project>>,
buffer: Entity<language::Buffer>,
cursor_position: language::Anchor,
_debounce: bool,
cx: &mut Context<Self>,
) {
let Some(project) = project else {
return;
};
let zeta = self.zeta.read(cx);
if self
.zeta
.read(cx)
.user_store
.read_with(cx, |user_store, _cx| {
user_store.account_too_young() || user_store.has_overdue_invoices()
})
{
if zeta.user_store.read_with(cx, |user_store, _cx| {
user_store.account_too_young() || user_store.has_overdue_invoices()
}) {
return;
}
if let Some(current_prediction) = self.current_prediction.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if current_prediction
.prediction
.interpolate(&snapshot)
.is_some()
{
return;
}
if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
let pending_prediction_id = self.next_pending_prediction_id;
self.next_pending_prediction_id += 1;
let last_request_timestamp = self.last_request_timestamp;
let project = self.project.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
@ -171,25 +128,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
cx.background_executor().timer(timeout).await;
}
let prediction_request = this.update(cx, |this, cx| {
let refresh_task = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
zeta.request_prediction(&project, &buffer, cursor_position, cx)
zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
})
});
let prediction = match prediction_request {
Ok(prediction_request) => {
let prediction_request = prediction_request.await;
prediction_request.map(|c| {
c.map(|prediction| CurrentEditPrediction {
buffer_id: buffer.entity_id(),
prediction,
})
})
}
Err(error) => Err(error),
};
if let Some(refresh_task) = refresh_task.ok() {
refresh_task.await.log_err();
}
this.update(cx, |this, cx| {
if this.pending_predictions[0].id == pending_prediction_id {
@ -198,24 +146,6 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
this.pending_predictions.clear();
}
let Some(new_prediction) = prediction
.context("edit prediction failed")
.log_err()
.flatten()
else {
cx.notify();
return;
};
if let Some(old_prediction) = this.current_prediction.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
this.current_prediction = Some(new_prediction);
}
} else {
this.current_prediction = Some(new_prediction);
}
cx.notify();
})
.ok();
@ -248,15 +178,18 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
) {
}
fn accept(&mut self, _cx: &mut Context<Self>) {
// TODO [zeta2] report accept
self.current_prediction.take();
fn accept(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, _cx| {
zeta.accept_current_prediction(&self.project);
});
self.pending_predictions.clear();
}
fn discard(&mut self, _cx: &mut Context<Self>) {
fn discard(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
});
self.pending_predictions.clear();
self.current_prediction.take();
}
fn suggest(
@ -265,36 +198,44 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Option<edit_prediction::EditPrediction> {
let CurrentEditPrediction {
buffer_id,
prediction,
..
} = self.current_prediction.as_mut()?;
let prediction =
self.zeta
.read(cx)
.current_prediction_for_buffer(buffer, &self.project, cx)?;
// Invalidate previous prediction if it was generated for a different buffer.
if *buffer_id != buffer.entity_id() {
self.current_prediction.take();
return None;
}
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
}
};
let buffer = buffer.read(cx);
let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
self.current_prediction.take();
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
});
return None;
};
let cursor_row = cursor_position.to_point(buffer).row;
let cursor_row = cursor_position.to_point(&snapshot).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
let distance_from_closest_edit =
closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
- range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
@ -305,7 +246,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit =
range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
@ -313,7 +254,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}
Some(edit_prediction::EditPrediction {
Some(edit_prediction::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),

View file

@ -17,8 +17,8 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
use language::BufferSnapshot;
use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
use language::{BufferSnapshot, TextBufferSnapshot};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
@ -35,7 +35,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
mod prediction;
mod provider;
use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
use crate::prediction::EditPrediction;
pub use provider::ZetaEditPredictionProvider;
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
@ -53,7 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
excerpt: DEFAULT_EXCERPT_OPTIONS,
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
max_diagnostic_bytes: 2048,
prompt_format: PromptFormat::MarkedExcerpt,
prompt_format: PromptFormat::DEFAULT,
};
#[derive(Clone)]
@ -94,6 +94,47 @@ struct ZetaProject {
syntax_index: Entity<SyntaxIndex>,
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
}
#[derive(Clone)]
struct CurrentEditPrediction {
pub requested_by_buffer_id: EntityId,
pub prediction: EditPrediction,
}
impl CurrentEditPrediction {
fn should_replace_prediction(
&self,
old_prediction: &Self,
snapshot: &TextBufferSnapshot,
) -> bool {
if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
return true;
}
let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
return true;
};
let Some(new_edits) = self.prediction.interpolate(snapshot) else {
return false;
};
if old_edits.len() == 1 && new_edits.len() == 1 {
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
new_range == old_range && new_text.starts_with(old_text)
} else {
true
}
}
}
/// A prediction from the perspective of a buffer.
#[derive(Debug)]
enum BufferEditPrediction<'a> {
Local { prediction: &'a EditPrediction },
Jump { prediction: &'a EditPrediction },
}
struct RegisteredBuffer {
@ -204,6 +245,7 @@ impl Zeta {
syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
events: VecDeque::new(),
registered_buffers: HashMap::new(),
current_prediction: None,
})
}
@ -305,7 +347,83 @@ impl Zeta {
events.push_back(event);
}
pub fn request_prediction(
fn current_prediction_for_buffer(
&self,
buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &App,
) -> Option<BufferEditPrediction<'_>> {
let project_state = self.projects.get(&project.entity_id())?;
let CurrentEditPrediction {
requested_by_buffer_id,
prediction,
} = project_state.current_prediction.as_ref()?;
if prediction.targets_buffer(buffer.read(cx), cx) {
Some(BufferEditPrediction::Local { prediction })
} else if *requested_by_buffer_id == buffer.entity_id() {
Some(BufferEditPrediction::Jump { prediction })
} else {
None
}
}
fn accept_current_prediction(&mut self, project: &Entity<Project>) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.current_prediction.take();
};
// TODO report accepted
}
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.current_prediction.take();
};
}
pub fn refresh_prediction(
&mut self,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let request_task = self.request_prediction(project, buffer, position, cx);
let buffer = buffer.clone();
let project = project.clone();
cx.spawn(async move |this, cx| {
if let Some(prediction) = request_task.await? {
this.update(cx, |this, cx| {
let project_state = this
.projects
.get_mut(&project.entity_id())
.context("Project not found")?;
let new_prediction = CurrentEditPrediction {
requested_by_buffer_id: buffer.entity_id(),
prediction: prediction,
};
if project_state
.current_prediction
.as_ref()
.is_none_or(|old_prediction| {
new_prediction
.should_replace_prediction(&old_prediction, buffer.read(cx))
})
{
project_state.current_prediction = Some(new_prediction);
}
anyhow::Ok(())
})??;
}
Ok(())
})
}
fn request_prediction(
&mut self,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
@ -457,74 +575,63 @@ impl Zeta {
.ok();
}
let (response, usage) = response?;
let edits = edits_from_response(&response.edits, &snapshot);
anyhow::Ok(Some((response.request_id, edits, usage)))
anyhow::Ok(Some(response?))
}
});
let buffer = buffer.clone();
cx.spawn(async move |this, cx| {
match request_task.await {
Ok(Some((id, edits, usage))) => {
if let Some(usage) = usage {
this.update(cx, |this, cx| {
this.user_store.update(cx, |user_store, cx| {
user_store.update_edit_prediction_usage(usage, cx);
});
})
.ok();
}
// TODO telemetry: duration, etc
let Some((edits, snapshot, edit_preview_task)) =
buffer.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[_]> =
interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
})?
else {
return Ok(None);
};
Ok(Some(EditPrediction {
id: id.into(),
edits,
snapshot,
edit_preview: edit_preview_task.await,
}))
}
Ok(None) => Ok(None),
Err(err) => {
if err.is::<ZedUpdateRequiredError>() {
cx.update(|cx| {
this.update(cx, |this, _cx| {
this.update_required = true;
cx.spawn({
let project = project.clone();
async move |this, cx| {
match request_task.await {
Ok(Some((response, usage))) => {
if let Some(usage) = usage {
this.update(cx, |this, cx| {
this.user_store.update(cx, |user_store, cx| {
user_store.update_edit_prediction_usage(usage, cx);
});
})
.ok();
}
let error_message: SharedString = err.to_string().into();
show_app_notification(
NotificationId::unique::<ZedUpdateRequiredError>(),
cx,
move |cx| {
cx.new(|cx| {
ErrorMessagePrompt::new(error_message.clone(), cx)
.with_link_button(
"Update Zed",
"https://zed.dev/releases",
)
})
},
);
})
.ok();
let prediction = EditPrediction::from_response(
response, &snapshot, &buffer, &project, cx,
)
.await;
// TODO telemetry: duration, etc
Ok(prediction)
}
Ok(None) => Ok(None),
Err(err) => {
if err.is::<ZedUpdateRequiredError>() {
cx.update(|cx| {
this.update(cx, |this, _cx| {
this.update_required = true;
})
.ok();
Err(err)
let error_message: SharedString = err.to_string().into();
show_app_notification(
NotificationId::unique::<ZedUpdateRequiredError>(),
cx,
move |cx| {
cx.new(|cx| {
ErrorMessagePrompt::new(error_message.clone(), cx)
.with_link_button(
"Update Zed",
"https://zed.dev/releases",
)
})
},
);
})
.ok();
}
Err(err)
}
}
}
})
@ -859,13 +966,113 @@ mod tests {
};
use indoc::indoc;
use language::{LanguageServerId, OffsetRangeExt as _};
use pretty_assertions::{assert_eq, assert_matches};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
use uuid::Uuid;
use crate::Zeta;
use crate::{BufferEditPrediction, Zeta};
#[gpui::test]
async fn test_current_state(cx: &mut TestAppContext) {
let (zeta, mut req_rx) = init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"1.txt": "Hello!\nHow\nBye",
"2.txt": "Hola!\nComo\nAdios"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
zeta.update(cx, |zeta, cx| {
zeta.register_project(&project, cx);
});
let buffer1 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot1.anchor_before(language::Point::new(1, 3));
// Prediction for current file
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer1, position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
respond_tx
.send(predict_edits_v3::PredictEditsResponse {
request_id: Uuid::new_v4(),
edits: vec![predict_edits_v3::Edit {
path: Path::new(path!("root/1.txt")).into(),
range: 0..snapshot1.len(),
content: "Hello!\nHow are you?\nBye".into(),
}],
debug_info: None,
})
.unwrap();
prediction_task.await.unwrap();
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
.current_prediction_for_buffer(&buffer1, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
// Prediction for another file
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer1, position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
respond_tx
.send(predict_edits_v3::PredictEditsResponse {
request_id: Uuid::new_v4(),
edits: vec![predict_edits_v3::Edit {
path: Path::new(path!("root/2.txt")).into(),
range: 0..snapshot1.len(),
content: "Hola!\nComo estas?\nAdios".into(),
}],
debug_info: None,
})
.unwrap();
prediction_task.await.unwrap();
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
.current_prediction_for_buffer(&buffer1, &project, cx)
.unwrap();
assert_matches!(
prediction,
BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
);
});
let buffer2 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
.current_prediction_for_buffer(&buffer2, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
}
#[gpui::test]
async fn test_simple_request(cx: &mut TestAppContext) {
@ -1146,6 +1353,7 @@ mod tests {
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = Zeta::global(&client, &user_store, cx);
(zeta, req_rx)
})
}

View file

@ -185,7 +185,7 @@ impl Zeta2Inspector {
cx.background_executor().timer(THROTTLE_TIME).await;
if let Some(task) = zeta
.update(cx, |zeta, cx| {
zeta.request_prediction(&project, &buffer, position, cx)
zeta.refresh_prediction(&project, &buffer, position, cx)
})
.ok()
{