Track additional metrics in settled (#52938)

Stacked on https://github.com/zed-industries/zed/pull/50566.

Begin collecting kept chars rate, as well as the count of tree-sitter
errors in the code before and after applying the prediction.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [ ] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A or Added/Fixed/Improved ...
This commit is contained in:
Ben Kunkle 2026-04-08 17:39:17 -05:00 committed by GitHub
parent 364ebfcc07
commit 7597666c08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 905 additions and 129 deletions

4
Cargo.lock generated
View file

@ -5176,6 +5176,7 @@ dependencies = [
"copilot",
"copilot_ui",
"credentials_provider",
"criterion",
"ctor",
"db",
"edit_prediction_context",
@ -5189,9 +5190,11 @@ dependencies = [
"itertools 0.14.0",
"language",
"language_model",
"languages",
"log",
"lsp",
"menu",
"node_runtime",
"open_ai",
"parking_lot",
"postage",
@ -5235,7 +5238,6 @@ dependencies = [
"client",
"cloud_llm_client",
"collections",
"criterion",
"db",
"debug_adapter_extension",
"dirs 4.0.0",

View file

@ -72,10 +72,14 @@ zeta_prompt.workspace = true
zstd.workspace = true
[dev-dependencies]
criterion.workspace = true
fs = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
languages = { workspace = true, features = ["load-grammars"] }
node_runtime.workspace = true
clock = { workspace = true, features = ["test-support"] }
cloud_llm_client = { workspace = true, features = ["test-support"] }
ctor.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
@ -86,3 +90,11 @@ settings = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
zlog.workspace = true
[[bench]]
name = "kept_rate"
harness = false
[[bench]]
name = "ts_error_count"
harness = false

View file

@ -1,5 +1,5 @@
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use edit_prediction_cli::kept_rate::compute_kept_rate;
use edit_prediction::metrics::compute_kept_rate;
fn repeated_function_lines(line_count: usize) -> String {
let mut text = String::with_capacity(line_count * 32);

View file

@ -0,0 +1,454 @@
use std::sync::Arc;
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use edit_prediction::metrics::count_tree_sitter_errors;
use fs::FakeFs;
use gpui::{AppContext as _, TestAppContext};
use language::{Buffer, BufferSnapshot, LanguageRegistry};
use languages::init as init_languages;
use node_runtime::NodeRuntime;
use settings::SettingsStore;
struct ParsedCase {
label: String,
bytes: usize,
error_count: usize,
snapshot: BufferSnapshot,
}
fn replace_nth_occurrences(
source: &mut String,
needle: &str,
replacement: &str,
every: usize,
max_replacements: usize,
) {
let mut rebuilt = String::with_capacity(source.len());
let mut cursor = 0;
let mut seen = 0;
let mut replaced = 0;
while let Some(relative_index) = source[cursor..].find(needle) {
let start = cursor + relative_index;
let end = start + needle.len();
rebuilt.push_str(&source[cursor..start]);
if seen % every == 0 && replaced < max_replacements {
rebuilt.push_str(replacement);
replaced += 1;
} else {
rebuilt.push_str(needle);
}
seen += 1;
cursor = end;
}
rebuilt.push_str(&source[cursor..]);
*source = rebuilt;
}
fn rust_source(function_count: usize) -> String {
let mut source = String::from(
"pub struct Counter {\n value: usize,\n}\n\nimpl Counter {\n pub fn new() -> Self {\n Self { value: 0 }\n }\n}\n\n",
);
for index in 0..function_count {
source.push_str(&format!(
"pub fn compute_value_{index}(input: usize) -> usize {{\n let mut total = input;\n for offset in 0..32 {{\n total += offset + {index};\n }}\n if total % 2 == 0 {{\n total / 2\n }} else {{\n total * 3 + 1\n }}\n}}\n\n"
));
}
source
}
fn rust_source_with_errors(function_count: usize) -> String {
let mut source = rust_source(function_count);
replace_nth_occurrences(
&mut source,
" if total % 2 == 0 {\n",
" if total % 2 == 0 \n",
17,
48,
);
source
}
fn python_source(function_count: usize) -> String {
let mut source = String::from(
"class Counter:\n def __init__(self) -> None:\n self.value = 0\n\n\n",
);
for index in 0..function_count {
source.push_str(&format!(
"def compute_value_{index}(input_value: int) -> int:\n total = input_value\n for offset in range(32):\n total += offset + {index}\n if total % 2 == 0:\n return total // 2\n return total * 3 + 1\n\n"
));
}
source
}
fn python_source_with_errors(function_count: usize) -> String {
let mut source = python_source(function_count);
replace_nth_occurrences(
&mut source,
" if total % 2 == 0:\n",
" if total % 2 == 0\n",
19,
48,
);
source
}
fn go_source(function_count: usize) -> String {
let mut source = String::from(
"package bench\n\ntype Counter struct {\n\tvalue int\n}\n\nfunc NewCounter() Counter {\n\treturn Counter{value: 0}\n}\n\n",
);
for index in 0..function_count {
source.push_str(&format!(
"func ComputeValue{index}(inputValue int) int {{\n\ttotal := inputValue\n\tfor offset := 0; offset < 32; offset++ {{\n\t\ttotal += offset + {index}\n\t}}\n\tif total%2 == 0 {{\n\t\treturn total / 2\n\t}}\n\treturn total*3 + 1\n}}\n\n"
));
}
source
}
fn go_source_with_errors(function_count: usize) -> String {
let mut source = go_source(function_count);
replace_nth_occurrences(
&mut source,
"\tfor offset := 0; offset < 32; offset++ {\n",
"\tfor offset := 0; offset < 32; offset++ \n",
17,
48,
);
source
}
fn typescript_source(function_count: usize) -> String {
let mut source = String::from(
"export type Counter = { value: number };\n\nexport function newCounter(): Counter {\n return { value: 0 };\n}\n\n",
);
for index in 0..function_count {
source.push_str(&format!(
"export function computeValue{index}(inputValue: number): number {{\n let total = inputValue;\n for (let offset = 0; offset < 32; offset += 1) {{\n total += offset + {index};\n }}\n return total % 2 === 0 ? total / 2 : total * 3 + 1;\n}}\n\n"
));
}
source
}
fn typescript_source_with_errors(function_count: usize) -> String {
let mut source = typescript_source(function_count);
replace_nth_occurrences(
&mut source,
" return total % 2 === 0 ? total / 2 : total * 3 + 1;\n",
" return total % 2 === 0 ? total / 2 : ;\n",
17,
64,
);
source
}
fn tsx_source(component_count: usize) -> String {
let mut source = String::from(
"type ItemProps = { index: number; label: string };\n\nfunction Item({ index, label }: ItemProps) {\n return <li data-index={index}>{label}</li>;\n}\n\nexport function App() {\n return <section><ul>{[0, 1, 2].map((value) => <Item key={value} index={value} label={`item-${value}`} />)}</ul></section>;\n}\n\n",
);
for index in 0..component_count {
source.push_str(&format!(
"export function Widget{index}(): JSX.Element {{\n const items = Array.from({{ length: 16 }}, (_, value) => value + {index});\n return (\n <div className=\"widget-{index}\">\n <h2>Widget {index}</h2>\n <ul>\n {{items.map((value) => (\n <Item key={{value}} index={{value}} label={{`widget-{index}-${{value}}`}} />\n ))}}\n </ul>\n </div>\n );\n}}\n\n"
));
}
source
}
fn tsx_source_with_errors(component_count: usize) -> String {
let mut source = tsx_source(component_count);
replace_nth_occurrences(
&mut source,
" const items = Array.from({ length: 16 }, (_, value) => value + ",
" const items = Array.from({ length: 16 }, (_, value) => ); // ",
11,
32,
);
source
}
fn json_source(object_count: usize) -> String {
let mut source = String::from("{\n \"items\": [\n");
for index in 0..object_count {
let suffix = if index + 1 == object_count { "" } else { "," };
source.push_str(&format!(
" {{\n \"id\": {index},\n \"name\": \"item-{index}\",\n \"enabled\": true,\n \"tags\": [\"alpha\", \"beta\", \"gamma\"],\n \"metrics\": {{ \"count\": {}, \"ratio\": {} }}\n }}{suffix}\n",
index * 3 + 1,
index as f64 / 10.0,
));
}
source.push_str(" ]\n}\n");
source
}
fn json_source_with_errors(object_count: usize) -> String {
let mut source = json_source(object_count);
replace_nth_occurrences(
&mut source,
" \"enabled\": true,\n",
" \"enabled\": ,\n",
23,
64,
);
source
}
fn yaml_source(document_count: usize) -> String {
let mut source = String::new();
for index in 0..document_count {
source.push_str(&format!(
"- id: {index}\n name: item-{index}\n enabled: true\n tags:\n - alpha\n - beta\n - gamma\n metrics:\n count: {}\n ratio: {}\n",
index * 3 + 1,
index as f64 / 10.0,
));
}
source
}
fn yaml_source_with_errors(document_count: usize) -> String {
let mut source = yaml_source(document_count);
replace_nth_occurrences(&mut source, " count: ", " count ", 23, 64);
source
}
fn css_source(rule_count: usize) -> String {
let mut source = String::new();
for index in 0..rule_count {
source.push_str(&format!(
".widget-{index} {{\n display: grid;\n grid-template-columns: repeat(4, minmax(0, 1fr));\n gap: 12px;\n padding: 8px;\n color: rgb({}, {}, {});\n}}\n\n.widget-{index} > .item-{index} {{\n border: 1px solid rgba(0, 0, 0, 0.15);\n background: linear-gradient(90deg, #fff, #eef);\n}}\n\n",
(index * 17) % 255,
(index * 31) % 255,
(index * 47) % 255,
));
}
source
}
fn css_source_with_errors(rule_count: usize) -> String {
let mut source = css_source(rule_count);
replace_nth_occurrences(&mut source, " gap: 12px;\n", " gap 12px;\n", 29, 64);
source
}
fn build_case(
context: &mut TestAppContext,
languages: &Arc<LanguageRegistry>,
language_name: &'static str,
variant_name: &'static str,
source: String,
expect_errors: bool,
) -> ParsedCase {
let language_task = context.background_spawn({
let languages = languages.clone();
async move { languages.language_for_name(language_name).await }
});
while !language_task.is_ready() {
context.run_until_parked();
}
let language = futures::executor::block_on(language_task)
.unwrap_or_else(|error| panic!("failed to load {language_name}: {error}"));
let buffer = context.new(|cx| Buffer::local(source, cx).with_language(language, cx));
context.run_until_parked();
while buffer.read_with(context, |buffer, _| buffer.is_parsing()) {
context.run_until_parked();
}
let snapshot = buffer.read_with(context, |buffer, _| buffer.snapshot());
let full_range = 0..snapshot.text.len();
let error_count = count_tree_sitter_errors(snapshot.syntax_layers());
if expect_errors {
assert!(
error_count > 0,
"expected tree-sitter errors for {language_name}/{variant_name}",
);
} else {
assert_eq!(
error_count, 0,
"expected no tree-sitter errors for {language_name}/{variant_name}",
);
}
let label = format!(
"{}/{}_{}kb_{}e",
language_name.to_lowercase(),
variant_name,
full_range.end / 1024,
error_count,
);
ParsedCase {
label,
bytes: full_range.end,
error_count,
snapshot,
}
}
fn parsed_cases() -> Vec<ParsedCase> {
let mut context = TestAppContext::single();
context.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
let languages = Arc::new(LanguageRegistry::new(context.executor()));
let fs = FakeFs::new(context.executor());
let node_runtime = NodeRuntime::unavailable();
context.update(|cx| init_languages(languages.clone(), fs, node_runtime, cx));
vec![
build_case(
&mut context,
&languages,
"Rust",
"valid",
rust_source(900),
false,
),
build_case(
&mut context,
&languages,
"Rust",
"error_heavy",
rust_source_with_errors(900),
true,
),
build_case(
&mut context,
&languages,
"Python",
"valid",
python_source(1100),
false,
),
build_case(
&mut context,
&languages,
"Python",
"error_heavy",
python_source_with_errors(1100),
true,
),
build_case(
&mut context,
&languages,
"Go",
"valid",
go_source(1000),
false,
),
build_case(
&mut context,
&languages,
"Go",
"error_heavy",
go_source_with_errors(1000),
true,
),
build_case(
&mut context,
&languages,
"TypeScript",
"valid",
typescript_source(1000),
false,
),
build_case(
&mut context,
&languages,
"TypeScript",
"error_heavy",
typescript_source_with_errors(1000),
true,
),
build_case(
&mut context,
&languages,
"TSX",
"valid",
tsx_source(350),
false,
),
build_case(
&mut context,
&languages,
"TSX",
"error_heavy",
tsx_source_with_errors(350),
true,
),
build_case(
&mut context,
&languages,
"JSON",
"valid",
json_source(2200),
false,
),
build_case(
&mut context,
&languages,
"JSON",
"error_heavy",
json_source_with_errors(2200),
true,
),
build_case(
&mut context,
&languages,
"YAML",
"valid",
yaml_source(2200),
false,
),
build_case(
&mut context,
&languages,
"YAML",
"error_heavy",
yaml_source_with_errors(2200),
true,
),
build_case(
&mut context,
&languages,
"CSS",
"valid",
css_source(2400),
false,
),
build_case(
&mut context,
&languages,
"CSS",
"error_heavy",
css_source_with_errors(2400),
true,
),
]
}
fn ts_error_count_benchmark(c: &mut Criterion) {
let cases = parsed_cases();
let mut group = c.benchmark_group("ts_error_count/full_file");
for case in &cases {
group.bench_with_input(
BenchmarkId::from_parameter(&case.label),
case,
|bench, case| {
bench.iter(|| {
black_box(case.bytes);
black_box(case.error_count);
black_box(count_tree_sitter_errors(case.snapshot.syntax_layers()))
});
},
);
}
group.finish();
}
criterion_group!(benches, ts_error_count_benchmark);
criterion_main!(benches);

View file

@ -30,7 +30,7 @@ use gpui::{
};
use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
@ -61,6 +61,7 @@ pub mod example_spec;
pub mod fim;
mod license_detection;
pub mod mercury;
pub mod metrics;
pub mod ollama;
mod onboarding_modal;
pub mod open_ai_response;
@ -80,6 +81,7 @@ use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
use crate::example_spec::ExampleSpec;
use crate::license_detection::LicenseDetectionWatcher;
use crate::mercury::Mercury;
pub use crate::metrics::{KeptRateResult, compute_kept_rate};
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
@ -478,10 +480,13 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
}
#[derive(Clone)]
struct PendingSettledPrediction {
request_id: EditPredictionId,
editable_anchor_range: Range<Anchor>,
editable_region_before_prediction: String,
predicted_editable_region: String,
ts_error_count_before_prediction: usize,
ts_error_count_after_prediction: usize,
example: Option<ExampleSpec>,
enqueued_at: Instant,
last_edit_at: Instant,
@ -1603,63 +1608,100 @@ impl EditPredictionStore {
};
let now = cx.background_executor().now();
let mut oldest_edited_at = None;
let mut ready_predictions = Vec::new();
this.update(cx, |this, _| {
for (_, project_state) in this.projects.iter_mut() {
for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
registered_buffer
.pending_predictions
.retain_mut(|pending_prediction| {
let age =
now.saturating_duration_since(pending_prediction.enqueued_at);
let mut pending_index = 0;
while pending_index < registered_buffer.pending_predictions.len() {
let pending_prediction =
&registered_buffer.pending_predictions[pending_index];
let age = now.saturating_duration_since(pending_prediction.enqueued_at);
if age >= EDIT_PREDICTION_SETTLED_TTL {
return false;
registered_buffer.pending_predictions.remove(pending_index);
continue;
}
let quiet_for =
now.saturating_duration_since(pending_prediction.last_edit_at);
if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
let pending_prediction =
registered_buffer.pending_predictions.remove(pending_index);
let settled_editable_region = registered_buffer
.snapshot
.text_for_range(
pending_prediction.editable_anchor_range.clone(),
)
.collect::<String>();
ready_predictions
.push((pending_prediction, settled_editable_region));
continue;
}
if oldest_edited_at
.is_none_or(|time| pending_prediction.last_edit_at < time)
{
oldest_edited_at = Some(pending_prediction.last_edit_at);
}
pending_index += 1;
}
}
}
});
for (pending_prediction, settled_editable_region) in ready_predictions {
let PendingSettledPrediction {
request_id,
editable_region_before_prediction,
predicted_editable_region,
ts_error_count_before_prediction,
ts_error_count_after_prediction,
example,
e2e_latency,
..
} = pending_prediction;
let settled_editable_region_for_metrics = settled_editable_region.clone();
let kept_rate_result = cx
.background_spawn(async move {
compute_kept_rate(
&editable_region_before_prediction,
&predicted_editable_region,
&settled_editable_region_for_metrics,
)
})
.await;
#[cfg(test)]
{
let request_id = request_id.clone();
let settled_editable_region = settled_editable_region.clone();
this.update(cx, |this, _| {
if let Some(callback) = &this.settled_event_callback {
callback(
pending_prediction.request_id.clone(),
settled_editable_region.clone(),
);
callback(request_id, settled_editable_region);
}
});
}
telemetry::event!(
EDIT_PREDICTION_SETTLED_EVENT,
request_id = pending_prediction.request_id.0.clone(),
request_id = request_id.0.clone(),
settled_editable_region,
example = pending_prediction.example.take(),
e2e_latency = pending_prediction.e2e_latency.as_millis(),
ts_error_count_before_prediction,
ts_error_count_after_prediction,
edit_bytes_predicted_new = kept_rate_result.predicted_new_chars,
edit_bytes_final_new = kept_rate_result.final_new_chars,
edit_bytes_kept = kept_rate_result.kept_chars,
edit_bytes_discarded = kept_rate_result.discarded_chars,
edit_bytes_context = kept_rate_result.context_chars,
edit_bytes_kept_rate = kept_rate_result.kept_rate,
example,
e2e_latency = e2e_latency.as_millis(),
);
return false;
}
if oldest_edited_at
.is_none_or(|t| pending_prediction.last_edit_at < t)
{
oldest_edited_at = Some(pending_prediction.last_edit_at);
}
true
});
}
}
});
next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
}
}
@ -1670,21 +1712,52 @@ impl EditPredictionStore {
edited_buffer: &Entity<Buffer>,
edited_buffer_snapshot: &BufferSnapshot,
editable_offset_range: Range<usize>,
edit_preview: &EditPreview,
example: Option<ExampleSpec>,
e2e_latency: std::time::Duration,
cx: &mut Context<Self>,
) {
let this = &mut *self;
let project_state = this.get_or_init_project(project, cx);
if let Some(buffer) = project_state
let Some(registered_buffer) = project_state
.registered_buffers
.get_mut(&edited_buffer.entity_id())
{
else {
return;
};
let editable_region_before_prediction = edited_buffer_snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
let editable_anchor_range_for_result =
edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
let predicted_editable_region = edit_preview
.result_text_snapshot()
.text_for_range(editable_anchor_range_for_result.clone())
.collect();
let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
edited_buffer_snapshot
.syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
);
let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
edit_preview.result_syntax_snapshot().layers_for_range(
editable_anchor_range_for_result,
edit_preview.result_text_snapshot(),
true,
),
);
let editable_anchor_range =
edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
let now = cx.background_executor().now();
buffer.pending_predictions.push(PendingSettledPrediction {
request_id: request_id,
editable_anchor_range: edited_buffer_snapshot
.anchor_range_inside(editable_offset_range),
registered_buffer
.pending_predictions
.push(PendingSettledPrediction {
request_id,
editable_anchor_range,
editable_region_before_prediction,
predicted_editable_region,
ts_error_count_before_prediction,
ts_error_count_after_prediction,
example,
e2e_latency,
enqueued_at: now,
@ -1692,7 +1765,6 @@ impl EditPredictionStore {
});
this.settled_predictions_tx.unbounded_send(now).ok();
}
}
fn reject_current_prediction(
&mut self,

View file

@ -3252,6 +3252,12 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
cx.run_until_parked();
let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
let edit_preview_a = buffer
.read_with(cx, |buffer, cx| {
buffer.preview_edits(empty_edits.clone(), cx)
})
.await;
// Region A: first 10 lines of the buffer.
let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
@ -3263,6 +3269,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
&buffer,
&snapshot_a,
editable_region_a.clone(),
&edit_preview_a,
None,
Duration::from_secs(0),
cx,
@ -3318,6 +3325,9 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
cx.run_until_parked();
let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let edit_preview_b = buffer
.read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
.await;
let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
ep_store.update(cx, |ep_store, cx| {
@ -3327,6 +3337,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
&buffer,
&snapshot_b2,
editable_region_b.clone(),
&edit_preview_b,
None,
Duration::from_secs(0),
cx,

View file

@ -0,0 +1,10 @@
mod kept_rate;
mod tokenize;
mod tree_sitter;
pub use kept_rate::KeptRateResult;
#[cfg(test)]
pub use kept_rate::TokenAnnotation;
pub use kept_rate::compute_kept_rate;
pub(crate) use tokenize::tokenize;
pub use tree_sitter::count_tree_sitter_errors;

View file

@ -1,4 +1,6 @@
use crate::word_diff::tokenize;
use crate::metrics::tokenize;
const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -25,19 +27,27 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
row * width + column
}
/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
/// side while sharing a single DP table construction.
fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
/// while sharing a single DP table construction.
fn fill_lcs_keep_masks(
a: &[&str],
b: &[&str],
mut keep_a: Option<&mut [bool]>,
mut keep_b: Option<&mut [bool]>,
) {
if a.is_empty() || b.is_empty() {
return (vec![false; a.len()], vec![false; b.len()]);
return;
}
if a == b {
return (vec![true; a.len()], vec![true; b.len()]);
if let Some(keep_a) = keep_a.as_mut() {
keep_a.fill(true);
}
if let Some(keep_b) = keep_b.as_mut() {
keep_b.fill(true);
}
return;
}
let mut keep_a = vec![false; a.len()];
let mut keep_b = vec![false; b.len()];
let prefix_len = a
.iter()
@ -61,22 +71,30 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
};
for index in 0..prefix_len {
if let Some(keep_a) = keep_a.as_mut() {
keep_a[index] = true;
}
if let Some(keep_b) = keep_b.as_mut() {
keep_b[index] = true;
}
}
for offset in 0..suffix_len {
let a_index = a.len() - suffix_len + offset;
let b_index = b.len() - suffix_len + offset;
if let Some(keep_a) = keep_a.as_mut() {
keep_a[a_index] = true;
}
if let Some(keep_b) = keep_b.as_mut() {
keep_b[b_index] = true;
}
}
let a_mid = &a[prefix_len..a.len() - suffix_len];
let b_mid = &b[prefix_len..b.len() - suffix_len];
if a_mid.is_empty() || b_mid.is_empty() {
return (keep_a, keep_b);
return;
}
let row_count = a_mid.len() + 1;
@ -97,6 +115,7 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
}
}
if let Some(keep_a) = keep_a.as_mut() {
let mut i = a_mid.len();
let mut j = b_mid.len();
@ -115,7 +134,9 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
}
}
}
}
if let Some(keep_b) = keep_b.as_mut() {
let mut i = a_mid.len();
let mut j = b_mid.len();
@ -134,7 +155,19 @@ fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
}
}
}
}
}
fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
let mut keep_a = vec![false; a.len()];
fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
keep_a
}
fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
let mut keep_a = vec![false; a.len()];
let mut keep_b = vec![false; b.len()];
fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
(keep_a, keep_b)
}
@ -155,6 +188,12 @@ fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>
(unmasked_tokens, unmasked_chars, masked_chars)
}
fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool {
let predicted_delta_chars = predicted.len().abs_diff(base.len());
let final_delta_chars = final_text.len().abs_diff(base.len());
predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
}
pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
if base == predicted && predicted == final_text {
let predicted_tokens = tokenize(predicted);
@ -171,11 +210,26 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
};
}
if should_bail_for_dirty_final(base, predicted, final_text) {
let predicted_new_chars = predicted.len().abs_diff(base.len());
let final_new_chars = final_text.len().abs_diff(base.len());
return KeptRateResult {
predicted_new_chars,
final_new_chars,
kept_chars: 0,
discarded_chars: predicted_new_chars,
context_chars: 0,
kept_rate: 0.0,
#[cfg(test)]
token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()],
};
}
let base_tokens = tokenize(base);
let predicted_tokens = tokenize(predicted);
let final_tokens = tokenize(final_text);
let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens);
let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
let context_mask: Vec<bool> = pred_base_mask
.iter()
@ -186,7 +240,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
let (stripped_predicted, predicted_new_chars, context_chars) =
analyze_masked_tokens(&predicted_tokens, &context_mask);
let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens);
let final_context_mask: Vec<bool> = final_base_mask
.iter()
.zip(final_pred_mask.iter())
@ -196,7 +250,7 @@ pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptR
let (stripped_final, final_new_chars, _) =
analyze_masked_tokens(&final_tokens, &final_context_mask);
let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final);
let kept_chars: usize = stripped_predicted
.iter()
@ -265,8 +319,8 @@ mod test_kept_rate {
let a = ["x", "a", "x", "b"];
let b = ["a", "x", "b", "x"];
let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
assert_eq!(a_mask, lcs_keep_mask(&a, &b));
assert_eq!(b_mask, lcs_keep_mask(&b, &a));
}
#[test]
@ -342,6 +396,21 @@ mod test_kept_rate {
assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
}
#[test]
fn test_bails_for_dirty_final() {
let base = "fn example() {\n work();\n}\n";
let predicted = "fn example() {\n work();\n predicted();\n}\n";
let final_text = format!(
"fn example() {{\n work();\n {}\n}}\n",
"settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
);
let result = compute_kept_rate(base, predicted, &final_text);
assert_eq!(result.kept_rate, 0.0);
assert_eq!(result.kept_chars, 0);
assert_eq!(result.discarded_chars, result.predicted_new_chars);
}
#[test]
fn test_eprintln_token_alignment() {
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";

View file

@ -0,0 +1,54 @@
fn char_class(character: char) -> u8 {
if character.is_alphanumeric() || character == '_' {
0
} else if character.is_whitespace() {
1
} else {
2
}
}
pub(crate) fn tokenize(text: &str) -> Vec<&str> {
let mut tokens = Vec::new();
let mut characters = text.char_indices().peekable();
while let Some((start, character)) = characters.next() {
let class = char_class(character);
if class == 2 {
tokens.push(&text[start..start + character.len_utf8()]);
continue;
}
let mut end = start + character.len_utf8();
while let Some(&(_, next_character)) = characters.peek() {
if char_class(next_character) != class {
break;
}
end += next_character.len_utf8();
characters.next();
}
tokens.push(&text[start..end]);
}
tokens
}
#[cfg(test)]
mod tests {
use super::tokenize;
#[test]
fn tokenizes_code_like_text() {
assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
assert_eq!(
tokenize("foo_bar123 + baz"),
vec!["foo_bar123", " ", "+", " ", "baz"]
);
assert_eq!(
tokenize("print(\"hello\")"),
vec!["print", "(", "\"", "hello", "\"", ")"]
);
assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
}
}

View file

@ -0,0 +1,88 @@
use language::SyntaxLayer;
pub fn count_tree_sitter_errors<'a>(layers: impl Iterator<Item = SyntaxLayer<'a>>) -> usize {
let mut total_count: usize = 0;
for layer in layers {
let node = layer.node();
let mut cursor = node.walk();
'layer: loop {
let current = cursor.node();
if current.is_error() || current.is_missing() {
total_count += 1;
}
if current.has_error() && cursor.goto_first_child() {
continue;
}
if cursor.goto_next_sibling() {
continue;
}
loop {
if !cursor.goto_parent() {
break 'layer;
}
if cursor.goto_next_sibling() {
continue;
}
}
}
}
total_count
}
#[cfg(test)]
mod tests {
use std::ops::Range;
use super::count_tree_sitter_errors;
use gpui::{AppContext as _, TestAppContext};
use language::{Buffer, BufferSnapshot, rust_lang};
fn error_count_in_range(edited_buffer_snapshot: &BufferSnapshot, range: Range<usize>) -> usize {
let layers = edited_buffer_snapshot.syntax_layers_for_range(range, true);
count_tree_sitter_errors(layers)
}
fn rust_snapshot(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
while buffer.read_with(cx, |buffer, _| buffer.is_parsing()) {
cx.run_until_parked();
}
buffer.read_with(cx, |buffer, _| buffer.snapshot())
}
#[gpui::test]
async fn counts_no_errors_for_valid_rust(cx: &mut TestAppContext) {
let text = "fn helper(value: usize) -> usize {\n value + 1\n}\n";
let snapshot = rust_snapshot(text, cx);
assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 0);
}
#[gpui::test]
async fn counts_errors_for_invalid_rust(cx: &mut TestAppContext) {
let text = "fn helper(value: usize) -> usize {\n let total = ;\n total\n}\n";
let snapshot = rust_snapshot(text, cx);
assert_eq!(error_count_in_range(&snapshot, 0..snapshot.text.len()), 1);
}
#[gpui::test]
async fn counts_no_errors_for_subrange_of_valid_rust(cx: &mut TestAppContext) {
let text = "fn first() -> usize {\n let value = 1;\n value + 1\n}\n";
let snapshot = rust_snapshot(text, cx);
let body_start = text.find("let value").unwrap();
let body_end = body_start + "let value = 1;".len();
assert_eq!(error_count_in_range(&snapshot, body_start..body_end), 0);
}
#[gpui::test]
async fn counts_errors_for_subrange_of_invalid_rust(cx: &mut TestAppContext) {
let text = "fn second() -> usize {\n let broken = ;\n broken\n}\n";
let snapshot = rust_snapshot(text, cx);
let error_start = text.find("let broken = ;").unwrap();
let error_end = error_start + "let broken = ;".len();
assert_eq!(error_count_in_range(&snapshot, error_start..error_end), 1);
}
}

View file

@ -85,7 +85,6 @@ pub fn request_prediction_with_zeta(
} else {
None
};
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let organization_id = store
@ -383,11 +382,26 @@ pub fn request_prediction_with_zeta(
}));
};
if can_collect_data {
let result = EditPredictionResult::new(
id,
&edited_buffer,
&edited_buffer_snapshot,
edits.into(),
cursor_position,
inputs,
model_version,
request_duration,
cx,
)
.await;
if can_collect_data && let Ok(prediction) = &result.prediction {
let weak_this = this.clone();
let id = id.clone();
let request_id = prediction.id.clone();
let edited_buffer = edited_buffer.clone();
let edited_buffer_snapshot = edited_buffer_snapshot.clone();
let editable_range_in_buffer = editable_range_in_buffer.clone();
let edit_preview = prediction.edit_preview.clone();
let example_task = capture_data.and_then(|stored_events| {
cx.update(|cx| {
crate::capture_example(
@ -410,11 +424,12 @@ pub fn request_prediction_with_zeta(
weak_this
.update(cx, |this, cx| {
this.enqueue_settled_prediction(
id.clone(),
request_id.clone(),
&project,
&edited_buffer,
&edited_buffer_snapshot,
editable_range_in_buffer,
&edit_preview,
example_spec,
request_duration,
cx,
@ -425,20 +440,7 @@ pub fn request_prediction_with_zeta(
.detach();
}
Ok(Some(
EditPredictionResult::new(
id,
&edited_buffer,
&edited_buffer_snapshot,
edits.into(),
cursor_position,
inputs,
model_version,
request_duration,
cx,
)
.await,
))
Ok(Some(result))
})
}

View file

@ -83,7 +83,6 @@ dynamic_prompts = []
ignored = ["wasmtime"]
[dev-dependencies]
criterion.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
pretty_assertions.workspace = true
@ -91,6 +90,3 @@ project = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
workspace = { workspace = true, features = ["test-support"] }
[[bench]]
name = "kept_rate"
harness = false

View file

@ -1,4 +1,2 @@
#[allow(dead_code)]
mod word_diff;
pub mod kept_rate;

View file

@ -5,7 +5,7 @@ mod filter_languages;
mod format_prompt;
mod git;
mod headless;
mod kept_rate;
mod load_project;
mod metrics;
mod openai_client;

View file

@ -1298,4 +1298,4 @@ index abc123..def456 100644
}
}
pub use crate::kept_rate::compute_kept_rate;
pub use edit_prediction::metrics::compute_kept_rate;

View file

@ -878,6 +878,14 @@ impl EditPreview {
})
}
pub fn result_text_snapshot(&self) -> &text::BufferSnapshot {
&self.applied_edits_snapshot
}
pub fn result_syntax_snapshot(&self) -> &SyntaxSnapshot {
&self.syntax_snapshot
}
pub fn anchor_to_offset_in_result(&self, anchor: Anchor) -> usize {
anchor
.bias_right(&self.old_snapshot)