ep: Change kept_rate definition to a more intuitive one (#54306)

This change contains a number of fixes to make kept_rate more intuitive.
It also adds a CLI utility to print debug info on how the metric is
computed.

Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2026-04-20 14:03:26 +03:00 committed by GitHub
parent f92178e6e5
commit 72b41263f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1148 additions and 93 deletions

1
Cargo.lock generated
View file

@ -5338,6 +5338,7 @@ dependencies = [
"language",
"pretty_assertions",
"serde",
"serde_json",
"similar",
"tree-sitter",
"zeta_prompt",

View file

@ -14,6 +14,7 @@ path = "src/edit_prediction_metrics.rs"
[dependencies]
language.workspace = true
serde.workspace = true
serde_json = "1.0"
similar = "2.7.0"
tree-sitter.workspace = true
zeta_prompt.workspace = true

View file

@ -4,9 +4,10 @@ mod reversal;
mod tokenize;
mod tree_sitter;
pub use kept_rate::AnnotatedToken;
pub use kept_rate::KeptRateResult;
#[cfg(test)]
pub use kept_rate::TokenAnnotation;
pub use kept_rate::annotate_kept_rate_tokens;
pub use kept_rate::compute_kept_rate;
pub use patch_metrics::ClassificationMetrics;
pub use patch_metrics::Counts;
@ -20,5 +21,6 @@ pub use patch_metrics::exact_lines_match;
pub use patch_metrics::extract_changed_lines_from_diff;
pub use patch_metrics::has_isolated_whitespace_changes;
pub use patch_metrics::is_editable_region_correct;
pub use patch_metrics::reconstruct_texts_from_diff;
pub use reversal::compute_prediction_reversal_ratio_from_history;
pub use tree_sitter::count_tree_sitter_errors;

View file

@ -3,14 +3,20 @@ use serde::Serialize;
const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum TokenAnnotation {
Context,
Kept,
Discarded,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct AnnotatedToken {
pub token: String,
pub annotation: TokenAnnotation,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize)]
pub struct KeptRateResult {
@ -40,8 +46,7 @@ pub struct KeptRateResult {
/// This includes both kept newly introduced characters and correctly
/// deleted base characters.
pub recall_rate: f64,
/// Per-token classification for candidate tokens used by tests.
#[cfg(test)]
/// Per-token classification for candidate tokens.
pub token_annotations: Vec<TokenAnnotation>,
}
@ -51,9 +56,9 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
/// while sharing a single DP table construction.
fn fill_lcs_keep_masks(
a: &[&str],
b: &[&str],
fn fill_lcs_keep_masks<T: Eq>(
a: &[T],
b: &[T],
mut keep_a: Option<&mut [bool]>,
mut keep_b: Option<&mut [bool]>,
) {
@ -124,10 +129,10 @@ fn fill_lcs_keep_masks(
let mut dp = vec![0u32; row_count * column_count];
for i in 1..row_count {
let token_a = a_mid[i - 1];
let token_a = &a_mid[i - 1];
for j in 1..column_count {
let index = dp_index(column_count, i, j);
if token_a == b_mid[j - 1] {
if token_a == &b_mid[j - 1] {
dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
} else {
let up = dp[dp_index(column_count, i - 1, j)];
@ -180,41 +185,91 @@ fn fill_lcs_keep_masks(
}
}
fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
fn lcs_keep_mask<T: Eq>(a: &[T], b: &[T]) -> Vec<bool> {
let mut keep_a = vec![false; a.len()];
fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
keep_a
}
fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
fn lcs_keep_masks<T: Eq>(a: &[T], b: &[T]) -> (Vec<bool>, Vec<bool>) {
let mut keep_a = vec![false; a.len()];
let mut keep_b = vec![false; b.len()];
fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
(keep_a, keep_b)
}
fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
let mut unmasked_tokens = Vec::with_capacity(tokens.len());
let mut unmasked_chars = 0;
let mut masked_chars = 0;
#[derive(Debug, Clone)]
struct ComparisonUnit {
text: String,
token_start: usize,
token_end: usize,
}
for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
if is_masked {
masked_chars += token.len();
fn is_identifier_token(token: &str) -> bool {
!token.is_empty()
&& token
.chars()
.all(|character| character.is_alphanumeric() || character == '_')
}
fn build_comparison_units(tokens: &[&str]) -> Vec<ComparisonUnit> {
let mut units = Vec::new();
let mut index = 0;
while index < tokens.len() {
let token_start = index;
if is_identifier_token(tokens[index]) {
let mut text = String::new();
while index < tokens.len() && is_identifier_token(tokens[index]) {
text.push_str(tokens[index]);
index += 1;
}
units.push(ComparisonUnit {
text,
token_start,
token_end: index,
});
} else {
unmasked_tokens.push(token);
unmasked_chars += token.len();
units.push(ComparisonUnit {
text: tokens[index].to_string(),
token_start,
token_end: index + 1,
});
index += 1;
}
}
(unmasked_tokens, unmasked_chars, masked_chars)
units
}
fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
tokens
fn analyze_masked_units<'a>(
units: &'a [ComparisonUnit],
mask: &[bool],
) -> (Vec<&'a str>, usize, usize) {
let mut unmasked_units = Vec::with_capacity(units.len());
let mut unmasked_chars = 0;
let mut masked_chars = 0;
for (unit, &is_masked) in units.iter().zip(mask.iter()) {
if is_masked {
masked_chars += unit.text.len();
} else {
unmasked_units.push(unit.text.as_str());
unmasked_chars += unit.text.len();
}
}
(unmasked_units, unmasked_chars, masked_chars)
}
fn count_unmasked_unit_chars(units: &[ComparisonUnit], mask: &[bool]) -> usize {
units
.iter()
.zip(mask.iter())
.filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
.filter_map(|(unit, &is_masked)| (!is_masked).then_some(unit.text.len()))
.sum()
}
@ -239,7 +294,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
context_chars,
kept_rate: 1.0,
recall_rate: 1.0,
#[cfg(test)]
token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
};
}
@ -258,7 +312,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
context_chars: 0,
kept_rate: 0.0,
recall_rate: 0.0,
#[cfg(test)]
token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
};
}
@ -267,29 +320,29 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
let candidate_tokens = tokenize(candidate);
let reference_tokens = tokenize(reference);
let (candidate_base_mask, base_candidate_mask) =
lcs_keep_masks(&candidate_tokens, &base_tokens);
let (candidate_reference_mask, reference_candidate_mask) =
lcs_keep_masks(&candidate_tokens, &reference_tokens);
let context_mask: Vec<bool> = candidate_base_mask
let candidate_units = build_comparison_units(&candidate_tokens);
let base_units = build_comparison_units(&base_tokens);
let reference_units = build_comparison_units(&reference_tokens);
let candidate_unit_texts: Vec<&str> = candidate_units
.iter()
.zip(candidate_reference_mask.iter())
.map(|(&in_base, &in_reference)| in_base && in_reference)
.map(|unit| unit.text.as_str())
.collect();
let base_unit_texts: Vec<&str> = base_units.iter().map(|unit| unit.text.as_str()).collect();
let reference_unit_texts: Vec<&str> = reference_units
.iter()
.map(|unit| unit.text.as_str())
.collect();
let (candidate_base_mask, base_candidate_mask) =
lcs_keep_masks(&candidate_unit_texts, &base_unit_texts);
let (stripped_candidate, candidate_new_chars, context_chars) =
analyze_masked_tokens(&candidate_tokens, &context_mask);
analyze_masked_units(&candidate_units, &candidate_base_mask);
let (reference_base_mask, base_reference_mask) =
lcs_keep_masks(&reference_tokens, &base_tokens);
let reference_context_mask: Vec<bool> = reference_base_mask
.iter()
.zip(reference_candidate_mask.iter())
.map(|(&in_base, &in_candidate)| in_base && in_candidate)
.collect();
lcs_keep_masks(&reference_unit_texts, &base_unit_texts);
let (stripped_reference, reference_new_chars, _) =
analyze_masked_tokens(&reference_tokens, &reference_context_mask);
analyze_masked_units(&reference_units, &reference_base_mask);
let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
@ -299,13 +352,13 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
.filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
.sum();
let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
let correctly_deleted_chars: usize = base_tokens
let candidate_deleted_chars = count_unmasked_unit_chars(&base_units, &base_candidate_mask);
let reference_deleted_chars = count_unmasked_unit_chars(&base_units, &base_reference_mask);
let correctly_deleted_chars: usize = base_units
.iter()
.zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
.filter_map(|(&token, (&in_candidate, &in_reference))| {
(!in_candidate && !in_reference).then_some(token.len())
.filter_map(|(unit, (&in_candidate, &in_reference))| {
(!in_candidate && !in_reference).then_some(unit.text.len())
})
.sum();
@ -326,24 +379,28 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
matched_edit_chars as f64 / reference_edit_chars as f64
};
#[cfg(test)]
let token_annotations = {
let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
let mut token_annotations = vec![TokenAnnotation::Context; candidate_tokens.len()];
let mut new_index = 0;
for (token_index, _token) in candidate_tokens.iter().enumerate() {
if context_mask[token_index] {
token_annotations.push(TokenAnnotation::Context);
for (unit_index, unit) in candidate_units.iter().enumerate() {
let annotation = if candidate_base_mask[unit_index] {
TokenAnnotation::Context
} else {
let annotation = if keep_mask[new_index] {
TokenAnnotation::Kept
} else {
TokenAnnotation::Discarded
};
#[cfg(test)]
token_annotations.push(annotation);
new_index += 1;
annotation
};
for token_index in unit.token_start..unit.token_end {
token_annotations[token_index] = annotation;
}
}
token_annotations
};
@ -358,14 +415,30 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
context_chars,
kept_rate,
recall_rate,
#[cfg(test)]
token_annotations,
}
}
pub fn annotate_kept_rate_tokens(
base: &str,
candidate: &str,
reference: &str,
) -> Vec<AnnotatedToken> {
let result = compute_kept_rate(base, candidate, reference);
tokenize(candidate)
.into_iter()
.zip(result.token_annotations)
.map(|(token, annotation)| AnnotatedToken {
token: token.to_string(),
annotation,
})
.collect()
}
#[cfg(test)]
mod test_kept_rate {
use super::*;
use indoc::indoc;
#[test]
fn test_lcs_keep_masks() {
@ -439,16 +512,24 @@ mod test_kept_rate {
#[test]
fn test_missing_deletion() {
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
let base = indoc! {"
fn example() {
epr
"};
let candidate = indoc! {r#"
fn example() {
epr
eprintln!("");
"#};
let reference = indoc! {r#"
fn example() {
eprintln!("");
"#};
let result = compute_kept_rate(base, candidate, reference);
assert!(
result.kept_rate < 0.85,
"expected kept_rate < 0.85, got {}",
result.kept_rate
);
assert!(result.discarded_chars > 0);
assert!((result.kept_rate - (14.0 / 15.0)).abs() < 1e-6);
assert_eq!(result.kept_chars, 14);
assert_eq!(result.discarded_chars, 1);
}
#[test]
@ -472,8 +553,17 @@ mod test_kept_rate {
#[test]
fn test_bails_for_dirty_final() {
let base = "fn example() {\n work();\n}\n";
let candidate = "fn example() {\n work();\n predicted();\n}\n";
let base = indoc! {"
fn example() {
work();
}
"};
let candidate = indoc! {"
fn example() {
work();
predicted();
}
"};
let reference = format!(
"fn example() {{\n work();\n {}\n}}\n",
"settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
@ -488,9 +578,19 @@ mod test_kept_rate {
#[test]
fn test_eprintln_token_alignment() {
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
let base = indoc! {"
fn example() {
epr
"};
let candidate = indoc! {r#"
fn example() {
eprintln!("hello world!");
"#};
let reference = indoc! {r#"
fn example() {
eprintln!("");
"#};
let result = compute_kept_rate(base, candidate, reference);
assert!(result.discarded_chars > 0);
assert!(result.kept_chars > 0);
@ -499,6 +599,42 @@ mod test_kept_rate {
assert_eq!(result.discarded_chars, 12);
}
#[test]
fn test_kept_rate_treats_unchanged_stale_text_as_context() {
let base = indoc! {"
a=fomr
b=old
"};
let candidate = indoc! {"
a=formula;
b=old
"};
let reference = indoc! {"
a=formula;
b=new
"};
let result = compute_kept_rate(base, candidate, reference);
let candidate_tokens = tokenize(candidate);
assert_eq!(result.candidate_new_chars, "formula".len() + ";".len());
assert_eq!(result.kept_chars, "formula".len() + ";".len());
assert_eq!(result.discarded_chars, 0);
assert_eq!(result.candidate_deleted_chars, "fomr".len());
assert_eq!(result.correctly_deleted_chars, "fomr".len());
assert!((result.kept_rate - 1.0).abs() < 1e-6);
assert!((result.recall_rate - (2.0 / 3.0)).abs() < 1e-6);
let old_index = candidate_tokens
.iter()
.position(|&token| token == "old")
.expect("old token not found");
assert_eq!(
result.token_annotations[old_index],
TokenAnnotation::Context
);
}
#[test]
fn test_annotations_rename() {
let base = " foo(old_name)\n";
@ -514,7 +650,7 @@ mod test_kept_rate {
assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
if token == "new_name" {
if matches!(token, "new" | "_" | "name") {
assert_eq!(annotation, TokenAnnotation::Kept);
} else {
assert_eq!(annotation, TokenAnnotation::Context);
@ -524,9 +660,18 @@ mod test_kept_rate {
#[test]
fn test_annotations_eprintln_coloring() {
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
let base = indoc! {"
fn example() {
epr
"};
let candidate = indoc! {r#"
fn example() {
eprintln!("hello world!");
"#};
let reference = indoc! {r#"
fn example() {
eprintln!("");
"#};
let result = compute_kept_rate(base, candidate, reference);
let candidate_tokens = tokenize(candidate);

View file

@ -0,0 +1,710 @@
use std::env;
use std::fmt::Write as _;
use std::fs;
use std::path::Path;
use std::process;
use edit_prediction_metrics::{
ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation,
annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes,
delta_chr_f, exact_lines_match, extract_changed_lines_from_diff,
has_isolated_whitespace_changes, is_editable_region_correct,
};
use serde::Deserialize;
fn main() {
if let Err(error) = run() {
eprintln!("error: {error}");
process::exit(1);
}
}
fn run() -> Result<(), String> {
let args: Vec<String> = env::args().skip(1).collect();
if args.is_empty() {
print_usage();
return Err("missing arguments".to_string());
}
let input = CliInput::parse(&args)?;
let report = match input {
CliInput::Files {
base_path,
expected_patch_path,
actual_patch_path,
} => {
let base = fs::read_to_string(&base_path)
.map_err(|err| format!("failed to read {}: {err}", base_path.display()))?;
let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| {
format!("failed to read {}: {err}", expected_patch_path.display())
})?;
let actual_patch = fs::read_to_string(&actual_patch_path)
.map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?;
let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?;
let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?;
EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
}
CliInput::Json {
json_path,
prediction_index,
} => {
let json = fs::read_to_string(&json_path)
.map_err(|err| format!("failed to read {}: {err}", json_path.display()))?;
let example: JsonExample = serde_json::from_str(&json)
.map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?;
let base = example.prompt_inputs.cursor_excerpt;
let excerpt_start_row = example.prompt_inputs.excerpt_start_row;
let expected_patch = example
.expected_patches
.into_iter()
.next()
.ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?;
let actual_patch = example
.predictions
.into_iter()
.nth(prediction_index)
.ok_or_else(|| {
format!("JSON input does not contain predictions[{prediction_index}]")
})?
.actual_patch;
let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?;
let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?;
EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
}
};
print_report(&report);
Ok(())
}
fn print_usage() {
eprintln!(
"Usage:\n edit_prediction_metrics --base <base.txt> --expected-patch <expected.diff> --actual-patch <actual.diff>\n edit_prediction_metrics --json <example.json> [--prediction-index <n>]"
);
}
enum CliInput {
Files {
base_path: std::path::PathBuf,
expected_patch_path: std::path::PathBuf,
actual_patch_path: std::path::PathBuf,
},
Json {
json_path: std::path::PathBuf,
prediction_index: usize,
},
}
impl CliInput {
fn parse(args: &[String]) -> Result<Self, String> {
let mut base_path = None;
let mut expected_patch_path = None;
let mut actual_patch_path = None;
let mut json_path = None;
let mut prediction_index = 0usize;
let mut index = 0;
while index < args.len() {
match args[index].as_str() {
"--base" => {
index += 1;
base_path = Some(path_arg(args, index, "--base")?);
}
"--expected-patch" => {
index += 1;
expected_patch_path = Some(path_arg(args, index, "--expected-patch")?);
}
"--actual-patch" => {
index += 1;
actual_patch_path = Some(path_arg(args, index, "--actual-patch")?);
}
"--json" => {
index += 1;
json_path = Some(path_arg(args, index, "--json")?);
}
"--prediction-index" => {
index += 1;
let raw = string_arg(args, index, "--prediction-index")?;
prediction_index = raw.parse::<usize>().map_err(|err| {
format!("invalid value for --prediction-index ({raw}): {err}")
})?;
}
"--help" | "-h" => {
print_usage();
process::exit(0);
}
unknown => {
return Err(format!("unrecognized argument: {unknown}"));
}
}
index += 1;
}
if let Some(json_path) = json_path {
if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() {
return Err(
"--json cannot be combined with --base/--expected-patch/--actual-patch"
.to_string(),
);
}
return Ok(CliInput::Json {
json_path,
prediction_index,
});
}
match (base_path, expected_patch_path, actual_patch_path) {
(Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => {
Ok(CliInput::Files {
base_path,
expected_patch_path,
actual_patch_path,
})
}
_ => Err(
"expected either --json <file> or all of --base, --expected-patch, and --actual-patch"
.to_string(),
),
}
}
}
fn path_arg(args: &[String], index: usize, flag: &str) -> Result<std::path::PathBuf, String> {
Ok(Path::new(string_arg(args, index, flag)?).to_path_buf())
}
fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> {
args.get(index)
.map(|value| value.as_str())
.ok_or_else(|| format!("missing value for {flag}"))
}
#[derive(Debug)]
struct EvaluationReport {
base: String,
expected: String,
actual: String,
kept_rate: KeptRateResult,
exact_lines: ClassificationMetrics,
delta_chr_f: DeltaChrFMetrics,
expected_changed_lines: usize,
actual_changed_lines: usize,
token_changes: edit_prediction_metrics::TokenChangeCounts,
isolated_whitespace_changes: bool,
editable_region_correct: bool,
expected_braces_disbalance: usize,
actual_braces_disbalance: usize,
}
impl EvaluationReport {
fn new(
base: String,
expected_patch: String,
actual_patch: String,
expected: String,
actual: String,
) -> Self {
let kept_rate = compute_kept_rate(&base, &actual, &expected);
let exact_lines = exact_lines_match(&expected_patch, &actual_patch);
let delta_chr_f = delta_chr_f(&base, &expected, &actual);
let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch)
.values()
.sum();
let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch)
.values()
.sum();
let token_changes = count_patch_token_changes(&actual_patch);
let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None);
let editable_region_correct = is_editable_region_correct(&actual_patch);
let expected_braces_disbalance = braces_disbalance(&expected);
let actual_braces_disbalance = braces_disbalance(&actual);
Self {
base,
expected,
actual,
kept_rate,
exact_lines,
delta_chr_f,
expected_changed_lines,
actual_changed_lines,
token_changes,
isolated_whitespace_changes,
editable_region_correct,
expected_braces_disbalance,
actual_braces_disbalance,
}
}
}
fn print_report(report: &EvaluationReport) {
println!("Metrics");
println!("=======");
println!("kept_rate: {:.6}", report.kept_rate.kept_rate);
println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate);
println!("delta_chr_f: {:.6}", report.delta_chr_f.score);
println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision);
println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall);
println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta);
println!();
println!("Exact line match");
println!("----------------");
println!("true_positives: {}", report.exact_lines.true_positives);
println!("false_positives: {}", report.exact_lines.false_positives);
println!("false_negatives: {}", report.exact_lines.false_negatives);
println!("precision: {:.6}", report.exact_lines.precision());
println!("recall: {:.6}", report.exact_lines.recall());
println!("f1: {:.6}", report.exact_lines.f1());
println!("expected_changed_lines: {}", report.expected_changed_lines);
println!("actual_changed_lines: {}", report.actual_changed_lines);
println!();
println!("Patch structure");
println!("---------------");
println!("inserted_tokens: {}", report.token_changes.inserted_tokens);
println!("deleted_tokens: {}", report.token_changes.deleted_tokens);
println!(
"isolated_whitespace_changes: {}",
report.isolated_whitespace_changes
);
println!(
"editable_region_correct: {}",
report.editable_region_correct
);
println!();
println!("Final text checks");
println!("-----------------");
println!(
"expected_braces_disbalance: {}",
report.expected_braces_disbalance
);
println!(
"actual_braces_disbalance: {}",
report.actual_braces_disbalance
);
println!();
println!("Kept-rate breakdown");
println!("-------------------");
println!(
"candidate_new_chars: {}",
report.kept_rate.candidate_new_chars
);
println!(
"reference_new_chars: {}",
report.kept_rate.reference_new_chars
);
println!(
"candidate_deleted_chars: {}",
report.kept_rate.candidate_deleted_chars
);
println!(
"reference_deleted_chars: {}",
report.kept_rate.reference_deleted_chars
);
println!("kept_chars: {}", report.kept_rate.kept_chars);
println!(
"correctly_deleted_chars: {}",
report.kept_rate.correctly_deleted_chars
);
println!("discarded_chars: {}", report.kept_rate.discarded_chars);
println!("context_chars: {}", report.kept_rate.context_chars);
println!();
print_kept_rate_explanation(&report.base, &report.actual, &report.expected);
}
fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) {
println!("Kept-rate explanation");
println!("---------------------");
println!("Legend: context = default, kept = green background, discarded = red background");
println!();
let annotated = annotate_kept_rate_tokens(base, actual, expected);
println!("Actual final text with token annotations:");
println!("{}", render_annotated_tokens(&annotated));
println!();
}
fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String {
const RESET: &str = "\x1b[0m";
const KEPT_STYLE: &str = "\x1b[30;42m";
const DISCARDED_STYLE: &str = "\x1b[30;41m";
let mut rendered = String::new();
for token in tokens {
let style = match token.annotation {
TokenAnnotation::Context => "",
TokenAnnotation::Kept => KEPT_STYLE,
TokenAnnotation::Discarded => DISCARDED_STYLE,
};
if style.is_empty() {
rendered.push_str(&visualize_whitespace(&token.token));
} else {
rendered.push_str(style);
rendered.push_str(&visualize_whitespace(&token.token));
rendered.push_str(RESET);
}
}
rendered
}
fn visualize_whitespace(token: &str) -> String {
let mut rendered = String::new();
for ch in token.chars() {
match ch {
' ' => rendered.push('·'),
'\t' => rendered.push('⇥'),
'\n' => rendered.push_str("\n"),
_ => rendered.push(ch),
}
}
rendered
}
#[derive(Debug, Deserialize)]
struct JsonExample {
prompt_inputs: PromptInputs,
expected_patches: Vec<String>,
predictions: Vec<Prediction>,
}
#[derive(Debug, Deserialize)]
struct PromptInputs {
cursor_excerpt: String,
excerpt_start_row: u32,
}
#[derive(Debug, Deserialize)]
struct Prediction {
actual_patch: String,
}
#[derive(Debug, Clone)]
struct ParsedHunk {
old_start: u32,
lines: Vec<HunkLine>,
}
#[derive(Debug, Clone)]
enum HunkLine {
Context(String),
Addition(String),
Deletion(String),
}
fn apply_patch_to_excerpt(
base: &str,
patch: &str,
excerpt_start_row: u32,
) -> Result<String, String> {
let hunks = parse_diff_hunks(patch);
let result = try_apply_hunks(base, &hunks, excerpt_start_row);
// Predicted patches may use excerpt-relative line numbers instead of
// file-global ones. When all hunks fall outside the excerpt window the
// result is identical to the base text. Retry with a zero offset so the
// line numbers are interpreted relative to the excerpt.
if excerpt_start_row > 0 && !hunks.is_empty() {
let should_retry = match &result {
Ok(text) => text == base,
Err(_) => true,
};
if should_retry {
let fallback = try_apply_hunks(base, &hunks, 0);
if matches!(&fallback, Ok(text) if text != base) {
return fallback;
}
}
}
result
}
fn try_apply_hunks(
base: &str,
hunks: &[ParsedHunk],
excerpt_start_row: u32,
) -> Result<String, String> {
let base_has_trailing_newline = base.ends_with('\n');
let mut lines = split_preserving_final_empty_line(base);
let original_line_count = lines.len() as u32;
let excerpt_end_row = excerpt_start_row + original_line_count;
let mut line_delta: i64 = 0;
for hunk in hunks {
let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) {
Some(filtered) => filtered,
None => continue,
};
let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta;
if local_start < 0 {
return Err(format!(
"patch application moved before excerpt start at source row {}",
filtered.old_start
));
}
let local_start = local_start as usize;
if local_start > lines.len() {
return Err(format!(
"patch application starts past excerpt end at local line {}",
local_start + 1
));
}
let old_len = filtered
.lines
.iter()
.filter(|line| !matches!(line, HunkLine::Addition(_)))
.count();
let new_len = filtered
.lines
.iter()
.filter(|line| !matches!(line, HunkLine::Deletion(_)))
.count();
let old_segment: Vec<&str> = filtered
.lines
.iter()
.filter_map(|line| match line {
HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()),
HunkLine::Addition(_) => None,
})
.collect();
let new_segment: Vec<String> = filtered
.lines
.iter()
.filter_map(|line| match line {
HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()),
HunkLine::Deletion(_) => None,
})
.collect();
if local_start + old_len > lines.len() {
return Err(format!(
"patch application exceeds excerpt bounds near source row {}",
filtered.old_start
));
}
let current_segment: Vec<&str> = lines[local_start..local_start + old_len]
.iter()
.map(String::as_str)
.collect();
if current_segment != old_segment {
let mut details = String::new();
let _ = write!(
details,
"patch context mismatch near source row {}: expected {:?}, found {:?}",
filtered.old_start, old_segment, current_segment
);
return Err(details);
}
lines.splice(local_start..local_start + old_len, new_segment);
line_delta += new_len as i64 - old_len as i64;
}
Ok(join_lines(&lines, base_has_trailing_newline))
}
fn split_preserving_final_empty_line(text: &str) -> Vec<String> {
let mut lines: Vec<String> = text.lines().map(ToString::to_string).collect();
if text.ends_with('\n') {
if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() {
lines.push(String::new());
}
}
lines
}
fn join_lines(lines: &[String], had_trailing_newline: bool) -> String {
if lines.is_empty() {
return String::new();
}
let mut joined = lines.join("\n");
if had_trailing_newline && !joined.ends_with('\n') {
joined.push('\n');
}
if !had_trailing_newline && joined.ends_with('\n') {
joined.pop();
}
joined
}
fn filter_hunk_to_excerpt(
hunk: &ParsedHunk,
excerpt_start_row: u32,
excerpt_end_row: u32,
) -> Option<ParsedHunk> {
let mut filtered_lines = Vec::new();
let mut current_old_row = hunk.old_start.saturating_sub(1);
let mut filtered_old_start = None;
let mut has_overlap = false;
for line in &hunk.lines {
match line {
HunkLine::Context(text) => {
let in_excerpt =
current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
if in_excerpt {
filtered_old_start.get_or_insert(current_old_row);
filtered_lines.push(HunkLine::Context(text.clone()));
has_overlap = true;
}
current_old_row += 1;
}
HunkLine::Deletion(text) => {
let in_excerpt =
current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
if in_excerpt {
filtered_old_start.get_or_insert(current_old_row);
filtered_lines.push(HunkLine::Deletion(text.clone()));
has_overlap = true;
}
current_old_row += 1;
}
HunkLine::Addition(text) => {
let insertion_in_excerpt =
current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row;
if insertion_in_excerpt {
filtered_old_start.get_or_insert(current_old_row);
filtered_lines.push(HunkLine::Addition(text.clone()));
has_overlap = true;
}
}
}
}
if !has_overlap {
return None;
}
Some(ParsedHunk {
old_start: filtered_old_start.unwrap_or(excerpt_start_row),
lines: filtered_lines,
})
}
fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
let mut hunks = Vec::new();
let mut current_hunk: Option<ParsedHunk> = None;
for line in diff.lines() {
if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) {
if let Some(hunk) = current_hunk.take() {
hunks.push(hunk);
}
let _ = old_count;
current_hunk = Some(ParsedHunk {
old_start,
lines: Vec::new(),
});
continue;
}
let Some(hunk) = current_hunk.as_mut() else {
continue;
};
if let Some(text) = line.strip_prefix('+') {
if !line.starts_with("+++") {
hunk.lines.push(HunkLine::Addition(text.to_string()));
}
} else if let Some(text) = line.strip_prefix('-') {
if !line.starts_with("---") {
hunk.lines.push(HunkLine::Deletion(text.to_string()));
}
} else if let Some(text) = line.strip_prefix(' ') {
hunk.lines.push(HunkLine::Context(text.to_string()));
} else if line.is_empty() {
hunk.lines.push(HunkLine::Context(String::new()));
}
}
if let Some(hunk) = current_hunk {
hunks.push(hunk);
}
hunks
}
fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
let line = line.strip_prefix("@@ -")?;
let (old_part, rest) = line.split_once(' ')?;
let rest = rest.strip_prefix('+')?;
let (new_part, _) = rest.split_once(" @@")?;
let (old_start, old_count) = parse_hunk_range(old_part)?;
let (new_start, new_count) = parse_hunk_range(new_part)?;
Some((old_start, old_count, new_start, new_count))
}
fn parse_hunk_range(part: &str) -> Option<(u32, u32)> {
if let Some((start, count)) = part.split_once(',') {
Some((start.parse().ok()?, count.parse().ok()?))
} else {
Some((part.parse().ok()?, 1))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn applies_patch_in_file_mode() {
let base = "fn main() {\n println!(\"hello\");\n}\n";
let patch = "@@ -1,3 +1,3 @@\n fn main() {\n- println!(\"hello\");\n+ println!(\"world\");\n }\n";
let actual = apply_patch_to_excerpt(base, patch, 0).unwrap();
assert_eq!(actual, "fn main() {\n println!(\"world\");\n}\n");
}
#[test]
fn applies_patch_in_json_excerpt_mode() {
let base = "b\nc\nd\n";
let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
let actual = apply_patch_to_excerpt(base, patch, 1).unwrap();
assert_eq!(actual, "x\ny\nd\n");
}
#[test]
fn applies_patch_with_excerpt_relative_line_numbers() {
let base = "a\nb\nc\nd\n";
// Patch uses excerpt-relative line numbers (line 2 of excerpt)
// even though the excerpt starts at file row 100.
let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
let actual = apply_patch_to_excerpt(base, patch, 100).unwrap();
assert_eq!(actual, "a\nx\ny\nd\n");
}
#[test]
fn prefers_file_global_line_numbers_over_excerpt_relative() {
let base = "a\nb\nc\n";
// Patch uses file-global line numbers: excerpt starts at row 5,
// hunk targets line 6 (1-based) = row 5 (0-based) = first line.
let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n";
let actual = apply_patch_to_excerpt(base, patch, 5).unwrap();
assert_eq!(actual, "x\ny\nc\n");
}
}

View file

@ -687,6 +687,35 @@ fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
.collect()
}
/// Reconstruct old and new text from a unified diff.
///
/// Context and deletion lines form the old text; context and addition
/// lines form the new text. Returns `(old_text, new_text)`.
pub fn reconstruct_texts_from_diff(patch_str: &str) -> (String, String) {
let patch = Patch::parse_unified_diff(patch_str);
let mut old_lines: Vec<&str> = Vec::new();
let mut new_lines: Vec<&str> = Vec::new();
for hunk in &patch.hunks {
for line in &hunk.lines {
match line {
PatchLine::Context(content) => {
old_lines.push(content);
new_lines.push(content);
}
PatchLine::Deletion(content) => {
old_lines.push(content);
}
PatchLine::Addition(content) => {
new_lines.push(content);
}
PatchLine::Garbage(_) => {}
}
}
}
(old_lines.join("\n"), new_lines.join("\n"))
}
#[derive(Debug, Default, Clone)]
struct Patch {
hunks: Vec<Hunk>,

View file

@ -1,33 +1,158 @@
fn char_class(character: char) -> u8 {
if character.is_alphanumeric() || character == '_' {
0
use std::{iter::Peekable, str::CharIndices};
#[derive(Clone, Copy, PartialEq, Eq)]
enum CharClass {
Identifier,
Newline,
Whitespace,
Punctuation,
}
const MULTI_CHAR_PUNCTUATION: &[&str] = &[
">>>=", "<<=", ">>=", "...", "..=", "??=", "**=", ">>>", "::", "->", "=>", "==", "!=", "<=",
">=", "&&", "||", "<<", ">>", "..", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "++", "--",
"**", "??", "?.", ":=", "<-", "//", "/*", "*/",
];
fn char_class(character: char) -> CharClass {
if character == '\n' || character == '\r' {
CharClass::Newline
} else if character.is_whitespace() {
1
CharClass::Whitespace
} else if character.is_alphanumeric() || character == '_' {
CharClass::Identifier
} else {
2
CharClass::Punctuation
}
}
fn is_identifier_boundary(previous: char, current: char, next: Option<char>) -> bool {
(current.is_uppercase() && (previous.is_lowercase() || previous.is_numeric()))
|| (current.is_uppercase()
&& previous.is_uppercase()
&& next.is_some_and(|next| next.is_lowercase()))
}
fn push_identifier_tokens<'a>(identifier: &'a str, tokens: &mut Vec<&'a str>) {
let characters: Vec<(usize, char)> = identifier.char_indices().collect();
let mut segment_start = 0;
let mut index = 0;
while index < characters.len() {
let (byte_index, character) = characters[index];
if character == '_' {
if segment_start < byte_index {
tokens.push(&identifier[segment_start..byte_index]);
}
let mut underscore_end = byte_index + character.len_utf8();
index += 1;
while index < characters.len() && characters[index].1 == '_' {
underscore_end = characters[index].0 + characters[index].1.len_utf8();
index += 1;
}
tokens.push(&identifier[byte_index..underscore_end]);
segment_start = underscore_end;
continue;
}
if byte_index > segment_start {
let previous = characters[index - 1].1;
let next = characters.get(index + 1).map(|(_, character)| *character);
if is_identifier_boundary(previous, character, next) {
tokens.push(&identifier[segment_start..byte_index]);
segment_start = byte_index;
}
}
index += 1;
}
if segment_start < identifier.len() {
tokens.push(&identifier[segment_start..]);
}
}
fn push_punctuation_token<'a>(
text: &'a str,
start: usize,
character: char,
characters: &mut Peekable<CharIndices<'a>>,
tokens: &mut Vec<&'a str>,
) {
let remaining = &text[start..];
for punctuation in MULTI_CHAR_PUNCTUATION {
if remaining.starts_with(punctuation) {
for _ in punctuation.chars().skip(1) {
characters.next();
}
tokens.push(&remaining[..punctuation.len()]);
return;
}
}
let end = start + character.len_utf8();
tokens.push(&text[start..end]);
}
pub(crate) fn tokenize(text: &str) -> Vec<&str> {
let mut tokens = Vec::new();
let mut characters = text.char_indices().peekable();
while let Some((start, character)) = characters.next() {
let class = char_class(character);
if class == 2 {
tokens.push(&text[start..start + character.len_utf8()]);
continue;
}
match char_class(character) {
CharClass::Identifier => {
let mut end = start + character.len_utf8();
let mut end = start + character.len_utf8();
while let Some(&(_, next_character)) = characters.peek() {
if char_class(next_character) != class {
break;
while let Some(&(next_start, next_character)) = characters.peek() {
if char_class(next_character) != CharClass::Identifier {
break;
}
end = next_start + next_character.len_utf8();
characters.next();
}
push_identifier_tokens(&text[start..end], &mut tokens);
}
CharClass::Newline => {
let mut end = start + character.len_utf8();
while let Some(&(next_start, next_character)) = characters.peek() {
if char_class(next_character) != CharClass::Newline {
break;
}
end = next_start + next_character.len_utf8();
characters.next();
}
tokens.push(&text[start..end]);
}
CharClass::Whitespace => {
let mut end = start + character.len_utf8();
while let Some(&(next_start, next_character)) = characters.peek() {
if char_class(next_character) != CharClass::Whitespace {
break;
}
end = next_start + next_character.len_utf8();
characters.next();
}
tokens.push(&text[start..end]);
}
CharClass::Punctuation => {
push_punctuation_token(text, start, character, &mut characters, &mut tokens);
}
end += next_character.len_utf8();
characters.next();
}
tokens.push(&text[start..end]);
}
tokens
@ -38,17 +163,58 @@ mod tests {
use super::tokenize;
#[test]
fn tokenizes_code_like_text() {
fn tokenizes_code() {
assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
assert_eq!(
tokenize("foo_bar123 + baz"),
vec!["foo_bar123", " ", "+", " ", "baz"]
vec!["foo", "_", "bar123", " ", "+", " ", "baz"]
);
assert_eq!(
tokenize("print(\"hello\")"),
vec!["print", "(", "\"", "hello", "\"", ")"]
);
assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
assert_eq!(tokenize("hello_world"), vec!["hello", "_", "world"]);
assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
}
#[test]
fn tokenizes_identifier_case_styles() {
assert_eq!(
tokenize("camelCase PascalCase snake_case"),
vec![
"camel", "Case", " ", "Pascal", "Case", " ", "snake", "_", "case"
]
);
assert_eq!(
tokenize("myHTTPServer __private_value foo__bar"),
vec![
"my", "HTTP", "Server", " ", "__", "private", "_", "value", " ", "foo", "__", "bar"
]
);
assert_eq!(
tokenize("XMLHttpRequest Version2Update"),
vec!["XML", "Http", "Request", " ", "Version2", "Update"]
);
}
#[test]
fn tokenizes_grouped_punctuation() {
assert_eq!(
tokenize("a::b -> c != d ..= e"),
vec![
"a", "::", "b", " ", "->", " ", "c", " ", "!=", " ", "d", " ", "..=", " ", "e"
]
);
assert_eq!(
tokenize("foo?.bar ?? baz"),
vec!["foo", "?.", "bar", " ", "??", " ", "baz"]
);
}
#[test]
fn tokenize_whitespace_runs() {
assert_eq!(tokenize(" "), vec![" "]);
assert_eq!(tokenize(" \n foo"), vec![" ", "\n", " ", "foo"]);
assert_eq!(tokenize("\r\n\nfoo"), vec!["\r\n\n", "foo"]);
}
}

View file

@ -63,6 +63,7 @@ extend-exclude = [
"crates/gpui_macos/src/dispatcher.rs",
# Tests contain partially incomplete words (by design)
"crates/edit_prediction_cli/src/split_commit.rs",
"crates/edit_prediction_metrics/src/kept_rate.rs",
# Eval examples contain intentionally partial words (e.g. "secur" for "secure")
"crates/edit_prediction_cli/evals/",
# Tests contain `baˇr` that cause `"ba" should be "by" or "be".`-like false-positives