ep: Allow setting expected patch in rate completions modal (#56629)

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [ ] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- Added the ability to input the "expected patch", i.e. what the model
should have predicted in the `edit prediction: rate completions` modal
This commit is contained in:
Ben Kunkle 2026-05-13 11:03:01 -05:00 committed by GitHub
parent f56219e503
commit 49c6f78fc1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 439 additions and 56 deletions

View file

@ -113,5 +113,7 @@ pub struct SubmitEditPredictionFeedbackBody {
pub rating: String,
pub inputs: serde_json::Value,
pub output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expected_output: Option<String>,
pub feedback: String,
}

View file

@ -2805,6 +2805,7 @@ impl EditPredictionStore {
prediction: &EditPrediction,
rating: EditPredictionRating,
feedback: String,
expected_output: Option<String>,
cx: &mut Context<Self>,
) {
let organization = self.user_store.read(cx).current_organization();
@ -2830,6 +2831,7 @@ impl EditPredictionStore {
},
inputs: inputs?,
output,
expected_output,
feedback,
})
.await?;

View file

@ -2608,6 +2608,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let prediction = EditPrediction {
edits,
cursor_position: None,
editable_range: None,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),

View file

@ -17,6 +17,7 @@ const FIM_CONTEXT_TOKENS: usize = 512;
struct FimRequestOutput {
request_id: String,
edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
editable_range: std::ops::Range<Anchor>,
snapshot: BufferSnapshot,
inputs: ZetaPromptInput,
buffer: Entity<Buffer>,
@ -127,9 +128,15 @@ pub fn request_prediction(
vec![(anchor..anchor, completion)]
};
let editable_range = snapshot.anchor_range_inside(
(excerpt_offset_range.start + editable_range.start)
..(excerpt_offset_range.start + editable_range.end),
);
anyhow::Ok(FimRequestOutput {
request_id,
edits,
editable_range,
snapshot,
inputs,
buffer,
@ -145,6 +152,7 @@ pub fn request_prediction(
&output.snapshot,
output.edits.into(),
None,
Some(output.editable_range),
output.inputs,
None,
cx.background_executor().now() - request_start,

View file

@ -223,7 +223,9 @@ impl Mercury {
);
}
anyhow::Ok((id, edits, snapshot, inputs))
let editable_range = snapshot.anchor_range_inside(editable_offset_range);
anyhow::Ok((id, edits, snapshot, inputs, editable_range))
});
cx.spawn(async move |ep_store, cx| {
@ -241,7 +243,7 @@ impl Mercury {
cx.notify();
})?;
let (id, edits, old_snapshot, inputs) = result?;
let (id, edits, old_snapshot, inputs, editable_range) = result?;
anyhow::Ok(Some(
EditPredictionResult::new(
EditPredictionId(id.into()),
@ -249,6 +251,7 @@ impl Mercury {
&old_snapshot,
edits.into(),
None,
Some(editable_range),
inputs,
None,
cx.background_executor().now() - request_start,

View file

@ -36,6 +36,7 @@ impl EditPredictionResult {
edited_buffer_snapshot: &BufferSnapshot,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
cursor_position: Option<PredictedCursorPosition>,
editable_range: Option<Range<Anchor>>,
inputs: ZetaPromptInput,
model_version: Option<String>,
e2e_latency: std::time::Duration,
@ -75,6 +76,7 @@ impl EditPredictionResult {
id,
edits,
cursor_position,
editable_range,
snapshot,
edit_preview,
inputs,
@ -92,6 +94,7 @@ pub struct EditPrediction {
pub id: EditPredictionId,
pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
pub cursor_position: Option<PredictedCursorPosition>,
pub editable_range: Option<Range<Anchor>>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
pub buffer: Entity<Buffer>,
@ -145,6 +148,7 @@ mod tests {
id: EditPredictionId("prediction-1".into()),
edits,
cursor_position: None,
editable_range: None,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,

View file

@ -396,6 +396,7 @@ pub fn request_prediction_with_zeta(
&edited_buffer_snapshot,
edits.into(),
cursor_position,
Some(edited_buffer_snapshot.anchor_range_inside(editable_range_in_buffer.clone())),
inputs,
model_version,
request_duration,

View file

@ -6,14 +6,18 @@ use gpui::{
App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*,
};
use language::{Buffer, CodeLabel, LanguageRegistry, Point, ToOffset, language_settings};
use language::{
Anchor, Bias, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry, Point, ToOffset, ToPoint,
language_settings::{self, InlayHintKind},
};
use markdown::{Markdown, MarkdownStyle};
use project::{
Completion, CompletionDisplayOptions, CompletionResponse, CompletionSource, InlayId,
Completion, CompletionDisplayOptions, CompletionResponse, CompletionSource, InlayHint,
InlayHintLabel, InlayId, ResolveState,
};
use settings::Settings as _;
use std::rc::Rc;
use std::{fmt::Write, sync::Arc};
use std::{fmt::Write, ops::Range, sync::Arc};
use theme_settings::ThemeSettings;
use ui::{
ContextMenu, DropdownMenu, KeyBinding, List, ListItem, ListItemSpacing, PopoverMenuHandle,
@ -62,6 +66,11 @@ pub struct RatePredictionsModal {
struct ActivePrediction {
prediction: EditPrediction,
feedback_editor: Entity<Editor>,
expected_buffer: Entity<Buffer>,
expected_editable_range: Option<Range<Anchor>>,
expected_editor: Entity<Editor>,
expected_diff_editor: Entity<Editor>,
expected_patch_preview: bool,
formatted_inputs: Entity<Markdown>,
}
@ -204,6 +213,7 @@ impl RatePredictionsModal {
&active.prediction,
EditPredictionRating::Positive,
active.feedback_editor.read(cx).text(cx),
self.expected_patch_for_active(cx),
cx,
);
}
@ -236,6 +246,7 @@ impl RatePredictionsModal {
&active.prediction,
EditPredictionRating::Negative,
active.feedback_editor.read(cx).text(cx),
self.expected_patch_for_active(cx),
cx,
);
});
@ -293,6 +304,145 @@ impl RatePredictionsModal {
self.select_completion(completion, true, window, cx);
}
fn update_diff_editor(
diff_editor: &Entity<Editor>,
new_buffer: Entity<Buffer>,
old_buffer_snapshot: BufferSnapshot,
visible_range: Range<Point>,
cx: &mut Context<Self>,
) {
diff_editor.update(cx, |editor, cx| {
let new_buffer_snapshot = new_buffer.read(cx).snapshot();
let new_buffer_id = new_buffer_snapshot.remote_id();
let language = new_buffer_snapshot.language().cloned();
let diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot.text, cx));
diff.update(cx, |diff, cx| {
let update = diff.update_diff(
new_buffer_snapshot.text.clone(),
Some(old_buffer_snapshot.text().into()),
Some(true),
language,
cx,
);
cx.spawn(async move |diff, cx| {
let update = update.await;
if let Some(task) = diff
.update(cx, |diff, cx| {
diff.set_snapshot(update, &new_buffer_snapshot.text, cx)
})
.ok()
{
task.await;
}
})
.detach();
});
editor.disable_header_for_buffer(new_buffer_id, cx);
editor.buffer().update(cx, |multibuffer, cx| {
multibuffer.clear(cx);
multibuffer.set_excerpts_for_buffer(new_buffer, [visible_range], 0, cx);
multibuffer.add_diff(diff, cx);
});
});
}
fn editable_range_for_prediction(prediction: &EditPrediction) -> Option<Range<Anchor>> {
prediction
.editable_range
.clone()
.or_else(|| Some(prediction.edits.first()?.0.start..prediction.edits.last()?.0.end))
}
fn insert_editable_region_markers(
editor: &Entity<Editor>,
buffer: &Entity<Buffer>,
marker_range: Range<usize>,
cx: &mut Context<Self>,
) {
editor.update(cx, |editor, cx| {
let buffer_snapshot = buffer.read(cx).snapshot();
let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx);
let start_buffer_anchor = buffer_snapshot
.anchor_after(buffer_snapshot.clip_offset(marker_range.start, Bias::Left));
let end_buffer_anchor = buffer_snapshot
.anchor_after(buffer_snapshot.clip_offset(marker_range.end, Bias::Right));
let Some(start_anchor) = multibuffer_snapshot.anchor_in_excerpt(start_buffer_anchor)
else {
return;
};
let Some(end_anchor) = multibuffer_snapshot.anchor_in_excerpt(end_buffer_anchor) else {
return;
};
let Some((start_hint_position, _)) =
multibuffer_snapshot.anchor_to_buffer_anchor(start_anchor)
else {
return;
};
let Some((end_hint_position, _)) =
multibuffer_snapshot.anchor_to_buffer_anchor(end_anchor)
else {
return;
};
editor.splice_inlays(
&[InlayId::Hint(0), InlayId::Hint(1)],
vec![
Inlay::hint(
InlayId::Hint(0),
start_anchor,
&InlayHint {
position: start_hint_position,
label: InlayHintLabel::String("╭─ editable region start\n".into()),
kind: Some(InlayHintKind::Parameter),
padding_left: false,
padding_right: false,
tooltip: None,
resolve_state: ResolveState::Resolved,
},
),
Inlay::hint(
InlayId::Hint(1),
end_anchor,
&InlayHint {
position: end_hint_position,
label: InlayHintLabel::String("\n╰─ editable region end".into()),
kind: Some(InlayHintKind::Parameter),
padding_left: false,
padding_right: false,
tooltip: None,
resolve_state: ResolveState::Resolved,
},
),
],
cx,
);
});
}
fn expected_patch_for_active(&self, cx: &App) -> Option<String> {
let active_prediction = self.active_prediction.as_ref()?;
let expected_text = active_prediction.expected_buffer.read(cx).snapshot().text();
let original_text = active_prediction.prediction.snapshot.text();
let diff_body = language::unified_diff(&original_text, &expected_text);
if diff_body.is_empty() {
return None;
}
let path = active_prediction
.prediction
.snapshot
.file()
.map(|file| file.path().as_unix_str());
let header = match path {
Some(path) => format!("--- a/{path}\n+++ b/{path}\n"),
None => String::new(),
};
Some(format!("{header}{diff_body}"))
}
pub fn select_completion(
&mut self,
prediction: Option<EditPrediction>,
@ -321,57 +471,49 @@ impl RatePredictionsModal {
return;
}
let editable_range = Self::editable_range_for_prediction(&prediction);
let predicted_buffer = prediction.edit_preview.build_result_buffer(cx);
let predicted_buffer_snapshot = predicted_buffer.read(cx).snapshot();
let visible_range = prediction
.edit_preview
.compute_visible_range(&prediction.edits)
.unwrap_or(Point::zero()..Point::zero());
let start = Point::new(visible_range.start.row.saturating_sub(5), 0);
let end =
Point::new(visible_range.end.row + 5, 0).min(predicted_buffer_snapshot.max_point());
Self::update_diff_editor(
&self.diff_editor,
predicted_buffer.clone(),
prediction.snapshot.clone(),
start..end,
cx,
);
if let Some(editable_range) = editable_range.as_ref() {
Self::insert_editable_region_markers(
&self.diff_editor,
&predicted_buffer,
prediction
.edit_preview
.anchor_to_offset_in_result(editable_range.start)
..prediction
.edit_preview
.anchor_to_offset_in_result(editable_range.end),
cx,
);
}
self.diff_editor.update(cx, |editor, cx| {
let new_buffer = prediction.edit_preview.build_result_buffer(cx);
let new_buffer_snapshot = new_buffer.read(cx).snapshot();
let old_buffer_snapshot = prediction.snapshot.clone();
let new_buffer_id = new_buffer_snapshot.remote_id();
let range = prediction
.edit_preview
.compute_visible_range(&prediction.edits)
.unwrap_or(Point::zero()..Point::zero());
let start = Point::new(range.start.row.saturating_sub(5), 0);
let end = Point::new(range.end.row + 5, 0).min(new_buffer_snapshot.max_point());
let language = new_buffer_snapshot.language().cloned();
let diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot.text, cx));
diff.update(cx, |diff, cx| {
let update = diff.update_diff(
new_buffer_snapshot.text.clone(),
Some(old_buffer_snapshot.text().into()),
Some(true),
language,
cx,
);
cx.spawn(async move |diff, cx| {
let update = update.await;
if let Some(task) = diff
.update(cx, |diff, cx| {
diff.set_snapshot(update, &new_buffer_snapshot.text, cx)
})
.ok()
{
task.await;
}
})
.detach();
});
editor.disable_header_for_buffer(new_buffer_id, cx);
editor.buffer().update(cx, |multibuffer, cx| {
multibuffer.clear(cx);
multibuffer.set_excerpts_for_buffer(new_buffer.clone(), [start..end], 0, cx);
multibuffer.add_diff(diff, cx);
});
if let Some(cursor_position) = prediction.cursor_position.as_ref() {
let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx);
let cursor_offset = prediction
.edit_preview
.anchor_to_offset_in_result(cursor_position.anchor)
+ cursor_position.offset;
let cursor_anchor = new_buffer.read(cx).snapshot().anchor_after(cursor_offset);
let predicted_buffer_snapshot = predicted_buffer.read(cx).snapshot();
let cursor_anchor = predicted_buffer_snapshot.anchor_after(
predicted_buffer_snapshot.clip_offset(cursor_offset, Bias::Right),
);
if let Some(anchor) = multibuffer_snapshot.anchor_in_excerpt(cursor_anchor) {
editor.splice_inlays(
@ -422,15 +564,111 @@ impl RatePredictionsModal {
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
let mut cursor_offset = prediction
.inputs
.cursor_offset_in_excerpt
.min(prediction.inputs.cursor_excerpt.len());
while !prediction
.inputs
.cursor_excerpt
.is_char_boundary(cursor_offset)
{
cursor_offset = cursor_offset.saturating_sub(1);
}
writeln!(
&mut formatted_inputs,
"```{}\n{}<CURSOR>{}\n```\n",
prediction.inputs.cursor_path.display(),
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
&prediction.inputs.cursor_excerpt[..cursor_offset],
&prediction.inputs.cursor_excerpt[cursor_offset..],
)
.unwrap();
let current_editable_region = editable_range.as_ref().map(|range| {
prediction
.buffer
.read(cx)
.snapshot()
.text_for_range(range.clone())
.collect::<String>()
});
let expected_buffer = cx.new(|cx| {
let mut buffer = Buffer::local(prediction.snapshot.text(), cx);
buffer.set_language_async(prediction.snapshot.language().cloned(), cx);
buffer
});
let expected_editable_range = editable_range.as_ref().map(|editable_range| {
expected_buffer.update(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
let editable_point_range = editable_range.start.to_point(&prediction.snapshot)
..editable_range.end.to_point(&prediction.snapshot);
let expected_editable_range = snapshot.anchor_before(editable_point_range.start)
..snapshot.anchor_after(editable_point_range.end);
if let Some(current_editable_region) = current_editable_region {
buffer.edit(
[(expected_editable_range.clone(), current_editable_region)],
None,
cx,
);
}
expected_editable_range
})
});
let expected_buffer_snapshot = expected_buffer.read(cx).snapshot();
let expected_excerpt_range = expected_editable_range
.as_ref()
.map(|range| {
range.start.to_point(&expected_buffer_snapshot)
..range.end.to_point(&expected_buffer_snapshot)
})
.unwrap_or_else(|| visible_range.clone());
let expected_editor = cx.new(|cx| {
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::new(language::Capability::ReadWrite);
multibuffer.set_excerpts_for_buffer(
expected_buffer.clone(),
[expected_excerpt_range],
0,
cx,
);
multibuffer
});
let mut editor = Editor::for_multibuffer(multibuffer, None, window, cx);
let expected_buffer_id = expected_buffer.read(cx).remote_id();
editor.disable_header_for_buffer(expected_buffer_id, cx);
editor.disable_inline_diagnostics();
editor.set_show_git_diff_gutter(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_runnables(false, cx);
editor.set_show_bookmarks(false, cx);
editor.set_show_breakpoints(false, cx);
editor.set_show_wrap_guides(false, cx);
editor.set_show_edit_predictions(Some(false), window, cx);
editor
});
let expected_diff_editor = cx.new(|cx| {
let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly));
let mut editor = Editor::for_multibuffer(multibuffer, None, window, cx);
editor.disable_inline_diagnostics();
editor.set_expand_all_diff_hunks(cx);
editor.set_show_git_diff_gutter(false, cx);
editor
});
if let Some(expected_editable_range) = expected_editable_range.as_ref() {
let expected_buffer_snapshot = expected_buffer.read(cx).snapshot();
Self::insert_editable_region_markers(
&expected_editor,
&expected_buffer,
expected_editable_range
.start
.to_offset(&expected_buffer_snapshot)
..expected_editable_range
.end
.to_offset(&expected_buffer_snapshot),
cx,
);
}
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {
@ -453,6 +691,11 @@ impl RatePredictionsModal {
}
editor
}),
expected_buffer,
expected_editable_range,
expected_editor,
expected_diff_editor,
expected_patch_preview: false,
formatted_inputs: cx.new(|cx| {
Markdown::new(
formatted_inputs.into(),
@ -503,17 +746,136 @@ impl RatePredictionsModal {
)
}
fn toggle_expected_patch_preview(&mut self, cx: &mut Context<Self>) {
if let Some(active_prediction) = &mut self.active_prediction {
if active_prediction.expected_patch_preview {
active_prediction.expected_patch_preview = false;
} else {
let expected_buffer_snapshot =
active_prediction.expected_buffer.read(cx).snapshot();
let visible_range = active_prediction
.prediction
.edit_preview
.compute_visible_range(&active_prediction.prediction.edits)
.unwrap_or(Point::zero()..Point::zero());
let start = Point::new(visible_range.start.row.saturating_sub(5), 0);
let end = Point::new(visible_range.end.row + 5, 0)
.min(expected_buffer_snapshot.max_point());
Self::update_diff_editor(
&active_prediction.expected_diff_editor,
active_prediction.expected_buffer.clone(),
active_prediction.prediction.snapshot.clone(),
start..end,
cx,
);
if let Some(expected_editable_range) =
active_prediction.expected_editable_range.as_ref()
{
let expected_buffer_snapshot =
active_prediction.expected_buffer.read(cx).snapshot();
Self::insert_editable_region_markers(
&active_prediction.expected_diff_editor,
&active_prediction.expected_buffer,
expected_editable_range
.start
.to_offset(&expected_buffer_snapshot)
..expected_editable_range
.end
.to_offset(&expected_buffer_snapshot),
cx,
);
}
active_prediction.expected_patch_preview = true;
}
cx.notify();
}
}
fn render_suggested_edits(&self, cx: &mut Context<Self>) -> Option<gpui::Stateful<Div>> {
let bg_color = cx.theme().colors().editor_background;
let border_color = cx.theme().colors().border;
let active_prediction = self.active_prediction.as_ref()?;
let expected_patch_preview = active_prediction.expected_patch_preview;
Some(
div()
v_flex()
.id("diff")
.p_4()
.size_full()
.bg(bg_color)
.overflow_scroll()
.whitespace_nowrap()
.child(self.diff_editor.clone()),
.overflow_hidden()
.child(
v_flex()
.flex_1()
.min_h_0()
.child(
h_flex()
.h_8()
.px_2()
.border_b_1()
.border_color(border_color)
.child(Label::new("Predicted Patch").size(LabelSize::Small)),
)
.child(
div()
.id("predicted-patch-diff")
.p_4()
.flex_1()
.min_h_0()
.overflow_scroll()
.whitespace_nowrap()
.child(self.diff_editor.clone()),
),
)
.child(
v_flex()
.flex_1()
.min_h_0()
.border_t_1()
.border_color(border_color)
.child(
h_flex()
.h_8()
.px_2()
.gap_2()
.border_b_1()
.border_color(border_color)
.child(
Button::new(
"expected-patch-preview",
if expected_patch_preview {
"Edit"
} else {
"Preview"
},
)
.label_size(LabelSize::Small)
.on_click(cx.listener(
|this, _, _window, cx| {
this.toggle_expected_patch_preview(cx);
},
)),
)
.child(Label::new("Expected Patch").size(LabelSize::Small)),
)
.child(
div()
.id("expected-patch")
.p_4()
.flex_1()
.min_h_0()
.overflow_scroll()
.whitespace_nowrap()
.child(if expected_patch_preview {
active_prediction
.expected_diff_editor
.clone()
.into_any_element()
} else {
active_prediction.expected_editor.clone().into_any_element()
}),
),
),
)
}