mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
zeta2: Allow provider to suggest edits in different files (#39110)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
parent
b7f9fd7d74
commit
cda48a3a1c
16 changed files with 689 additions and 318 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -5126,7 +5126,6 @@ dependencies = [
|
|||
"client",
|
||||
"gpui",
|
||||
"language",
|
||||
"project",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -43,15 +43,24 @@ pub struct PredictEditsRequest {
|
|||
pub prompt_format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum PromptFormat {
|
||||
#[default]
|
||||
MarkedExcerpt,
|
||||
LabeledSections,
|
||||
/// Prompt format intended for use via zeta_cli
|
||||
OnlySnippets,
|
||||
}
|
||||
|
||||
impl PromptFormat {
|
||||
pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
|
||||
}
|
||||
|
||||
impl Default for PromptFormat {
|
||||
fn default() -> Self {
|
||||
Self::DEFAULT
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptFormat {
|
||||
pub fn iter() -> impl Iterator<Item = Self> {
|
||||
<Self as strum::IntoEnumIterator>::iter()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ use anyhow::Result;
|
|||
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
|
||||
use project::Project;
|
||||
use settings::Settings;
|
||||
use std::{path::Path, time::Duration};
|
||||
|
||||
|
|
@ -84,7 +83,6 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
|
|
@ -249,7 +247,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
|||
None
|
||||
} else {
|
||||
let position = cursor_position.bias_right(buffer);
|
||||
Some(EditPrediction {
|
||||
Some(EditPrediction::Local {
|
||||
id: None,
|
||||
edits: vec![(position..position, completion_text.into())],
|
||||
edit_preview: None,
|
||||
|
|
|
|||
|
|
@ -15,5 +15,4 @@ path = "src/edit_prediction.rs"
|
|||
client.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
project.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ use std::ops::Range;
|
|||
use client::EditPredictionUsage;
|
||||
use gpui::{App, Context, Entity, SharedString};
|
||||
use language::Buffer;
|
||||
use project::Project;
|
||||
|
||||
// TODO: Find a better home for `Direction`.
|
||||
//
|
||||
|
|
@ -16,11 +15,19 @@ pub enum Direction {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
/// The ID of the completion, if it has one.
|
||||
pub id: Option<SharedString>,
|
||||
pub edits: Vec<(Range<language::Anchor>, String)>,
|
||||
pub edit_preview: Option<language::EditPreview>,
|
||||
pub enum EditPrediction {
|
||||
/// Edits within the buffer that requested the prediction
|
||||
Local {
|
||||
id: Option<SharedString>,
|
||||
edits: Vec<(Range<language::Anchor>, String)>,
|
||||
edit_preview: Option<language::EditPreview>,
|
||||
},
|
||||
/// Jump to a different file from the one that requested the prediction
|
||||
Jump {
|
||||
id: Option<SharedString>,
|
||||
snapshot: language::BufferSnapshot,
|
||||
target: language::Anchor,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum DataCollectionState {
|
||||
|
|
@ -83,7 +90,6 @@ pub trait EditPredictionProvider: 'static + Sized {
|
|||
fn is_refreshing(&self) -> bool;
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
|
|
@ -124,7 +130,6 @@ pub trait EditPredictionProviderHandle {
|
|||
fn is_refreshing(&self, cx: &App) -> bool;
|
||||
fn refresh(
|
||||
&self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
|
|
@ -198,14 +203,13 @@ where
|
|||
|
||||
fn refresh(
|
||||
&self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.update(cx, |this, cx| {
|
||||
this.refresh(project, buffer, cursor_position, debounce, cx)
|
||||
this.refresh(buffer, cursor_position, debounce, cx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ use edit_prediction::EditPredictionProvider;
|
|||
use gpui::{Entity, prelude::*};
|
||||
use indoc::indoc;
|
||||
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
|
||||
use project::Project;
|
||||
use std::ops::Range;
|
||||
use text::{Point, ToOffset};
|
||||
|
||||
|
|
@ -261,7 +260,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
|
|||
EditPrediction::Edit { .. } => {
|
||||
// This is expected for non-Zed providers
|
||||
}
|
||||
EditPrediction::Move { .. } => {
|
||||
EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => {
|
||||
panic!(
|
||||
"Non-Zed providers should not show Move predictions (jump functionality)"
|
||||
);
|
||||
|
|
@ -299,7 +298,7 @@ fn assert_editor_active_move_completion(
|
|||
.as_ref()
|
||||
.expect("editor has no active completion");
|
||||
|
||||
if let EditPrediction::Move { target, .. } = &completion_state.completion {
|
||||
if let EditPrediction::MoveWithin { target, .. } = &completion_state.completion {
|
||||
assert(editor.buffer().read(cx).snapshot(cx), *target);
|
||||
} else {
|
||||
panic!("expected move completion");
|
||||
|
|
@ -326,7 +325,7 @@ fn propose_edits<T: ToOffset>(
|
|||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
|
||||
id: None,
|
||||
edits: edits.collect(),
|
||||
edit_preview: None,
|
||||
|
|
@ -357,7 +356,7 @@ fn propose_edits_non_zed<T: ToOffset>(
|
|||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
|
||||
id: None,
|
||||
edits: edits.collect(),
|
||||
edit_preview: None,
|
||||
|
|
@ -418,7 +417,6 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
|
|
@ -492,7 +490,6 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
|
|
|
|||
|
|
@ -638,17 +638,23 @@ enum EditPrediction {
|
|||
display_mode: EditDisplayMode,
|
||||
snapshot: BufferSnapshot,
|
||||
},
|
||||
Move {
|
||||
/// Move to a specific location in the active editor
|
||||
MoveWithin {
|
||||
target: Anchor,
|
||||
snapshot: BufferSnapshot,
|
||||
},
|
||||
/// Move to a specific location in a different editor (not the active one)
|
||||
MoveOutside {
|
||||
target: language::Anchor,
|
||||
snapshot: BufferSnapshot,
|
||||
},
|
||||
}
|
||||
|
||||
struct EditPredictionState {
|
||||
inlay_ids: Vec<InlayId>,
|
||||
completion: EditPrediction,
|
||||
completion_id: Option<SharedString>,
|
||||
invalidation_range: Range<Anchor>,
|
||||
invalidation_range: Option<Range<Anchor>>,
|
||||
}
|
||||
|
||||
enum EditPredictionSettings {
|
||||
|
|
@ -7175,13 +7181,7 @@ impl Editor {
|
|||
return None;
|
||||
}
|
||||
|
||||
provider.refresh(
|
||||
self.project.clone(),
|
||||
buffer,
|
||||
cursor_buffer_position,
|
||||
debounce,
|
||||
cx,
|
||||
);
|
||||
provider.refresh(buffer, cursor_buffer_position, debounce, cx);
|
||||
Some(())
|
||||
}
|
||||
|
||||
|
|
@ -7424,10 +7424,8 @@ impl Editor {
|
|||
return;
|
||||
};
|
||||
|
||||
self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
|
||||
|
||||
match &active_edit_prediction.completion {
|
||||
EditPrediction::Move { target, .. } => {
|
||||
EditPrediction::MoveWithin { target, .. } => {
|
||||
let target = *target;
|
||||
|
||||
if let Some(position_map) = &self.last_position_map {
|
||||
|
|
@ -7469,7 +7467,19 @@ impl Editor {
|
|||
}
|
||||
}
|
||||
}
|
||||
EditPrediction::MoveOutside { snapshot, target } => {
|
||||
if let Some(workspace) = self.workspace() {
|
||||
Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
EditPrediction::Edit { edits, .. } => {
|
||||
self.report_edit_prediction_event(
|
||||
active_edit_prediction.completion_id.clone(),
|
||||
true,
|
||||
cx,
|
||||
);
|
||||
|
||||
if let Some(provider) = self.edit_prediction_provider() {
|
||||
provider.accept(cx);
|
||||
}
|
||||
|
|
@ -7522,10 +7532,8 @@ impl Editor {
|
|||
return;
|
||||
}
|
||||
|
||||
self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
|
||||
|
||||
match &active_edit_prediction.completion {
|
||||
EditPrediction::Move { target, .. } => {
|
||||
EditPrediction::MoveWithin { target, .. } => {
|
||||
let target = *target;
|
||||
self.change_selections(
|
||||
SelectionEffects::scroll(Autoscroll::newest()),
|
||||
|
|
@ -7536,7 +7544,19 @@ impl Editor {
|
|||
},
|
||||
);
|
||||
}
|
||||
EditPrediction::MoveOutside { snapshot, target } => {
|
||||
if let Some(workspace) = self.workspace() {
|
||||
Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
EditPrediction::Edit { edits, .. } => {
|
||||
self.report_edit_prediction_event(
|
||||
active_edit_prediction.completion_id.clone(),
|
||||
true,
|
||||
cx,
|
||||
);
|
||||
|
||||
// Find an insertion that starts at the cursor position.
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let cursor_offset = self.selections.newest::<usize>(cx).head();
|
||||
|
|
@ -7631,6 +7651,36 @@ impl Editor {
|
|||
);
|
||||
}
|
||||
|
||||
fn open_editor_at_anchor(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
target: language::Anchor,
|
||||
workspace: &Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let path = snapshot.file().map(|file| file.full_path(cx));
|
||||
let Some(path) =
|
||||
path.and_then(|path| workspace.project().read(cx).find_project_path(path, cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow::anyhow!("Project path not found")));
|
||||
};
|
||||
let target = text::ToPoint::to_point(&target, snapshot);
|
||||
let item = workspace.open_path(path, None, true, window, cx);
|
||||
window.spawn(cx, async move |cx| {
|
||||
let Some(editor) = item.await?.downcast::<Editor>() else {
|
||||
return Ok(());
|
||||
};
|
||||
editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(target, window, cx);
|
||||
})
|
||||
.ok();
|
||||
anyhow::Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn has_active_edit_prediction(&self) -> bool {
|
||||
self.active_edit_prediction.is_some()
|
||||
}
|
||||
|
|
@ -7846,7 +7896,10 @@ impl Editor {
|
|||
.active_edit_prediction
|
||||
.as_ref()
|
||||
.is_some_and(|completion| {
|
||||
let invalidation_range = completion.invalidation_range.to_offset(&multibuffer);
|
||||
let Some(invalidation_range) = completion.invalidation_range.as_ref() else {
|
||||
return false;
|
||||
};
|
||||
let invalidation_range = invalidation_range.to_offset(&multibuffer);
|
||||
let invalidation_range = invalidation_range.start..=invalidation_range.end;
|
||||
!invalidation_range.contains(&offset_selection.head())
|
||||
})
|
||||
|
|
@ -7882,8 +7935,31 @@ impl Editor {
|
|||
}
|
||||
|
||||
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
|
||||
let edits = edit_prediction
|
||||
.edits
|
||||
|
||||
let (completion_id, edits, edit_preview) = match edit_prediction {
|
||||
edit_prediction::EditPrediction::Local {
|
||||
id,
|
||||
edits,
|
||||
edit_preview,
|
||||
} => (id, edits, edit_preview),
|
||||
edit_prediction::EditPrediction::Jump {
|
||||
id,
|
||||
snapshot,
|
||||
target,
|
||||
} => {
|
||||
self.stale_edit_prediction_in_menu = None;
|
||||
self.active_edit_prediction = Some(EditPredictionState {
|
||||
inlay_ids: vec![],
|
||||
completion: EditPrediction::MoveOutside { snapshot, target },
|
||||
completion_id: id,
|
||||
invalidation_range: None,
|
||||
});
|
||||
cx.notify();
|
||||
return Some(());
|
||||
}
|
||||
};
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.flat_map(|(range, new_text)| {
|
||||
let start = multibuffer.anchor_in_excerpt(excerpt_id, range.start)?;
|
||||
|
|
@ -7928,7 +8004,7 @@ impl Editor {
|
|||
invalidation_row_range =
|
||||
move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
|
||||
let target = first_edit_start;
|
||||
EditPrediction::Move { target, snapshot }
|
||||
EditPrediction::MoveWithin { target, snapshot }
|
||||
} else {
|
||||
let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true)
|
||||
&& !self.edit_predictions_hidden_for_vim_mode;
|
||||
|
|
@ -7977,7 +8053,7 @@ impl Editor {
|
|||
|
||||
EditPrediction::Edit {
|
||||
edits,
|
||||
edit_preview: edit_prediction.edit_preview,
|
||||
edit_preview,
|
||||
display_mode,
|
||||
snapshot,
|
||||
}
|
||||
|
|
@ -7994,8 +8070,8 @@ impl Editor {
|
|||
self.active_edit_prediction = Some(EditPredictionState {
|
||||
inlay_ids,
|
||||
completion,
|
||||
completion_id: edit_prediction.id,
|
||||
invalidation_range,
|
||||
completion_id,
|
||||
invalidation_range: Some(invalidation_range),
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
|
|
@ -8581,7 +8657,7 @@ impl Editor {
|
|||
}
|
||||
|
||||
match &active_edit_prediction.completion {
|
||||
EditPrediction::Move { target, .. } => {
|
||||
EditPrediction::MoveWithin { target, .. } => {
|
||||
let target_display_point = target.to_display_point(editor_snapshot);
|
||||
|
||||
if self.edit_prediction_requires_modifier() {
|
||||
|
|
@ -8666,6 +8742,28 @@ impl Editor {
|
|||
window,
|
||||
cx,
|
||||
),
|
||||
EditPrediction::MoveOutside { snapshot, .. } => {
|
||||
let file_name = snapshot
|
||||
.file()
|
||||
.map(|file| file.file_name(cx))
|
||||
.unwrap_or("untitled");
|
||||
let mut element = self
|
||||
.render_edit_prediction_line_popover(
|
||||
format!("Jump to {file_name}"),
|
||||
Some(IconName::ZedPredict),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
let origin_x = text_bounds.size.width / 2. - size.width / 2.;
|
||||
let origin_y = text_bounds.size.height - size.height - px(30.);
|
||||
let origin = text_bounds.origin + gpui::Point::new(origin_x, origin_y);
|
||||
element.prepaint_at(origin, window, cx);
|
||||
|
||||
Some((element, origin))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -8730,13 +8828,13 @@ impl Editor {
|
|||
.items_end()
|
||||
.when(flag_on_right, |el| el.items_start())
|
||||
.child(if flag_on_right {
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)?
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)
|
||||
.rounded_bl(px(0.))
|
||||
.rounded_tl(px(0.))
|
||||
.border_l_2()
|
||||
.border_color(border_color)
|
||||
} else {
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)?
|
||||
self.render_edit_prediction_line_popover("Jump", None, window, cx)
|
||||
.rounded_br(px(0.))
|
||||
.rounded_tr(px(0.))
|
||||
.border_r_2()
|
||||
|
|
@ -8776,7 +8874,7 @@ impl Editor {
|
|||
cx: &mut App,
|
||||
) -> Option<(AnyElement, gpui::Point<Pixels>)> {
|
||||
let mut element = self
|
||||
.render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)?
|
||||
.render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
|
|
@ -8816,7 +8914,7 @@ impl Editor {
|
|||
Some(IconName::ArrowUp),
|
||||
window,
|
||||
cx,
|
||||
)?
|
||||
)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
|
|
@ -8835,7 +8933,7 @@ impl Editor {
|
|||
Some(IconName::ArrowDown),
|
||||
window,
|
||||
cx,
|
||||
)?
|
||||
)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
|
|
@ -8882,7 +8980,7 @@ impl Editor {
|
|||
);
|
||||
|
||||
let mut element = self
|
||||
.render_edit_prediction_line_popover(label, None, window, cx)?
|
||||
.render_edit_prediction_line_popover(label, None, window, cx)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
|
|
@ -8909,7 +9007,7 @@ impl Editor {
|
|||
};
|
||||
|
||||
element = self
|
||||
.render_edit_prediction_line_popover(label, Some(icon), window, cx)?
|
||||
.render_edit_prediction_line_popover(label, Some(icon), window, cx)
|
||||
.into_any();
|
||||
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
|
|
@ -9163,13 +9261,13 @@ impl Editor {
|
|||
icon: Option<IconName>,
|
||||
window: &mut Window,
|
||||
cx: &App,
|
||||
) -> Option<Stateful<Div>> {
|
||||
) -> Stateful<Div> {
|
||||
let padding_right = if icon.is_some() { px(4.) } else { px(8.) };
|
||||
|
||||
let keybind = self.render_edit_prediction_accept_keybind(window, cx);
|
||||
let has_keybind = keybind.is_some();
|
||||
|
||||
let result = h_flex()
|
||||
h_flex()
|
||||
.id("ep-line-popover")
|
||||
.py_0p5()
|
||||
.pl_1()
|
||||
|
|
@ -9215,9 +9313,7 @@ impl Editor {
|
|||
.mt(px(1.5))
|
||||
.child(Icon::new(icon).size(IconSize::Small)),
|
||||
)
|
||||
});
|
||||
|
||||
Some(result)
|
||||
})
|
||||
}
|
||||
|
||||
fn edit_prediction_line_popover_bg_color(cx: &App) -> Hsla {
|
||||
|
|
@ -9281,7 +9377,7 @@ impl Editor {
|
|||
.rounded_tl(px(0.))
|
||||
.overflow_hidden()
|
||||
.child(div().px_1p5().child(match &prediction.completion {
|
||||
EditPrediction::Move { target, snapshot } => {
|
||||
EditPrediction::MoveWithin { target, snapshot } => {
|
||||
use text::ToPoint as _;
|
||||
if target.text_anchor.to_point(snapshot).row > cursor_point.row
|
||||
{
|
||||
|
|
@ -9290,6 +9386,10 @@ impl Editor {
|
|||
Icon::new(IconName::ZedPredictUp)
|
||||
}
|
||||
}
|
||||
EditPrediction::MoveOutside { .. } => {
|
||||
// TODO [zeta2] custom icon for external jump?
|
||||
Icon::new(provider_icon)
|
||||
}
|
||||
EditPrediction::Edit { .. } => Icon::new(provider_icon),
|
||||
}))
|
||||
.child(
|
||||
|
|
@ -9472,7 +9572,7 @@ impl Editor {
|
|||
.unwrap_or(true);
|
||||
|
||||
match &completion.completion {
|
||||
EditPrediction::Move {
|
||||
EditPrediction::MoveWithin {
|
||||
target, snapshot, ..
|
||||
} => {
|
||||
if !supports_jump {
|
||||
|
|
@ -9494,7 +9594,20 @@ impl Editor {
|
|||
.child(Label::new("Jump to Edit")),
|
||||
)
|
||||
}
|
||||
|
||||
EditPrediction::MoveOutside { snapshot, .. } => {
|
||||
let file_name = snapshot
|
||||
.file()
|
||||
.map(|file| file.file_name(cx))
|
||||
.unwrap_or("untitled");
|
||||
Some(
|
||||
h_flex()
|
||||
.px_2()
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.child(Icon::new(IconName::ZedPredict))
|
||||
.child(Label::new(format!("Jump to {file_name}"))),
|
||||
)
|
||||
}
|
||||
EditPrediction::Edit {
|
||||
edits,
|
||||
edit_preview,
|
||||
|
|
@ -21418,7 +21531,7 @@ impl Editor {
|
|||
{
|
||||
self.hide_context_menu(window, cx);
|
||||
}
|
||||
self.discard_edit_prediction(false, cx);
|
||||
self.take_active_edit_prediction(cx);
|
||||
cx.emit(EditorEvent::Blurred);
|
||||
cx.notify();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8272,7 +8272,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
|
|||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
|
||||
id: None,
|
||||
edits: vec![(edit_position..edit_position, "X".into())],
|
||||
edit_preview: None,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ gpui.workspace = true
|
|||
language.workspace = true
|
||||
log.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
|||
use futures::StreamExt as _;
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use language::{Anchor, Buffer, BufferSnapshot};
|
||||
use project::Project;
|
||||
use std::{
|
||||
ops::{AddAssign, Range},
|
||||
path::Path,
|
||||
|
|
@ -94,7 +93,7 @@ fn completion_from_diff(
|
|||
edits.push((edit_range, edit_text));
|
||||
}
|
||||
|
||||
EditPrediction {
|
||||
EditPrediction::Local {
|
||||
id: None,
|
||||
edits,
|
||||
edit_preview: None,
|
||||
|
|
@ -132,7 +131,6 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
buffer_handle: Entity<Buffer>,
|
||||
cursor_position: Anchor,
|
||||
debounce: bool,
|
||||
|
|
|
|||
|
|
@ -205,42 +205,48 @@ fn assign_edit_prediction_provider(
|
|||
}
|
||||
}
|
||||
|
||||
if std::env::var("ZED_ZETA2").is_ok() {
|
||||
let zeta = zeta2::Zeta::global(client, &user_store, cx);
|
||||
let provider = cx.new(|cx| {
|
||||
zeta2::ZetaEditPredictionProvider::new(
|
||||
editor.project(),
|
||||
&client,
|
||||
&user_store,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
&& let Some(project) = editor.project()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
if let Some(project) = editor.project() {
|
||||
if std::env::var("ZED_ZETA2").is_ok() {
|
||||
let zeta = zeta2::Zeta::global(client, &user_store, cx);
|
||||
let provider = cx.new(|cx| {
|
||||
zeta2::ZetaEditPredictionProvider::new(
|
||||
project.clone(),
|
||||
&client,
|
||||
&user_store,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
} else {
|
||||
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
|
||||
// TODO [zeta2] handle multibuffers
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
&& let Some(project) = editor.project()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
} else {
|
||||
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
let provider = cx.new(|_| {
|
||||
zeta::ZetaEditPredictionProvider::new(
|
||||
zeta,
|
||||
project.clone(),
|
||||
singleton_buffer,
|
||||
)
|
||||
});
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
|
||||
let provider =
|
||||
cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1316,12 +1316,17 @@ pub struct ZetaEditPredictionProvider {
|
|||
next_pending_completion_id: usize,
|
||||
current_completion: Option<CurrentEditPrediction>,
|
||||
last_request_timestamp: Instant,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl ZetaEditPredictionProvider {
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
|
||||
pub fn new(zeta: Entity<Zeta>, singleton_buffer: Option<Entity<Buffer>>) -> Self {
|
||||
pub fn new(
|
||||
zeta: Entity<Zeta>,
|
||||
project: Entity<Project>,
|
||||
singleton_buffer: Option<Entity<Buffer>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
zeta,
|
||||
singleton_buffer,
|
||||
|
|
@ -1329,6 +1334,7 @@ impl ZetaEditPredictionProvider {
|
|||
next_pending_completion_id: 0,
|
||||
current_completion: None,
|
||||
last_request_timestamp: Instant::now(),
|
||||
project,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1394,7 +1400,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
_debounce: bool,
|
||||
|
|
@ -1403,9 +1408,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
if self.zeta.read(cx).update_required {
|
||||
return;
|
||||
}
|
||||
let Some(project) = project else {
|
||||
return;
|
||||
};
|
||||
|
||||
if self
|
||||
.zeta
|
||||
|
|
@ -1433,6 +1435,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
self.next_pending_completion_id += 1;
|
||||
let last_request_timestamp = self.last_request_timestamp;
|
||||
|
||||
let project = self.project.clone();
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
|
||||
.checked_duration_since(Instant::now())
|
||||
|
|
@ -1604,7 +1607,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction::EditPrediction {
|
||||
Some(edit_prediction::EditPrediction::Local {
|
||||
id: Some(completion.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(completion.edit_preview.clone()),
|
||||
|
|
|
|||
|
|
@ -1,35 +1,18 @@
|
|||
use std::{borrow::Cow, ops::Range, sync::Arc};
|
||||
use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
|
||||
use gpui::{App, AsyncApp, Entity};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
|
||||
};
|
||||
use project::Project;
|
||||
use util::ResultExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
pub id: EditPredictionId,
|
||||
pub edits: Arc<[(Range<Anchor>, String)]>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
pub edit_preview: EditPreview,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
pub fn interpolate(
|
||||
&self,
|
||||
new_snapshot: &BufferSnapshot,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(Uuid);
|
||||
|
||||
impl From<Uuid> for EditPredictionId {
|
||||
fn from(value: Uuid) -> Self {
|
||||
EditPredictionId(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EditPredictionId> for gpui::ElementId {
|
||||
fn from(value: EditPredictionId) -> Self {
|
||||
gpui::ElementId::Uuid(value.0)
|
||||
|
|
@ -42,9 +25,122 @@ impl std::fmt::Display for EditPredictionId {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
pub id: EditPredictionId,
|
||||
pub path: Arc<Path>,
|
||||
pub edits: Arc<[(Range<Anchor>, String)]>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
pub edit_preview: EditPreview,
|
||||
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
|
||||
_buffer: Entity<Buffer>,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
pub async fn from_response(
|
||||
response: predict_edits_v3::PredictEditsResponse,
|
||||
active_buffer_old_snapshot: &TextBufferSnapshot,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<Self> {
|
||||
// TODO only allow cloud to return one path
|
||||
let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let is_same_path = active_buffer
|
||||
.read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
|
||||
.ok()?;
|
||||
|
||||
let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
|
||||
active_buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
|
||||
let edits: Arc<[_]> =
|
||||
interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
|
||||
|
||||
Some((
|
||||
active_buffer.clone(),
|
||||
edits.clone(),
|
||||
new_snapshot,
|
||||
buffer.preview_edits(edits, cx),
|
||||
))
|
||||
})
|
||||
.ok()??
|
||||
} else {
|
||||
let buffer_handle = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(&path, cx)
|
||||
.context("Failed to find project path for zeta edit")?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})
|
||||
.ok()?
|
||||
.log_err()?
|
||||
.await
|
||||
.context("Failed to open buffer for zeta edit")
|
||||
.log_err()?;
|
||||
|
||||
buffer_handle
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let edits = edits_from_response(&response.edits, &snapshot);
|
||||
if edits.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((
|
||||
buffer_handle.clone(),
|
||||
edits.clone(),
|
||||
snapshot,
|
||||
buffer.preview_edits(edits, cx),
|
||||
))
|
||||
})
|
||||
.ok()??
|
||||
};
|
||||
|
||||
let edit_preview = edit_preview_task.await;
|
||||
|
||||
Some(EditPrediction {
|
||||
id: EditPredictionId(response.request_id),
|
||||
path,
|
||||
edits,
|
||||
snapshot,
|
||||
edit_preview,
|
||||
_buffer: buffer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn interpolate(
|
||||
&self,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
|
||||
}
|
||||
|
||||
pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
|
||||
buffer_path_eq(buffer, &self.path, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for EditPrediction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("EditPrediction")
|
||||
.field("id", &self.id)
|
||||
.field("path", &self.path)
|
||||
.field("edits", &self.edits)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
|
||||
buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
|
||||
}
|
||||
|
||||
pub fn interpolate_edits(
|
||||
old_snapshot: &BufferSnapshot,
|
||||
new_snapshot: &BufferSnapshot,
|
||||
old_snapshot: &TextBufferSnapshot,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
current_edits: Arc<[(Range<Anchor>, String)]>,
|
||||
) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
let mut edits = Vec::new();
|
||||
|
|
@ -88,14 +184,13 @@ pub fn interpolate_edits(
|
|||
if edits.is_empty() { None } else { Some(edits) }
|
||||
}
|
||||
|
||||
pub fn edits_from_response(
|
||||
fn edits_from_response(
|
||||
edits: &[predict_edits_v3::Edit],
|
||||
snapshot: &BufferSnapshot,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
) -> Arc<[(Range<Anchor>, String)]> {
|
||||
edits
|
||||
.iter()
|
||||
.flat_map(|edit| {
|
||||
// TODO multi-file edits
|
||||
let old_text = snapshot.text_for_range(edit.range.clone());
|
||||
|
||||
excerpt_edits_from_response(
|
||||
|
|
@ -113,7 +208,7 @@ fn excerpt_edits_from_response(
|
|||
old_text: Cow<str>,
|
||||
new_text: &str,
|
||||
offset: usize,
|
||||
snapshot: &BufferSnapshot,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
) -> impl Iterator<Item = (Range<Anchor>, String)> {
|
||||
text_diff(&old_text, new_text)
|
||||
.into_iter()
|
||||
|
|
@ -221,6 +316,8 @@ mod tests {
|
|||
id: EditPredictionId(Uuid::new_v4()),
|
||||
edits,
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
path: Path::new("test.txt").into(),
|
||||
_buffer: buffer.clone(),
|
||||
edit_preview,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -4,76 +4,44 @@ use std::{
|
|||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, UserStore};
|
||||
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
|
||||
use gpui::{App, Entity, EntityId, Task, prelude::*};
|
||||
use language::{BufferSnapshot, ToPoint as _};
|
||||
use gpui::{App, Entity, Task, prelude::*};
|
||||
use language::ToPoint as _;
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{Zeta, prediction::EditPrediction};
|
||||
use crate::{BufferEditPrediction, Zeta};
|
||||
|
||||
pub struct ZetaEditPredictionProvider {
|
||||
zeta: Entity<Zeta>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
last_request_timestamp: Instant,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl ZetaEditPredictionProvider {
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
|
||||
pub fn new(
|
||||
project: Option<&Entity<Project>>,
|
||||
project: Entity<Project>,
|
||||
client: &Arc<Client>,
|
||||
user_store: &Entity<UserStore>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let zeta = Zeta::global(client, user_store, cx);
|
||||
if let Some(project) = project {
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(project, cx);
|
||||
});
|
||||
}
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(&project, cx);
|
||||
});
|
||||
|
||||
Self {
|
||||
zeta,
|
||||
current_prediction: None,
|
||||
next_pending_prediction_id: 0,
|
||||
pending_predictions: ArrayVec::new(),
|
||||
last_request_timestamp: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CurrentEditPrediction {
|
||||
buffer_id: EntityId,
|
||||
prediction: EditPrediction,
|
||||
}
|
||||
|
||||
impl CurrentEditPrediction {
|
||||
fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
|
||||
if self.buffer_id != old_prediction.buffer_id {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
|
||||
return true;
|
||||
};
|
||||
let Some(new_edits) = self.prediction.interpolate(snapshot) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if old_edits.len() == 1 && new_edits.len() == 1 {
|
||||
let (old_range, old_text) = &old_edits[0];
|
||||
let (new_range, new_text) = &new_edits[0];
|
||||
new_range == old_range && new_text.starts_with(old_text)
|
||||
} else {
|
||||
true
|
||||
project: project,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -128,42 +96,31 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<project::Project>>,
|
||||
buffer: Entity<language::Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(project) = project else {
|
||||
return;
|
||||
};
|
||||
let zeta = self.zeta.read(cx);
|
||||
|
||||
if self
|
||||
.zeta
|
||||
.read(cx)
|
||||
.user_store
|
||||
.read_with(cx, |user_store, _cx| {
|
||||
user_store.account_too_young() || user_store.has_overdue_invoices()
|
||||
})
|
||||
{
|
||||
if zeta.user_store.read_with(cx, |user_store, _cx| {
|
||||
user_store.account_too_young() || user_store.has_overdue_invoices()
|
||||
}) {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(current_prediction) = self.current_prediction.as_ref() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
if current_prediction
|
||||
.prediction
|
||||
.interpolate(&snapshot)
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let pending_prediction_id = self.next_pending_prediction_id;
|
||||
self.next_pending_prediction_id += 1;
|
||||
let last_request_timestamp = self.last_request_timestamp;
|
||||
|
||||
let project = self.project.clone();
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
|
||||
.checked_duration_since(Instant::now())
|
||||
|
|
@ -171,25 +128,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
cx.background_executor().timer(timeout).await;
|
||||
}
|
||||
|
||||
let prediction_request = this.update(cx, |this, cx| {
|
||||
let refresh_task = this.update(cx, |this, cx| {
|
||||
this.last_request_timestamp = Instant::now();
|
||||
this.zeta.update(cx, |zeta, cx| {
|
||||
zeta.request_prediction(&project, &buffer, cursor_position, cx)
|
||||
zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
|
||||
})
|
||||
});
|
||||
|
||||
let prediction = match prediction_request {
|
||||
Ok(prediction_request) => {
|
||||
let prediction_request = prediction_request.await;
|
||||
prediction_request.map(|c| {
|
||||
c.map(|prediction| CurrentEditPrediction {
|
||||
buffer_id: buffer.entity_id(),
|
||||
prediction,
|
||||
})
|
||||
})
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
};
|
||||
if let Some(refresh_task) = refresh_task.ok() {
|
||||
refresh_task.await.log_err();
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if this.pending_predictions[0].id == pending_prediction_id {
|
||||
|
|
@ -198,24 +146,6 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
this.pending_predictions.clear();
|
||||
}
|
||||
|
||||
let Some(new_prediction) = prediction
|
||||
.context("edit prediction failed")
|
||||
.log_err()
|
||||
.flatten()
|
||||
else {
|
||||
cx.notify();
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(old_prediction) = this.current_prediction.as_ref() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
|
||||
this.current_prediction = Some(new_prediction);
|
||||
}
|
||||
} else {
|
||||
this.current_prediction = Some(new_prediction);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
|
|
@ -248,15 +178,18 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
) {
|
||||
}
|
||||
|
||||
fn accept(&mut self, _cx: &mut Context<Self>) {
|
||||
// TODO [zeta2] report accept
|
||||
self.current_prediction.take();
|
||||
fn accept(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.accept_current_prediction(&self.project);
|
||||
});
|
||||
self.pending_predictions.clear();
|
||||
}
|
||||
|
||||
fn discard(&mut self, _cx: &mut Context<Self>) {
|
||||
fn discard(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.discard_current_prediction(&self.project);
|
||||
});
|
||||
self.pending_predictions.clear();
|
||||
self.current_prediction.take();
|
||||
}
|
||||
|
||||
fn suggest(
|
||||
|
|
@ -265,36 +198,44 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction::EditPrediction> {
|
||||
let CurrentEditPrediction {
|
||||
buffer_id,
|
||||
prediction,
|
||||
..
|
||||
} = self.current_prediction.as_mut()?;
|
||||
let prediction =
|
||||
self.zeta
|
||||
.read(cx)
|
||||
.current_prediction_for_buffer(buffer, &self.project, cx)?;
|
||||
|
||||
// Invalidate previous prediction if it was generated for a different buffer.
|
||||
if *buffer_id != buffer.entity_id() {
|
||||
self.current_prediction.take();
|
||||
return None;
|
||||
}
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let buffer = buffer.read(cx);
|
||||
let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
|
||||
self.current_prediction.take();
|
||||
let snapshot = buffer.snapshot();
|
||||
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.discard_current_prediction(&self.project);
|
||||
});
|
||||
return None;
|
||||
};
|
||||
|
||||
let cursor_row = cursor_position.to_point(buffer).row;
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit =
|
||||
closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
|
|
@ -305,7 +246,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit =
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
|
|
@ -313,7 +254,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction::EditPrediction {
|
||||
Some(edit_prediction::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ use gpui::{
|
|||
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
|
||||
http_client, prelude::*,
|
||||
};
|
||||
use language::BufferSnapshot;
|
||||
use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
|
||||
use language::{BufferSnapshot, TextBufferSnapshot};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use project::Project;
|
||||
use release_channel::AppVersion;
|
||||
|
|
@ -35,7 +35,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
|
|||
mod prediction;
|
||||
mod provider;
|
||||
|
||||
use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
|
||||
use crate::prediction::EditPrediction;
|
||||
pub use provider::ZetaEditPredictionProvider;
|
||||
|
||||
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
|
@ -53,7 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
|
|||
excerpt: DEFAULT_EXCERPT_OPTIONS,
|
||||
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
|
||||
max_diagnostic_bytes: 2048,
|
||||
prompt_format: PromptFormat::MarkedExcerpt,
|
||||
prompt_format: PromptFormat::DEFAULT,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -94,6 +94,47 @@ struct ZetaProject {
|
|||
syntax_index: Entity<SyntaxIndex>,
|
||||
events: VecDeque<Event>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CurrentEditPrediction {
|
||||
pub requested_by_buffer_id: EntityId,
|
||||
pub prediction: EditPrediction,
|
||||
}
|
||||
|
||||
impl CurrentEditPrediction {
|
||||
fn should_replace_prediction(
|
||||
&self,
|
||||
old_prediction: &Self,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
) -> bool {
|
||||
if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let Some(new_edits) = self.prediction.interpolate(snapshot) else {
|
||||
return false;
|
||||
};
|
||||
if old_edits.len() == 1 && new_edits.len() == 1 {
|
||||
let (old_range, old_text) = &old_edits[0];
|
||||
let (new_range, new_text) = &new_edits[0];
|
||||
new_range == old_range && new_text.starts_with(old_text)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A prediction from the perspective of a buffer.
|
||||
#[derive(Debug)]
|
||||
enum BufferEditPrediction<'a> {
|
||||
Local { prediction: &'a EditPrediction },
|
||||
Jump { prediction: &'a EditPrediction },
|
||||
}
|
||||
|
||||
struct RegisteredBuffer {
|
||||
|
|
@ -204,6 +245,7 @@ impl Zeta {
|
|||
syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
|
||||
events: VecDeque::new(),
|
||||
registered_buffers: HashMap::new(),
|
||||
current_prediction: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -305,7 +347,83 @@ impl Zeta {
|
|||
events.push_back(event);
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
fn current_prediction_for_buffer(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<BufferEditPrediction<'_>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
|
||||
let CurrentEditPrediction {
|
||||
requested_by_buffer_id,
|
||||
prediction,
|
||||
} = project_state.current_prediction.as_ref()?;
|
||||
|
||||
if prediction.targets_buffer(buffer.read(cx), cx) {
|
||||
Some(BufferEditPrediction::Local { prediction })
|
||||
} else if *requested_by_buffer_id == buffer.entity_id() {
|
||||
Some(BufferEditPrediction::Jump { prediction })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_current_prediction(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.current_prediction.take();
|
||||
};
|
||||
// TODO report accepted
|
||||
}
|
||||
|
||||
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.current_prediction.take();
|
||||
};
|
||||
}
|
||||
|
||||
pub fn refresh_prediction(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let request_task = self.request_prediction(project, buffer, position, cx);
|
||||
let buffer = buffer.clone();
|
||||
let project = project.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(prediction) = request_task.await? {
|
||||
this.update(cx, |this, cx| {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project.entity_id())
|
||||
.context("Project not found")?;
|
||||
|
||||
let new_prediction = CurrentEditPrediction {
|
||||
requested_by_buffer_id: buffer.entity_id(),
|
||||
prediction: prediction,
|
||||
};
|
||||
|
||||
if project_state
|
||||
.current_prediction
|
||||
.as_ref()
|
||||
.is_none_or(|old_prediction| {
|
||||
new_prediction
|
||||
.should_replace_prediction(&old_prediction, buffer.read(cx))
|
||||
})
|
||||
{
|
||||
project_state.current_prediction = Some(new_prediction);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn request_prediction(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
|
|
@ -457,74 +575,63 @@ impl Zeta {
|
|||
.ok();
|
||||
}
|
||||
|
||||
let (response, usage) = response?;
|
||||
let edits = edits_from_response(&response.edits, &snapshot);
|
||||
|
||||
anyhow::Ok(Some((response.request_id, edits, usage)))
|
||||
anyhow::Ok(Some(response?))
|
||||
}
|
||||
});
|
||||
|
||||
let buffer = buffer.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
match request_task.await {
|
||||
Ok(Some((id, edits, usage))) => {
|
||||
if let Some(usage) = usage {
|
||||
this.update(cx, |this, cx| {
|
||||
this.user_store.update(cx, |user_store, cx| {
|
||||
user_store.update_edit_prediction_usage(usage, cx);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
// TODO telemetry: duration, etc
|
||||
let Some((edits, snapshot, edit_preview_task)) =
|
||||
buffer.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
let edits: Arc<[_]> =
|
||||
interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
|
||||
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
|
||||
})?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(EditPrediction {
|
||||
id: id.into(),
|
||||
edits,
|
||||
snapshot,
|
||||
edit_preview: edit_preview_task.await,
|
||||
}))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(err) => {
|
||||
if err.is::<ZedUpdateRequiredError>() {
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update_required = true;
|
||||
cx.spawn({
|
||||
let project = project.clone();
|
||||
async move |this, cx| {
|
||||
match request_task.await {
|
||||
Ok(Some((response, usage))) => {
|
||||
if let Some(usage) = usage {
|
||||
this.update(cx, |this, cx| {
|
||||
this.user_store.update(cx, |user_store, cx| {
|
||||
user_store.update_edit_prediction_usage(usage, cx);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
let error_message: SharedString = err.to_string().into();
|
||||
show_app_notification(
|
||||
NotificationId::unique::<ZedUpdateRequiredError>(),
|
||||
cx,
|
||||
move |cx| {
|
||||
cx.new(|cx| {
|
||||
ErrorMessagePrompt::new(error_message.clone(), cx)
|
||||
.with_link_button(
|
||||
"Update Zed",
|
||||
"https://zed.dev/releases",
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
let prediction = EditPrediction::from_response(
|
||||
response, &snapshot, &buffer, &project, cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
// TODO telemetry: duration, etc
|
||||
Ok(prediction)
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(err) => {
|
||||
if err.is::<ZedUpdateRequiredError>() {
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update_required = true;
|
||||
})
|
||||
.ok();
|
||||
|
||||
Err(err)
|
||||
let error_message: SharedString = err.to_string().into();
|
||||
show_app_notification(
|
||||
NotificationId::unique::<ZedUpdateRequiredError>(),
|
||||
cx,
|
||||
move |cx| {
|
||||
cx.new(|cx| {
|
||||
ErrorMessagePrompt::new(error_message.clone(), cx)
|
||||
.with_link_button(
|
||||
"Update Zed",
|
||||
"https://zed.dev/releases",
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -859,13 +966,113 @@ mod tests {
|
|||
};
|
||||
use indoc::indoc;
|
||||
use language::{LanguageServerId, OffsetRangeExt as _};
|
||||
use pretty_assertions::{assert_eq, assert_matches};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Zeta;
|
||||
use crate::{BufferEditPrediction, Zeta};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_current_state(cx: &mut TestAppContext) {
|
||||
let (zeta, mut req_rx) = init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"1.txt": "Hello!\nHow\nBye",
|
||||
"2.txt": "Hola!\nComo\nAdios"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
|
||||
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(&project, cx);
|
||||
});
|
||||
|
||||
let buffer1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let position = snapshot1.anchor_before(language::Point::new(1, 3));
|
||||
|
||||
// Prediction for current file
|
||||
|
||||
let prediction_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer1, position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = req_rx.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(predict_edits_v3::PredictEditsResponse {
|
||||
request_id: Uuid::new_v4(),
|
||||
edits: vec![predict_edits_v3::Edit {
|
||||
path: Path::new(path!("root/1.txt")).into(),
|
||||
range: 0..snapshot1.len(),
|
||||
content: "Hello!\nHow are you?\nBye".into(),
|
||||
}],
|
||||
debug_info: None,
|
||||
})
|
||||
.unwrap();
|
||||
prediction_task.await.unwrap();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
|
||||
// Prediction for another file
|
||||
|
||||
let prediction_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_prediction(&project, &buffer1, position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = req_rx.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(predict_edits_v3::PredictEditsResponse {
|
||||
request_id: Uuid::new_v4(),
|
||||
edits: vec![predict_edits_v3::Edit {
|
||||
path: Path::new(path!("root/2.txt")).into(),
|
||||
range: 0..snapshot1.len(),
|
||||
content: "Hola!\nComo estas?\nAdios".into(),
|
||||
}],
|
||||
debug_info: None,
|
||||
})
|
||||
.unwrap();
|
||||
prediction_task.await.unwrap();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
prediction,
|
||||
BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
|
||||
);
|
||||
});
|
||||
|
||||
let buffer2 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
zeta.read_with(cx, |zeta, cx| {
|
||||
let prediction = zeta
|
||||
.current_prediction_for_buffer(&buffer2, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
|
|
@ -1146,6 +1353,7 @@ mod tests {
|
|||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
let zeta = Zeta::global(&client, &user_store, cx);
|
||||
|
||||
(zeta, req_rx)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ impl Zeta2Inspector {
|
|||
cx.background_executor().timer(THROTTLE_TIME).await;
|
||||
if let Some(task) = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
zeta.request_prediction(&project, &buffer, position, cx)
|
||||
zeta.refresh_prediction(&project, &buffer, position, cx)
|
||||
})
|
||||
.ok()
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in a new issue