mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
ep: Change kept_rate definition to a more intuitive one (#54306)
This change contains a number of fixes to make kept_rate more intuitive. It also adds a CLI utility to print debug info on how the metric is computed. Release Notes: - N/A
This commit is contained in:
parent
f92178e6e5
commit
72b41263f3
8 changed files with 1148 additions and 93 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -5338,6 +5338,7 @@ dependencies = [
|
|||
"language",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"similar",
|
||||
"tree-sitter",
|
||||
"zeta_prompt",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ path = "src/edit_prediction_metrics.rs"
|
|||
[dependencies]
|
||||
language.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json = "1.0"
|
||||
similar = "2.7.0"
|
||||
tree-sitter.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ mod reversal;
|
|||
mod tokenize;
|
||||
mod tree_sitter;
|
||||
|
||||
pub use kept_rate::AnnotatedToken;
|
||||
pub use kept_rate::KeptRateResult;
|
||||
#[cfg(test)]
|
||||
pub use kept_rate::TokenAnnotation;
|
||||
pub use kept_rate::annotate_kept_rate_tokens;
|
||||
pub use kept_rate::compute_kept_rate;
|
||||
pub use patch_metrics::ClassificationMetrics;
|
||||
pub use patch_metrics::Counts;
|
||||
|
|
@ -20,5 +21,6 @@ pub use patch_metrics::exact_lines_match;
|
|||
pub use patch_metrics::extract_changed_lines_from_diff;
|
||||
pub use patch_metrics::has_isolated_whitespace_changes;
|
||||
pub use patch_metrics::is_editable_region_correct;
|
||||
pub use patch_metrics::reconstruct_texts_from_diff;
|
||||
pub use reversal::compute_prediction_reversal_ratio_from_history;
|
||||
pub use tree_sitter::count_tree_sitter_errors;
|
||||
|
|
|
|||
|
|
@ -3,14 +3,20 @@ use serde::Serialize;
|
|||
|
||||
const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TokenAnnotation {
|
||||
Context,
|
||||
Kept,
|
||||
Discarded,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub struct AnnotatedToken {
|
||||
pub token: String,
|
||||
pub annotation: TokenAnnotation,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct KeptRateResult {
|
||||
|
|
@ -40,8 +46,7 @@ pub struct KeptRateResult {
|
|||
/// This includes both kept newly introduced characters and correctly
|
||||
/// deleted base characters.
|
||||
pub recall_rate: f64,
|
||||
/// Per-token classification for candidate tokens used by tests.
|
||||
#[cfg(test)]
|
||||
/// Per-token classification for candidate tokens.
|
||||
pub token_annotations: Vec<TokenAnnotation>,
|
||||
}
|
||||
|
||||
|
|
@ -51,9 +56,9 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
|
|||
|
||||
/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
|
||||
/// while sharing a single DP table construction.
|
||||
fn fill_lcs_keep_masks(
|
||||
a: &[&str],
|
||||
b: &[&str],
|
||||
fn fill_lcs_keep_masks<T: Eq>(
|
||||
a: &[T],
|
||||
b: &[T],
|
||||
mut keep_a: Option<&mut [bool]>,
|
||||
mut keep_b: Option<&mut [bool]>,
|
||||
) {
|
||||
|
|
@ -124,10 +129,10 @@ fn fill_lcs_keep_masks(
|
|||
let mut dp = vec![0u32; row_count * column_count];
|
||||
|
||||
for i in 1..row_count {
|
||||
let token_a = a_mid[i - 1];
|
||||
let token_a = &a_mid[i - 1];
|
||||
for j in 1..column_count {
|
||||
let index = dp_index(column_count, i, j);
|
||||
if token_a == b_mid[j - 1] {
|
||||
if token_a == &b_mid[j - 1] {
|
||||
dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
|
||||
} else {
|
||||
let up = dp[dp_index(column_count, i - 1, j)];
|
||||
|
|
@ -180,41 +185,91 @@ fn fill_lcs_keep_masks(
|
|||
}
|
||||
}
|
||||
|
||||
fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
|
||||
fn lcs_keep_mask<T: Eq>(a: &[T], b: &[T]) -> Vec<bool> {
|
||||
let mut keep_a = vec![false; a.len()];
|
||||
fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
|
||||
keep_a
|
||||
}
|
||||
|
||||
fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
|
||||
fn lcs_keep_masks<T: Eq>(a: &[T], b: &[T]) -> (Vec<bool>, Vec<bool>) {
|
||||
let mut keep_a = vec![false; a.len()];
|
||||
let mut keep_b = vec![false; b.len()];
|
||||
fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
|
||||
(keep_a, keep_b)
|
||||
}
|
||||
|
||||
fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
|
||||
let mut unmasked_tokens = Vec::with_capacity(tokens.len());
|
||||
let mut unmasked_chars = 0;
|
||||
let mut masked_chars = 0;
|
||||
#[derive(Debug, Clone)]
|
||||
struct ComparisonUnit {
|
||||
text: String,
|
||||
token_start: usize,
|
||||
token_end: usize,
|
||||
}
|
||||
|
||||
for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
|
||||
if is_masked {
|
||||
masked_chars += token.len();
|
||||
fn is_identifier_token(token: &str) -> bool {
|
||||
!token.is_empty()
|
||||
&& token
|
||||
.chars()
|
||||
.all(|character| character.is_alphanumeric() || character == '_')
|
||||
}
|
||||
|
||||
fn build_comparison_units(tokens: &[&str]) -> Vec<ComparisonUnit> {
|
||||
let mut units = Vec::new();
|
||||
let mut index = 0;
|
||||
|
||||
while index < tokens.len() {
|
||||
let token_start = index;
|
||||
|
||||
if is_identifier_token(tokens[index]) {
|
||||
let mut text = String::new();
|
||||
|
||||
while index < tokens.len() && is_identifier_token(tokens[index]) {
|
||||
text.push_str(tokens[index]);
|
||||
index += 1;
|
||||
}
|
||||
|
||||
units.push(ComparisonUnit {
|
||||
text,
|
||||
token_start,
|
||||
token_end: index,
|
||||
});
|
||||
} else {
|
||||
unmasked_tokens.push(token);
|
||||
unmasked_chars += token.len();
|
||||
units.push(ComparisonUnit {
|
||||
text: tokens[index].to_string(),
|
||||
token_start,
|
||||
token_end: index + 1,
|
||||
});
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
(unmasked_tokens, unmasked_chars, masked_chars)
|
||||
units
|
||||
}
|
||||
|
||||
fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
|
||||
tokens
|
||||
fn analyze_masked_units<'a>(
|
||||
units: &'a [ComparisonUnit],
|
||||
mask: &[bool],
|
||||
) -> (Vec<&'a str>, usize, usize) {
|
||||
let mut unmasked_units = Vec::with_capacity(units.len());
|
||||
let mut unmasked_chars = 0;
|
||||
let mut masked_chars = 0;
|
||||
|
||||
for (unit, &is_masked) in units.iter().zip(mask.iter()) {
|
||||
if is_masked {
|
||||
masked_chars += unit.text.len();
|
||||
} else {
|
||||
unmasked_units.push(unit.text.as_str());
|
||||
unmasked_chars += unit.text.len();
|
||||
}
|
||||
}
|
||||
|
||||
(unmasked_units, unmasked_chars, masked_chars)
|
||||
}
|
||||
|
||||
fn count_unmasked_unit_chars(units: &[ComparisonUnit], mask: &[bool]) -> usize {
|
||||
units
|
||||
.iter()
|
||||
.zip(mask.iter())
|
||||
.filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
|
||||
.filter_map(|(unit, &is_masked)| (!is_masked).then_some(unit.text.len()))
|
||||
.sum()
|
||||
}
|
||||
|
||||
|
|
@ -239,7 +294,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
|
|||
context_chars,
|
||||
kept_rate: 1.0,
|
||||
recall_rate: 1.0,
|
||||
#[cfg(test)]
|
||||
token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
|
||||
};
|
||||
}
|
||||
|
|
@ -258,7 +312,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
|
|||
context_chars: 0,
|
||||
kept_rate: 0.0,
|
||||
recall_rate: 0.0,
|
||||
#[cfg(test)]
|
||||
token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
|
||||
};
|
||||
}
|
||||
|
|
@ -267,29 +320,29 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
|
|||
let candidate_tokens = tokenize(candidate);
|
||||
let reference_tokens = tokenize(reference);
|
||||
|
||||
let (candidate_base_mask, base_candidate_mask) =
|
||||
lcs_keep_masks(&candidate_tokens, &base_tokens);
|
||||
let (candidate_reference_mask, reference_candidate_mask) =
|
||||
lcs_keep_masks(&candidate_tokens, &reference_tokens);
|
||||
let context_mask: Vec<bool> = candidate_base_mask
|
||||
let candidate_units = build_comparison_units(&candidate_tokens);
|
||||
let base_units = build_comparison_units(&base_tokens);
|
||||
let reference_units = build_comparison_units(&reference_tokens);
|
||||
|
||||
let candidate_unit_texts: Vec<&str> = candidate_units
|
||||
.iter()
|
||||
.zip(candidate_reference_mask.iter())
|
||||
.map(|(&in_base, &in_reference)| in_base && in_reference)
|
||||
.map(|unit| unit.text.as_str())
|
||||
.collect();
|
||||
let base_unit_texts: Vec<&str> = base_units.iter().map(|unit| unit.text.as_str()).collect();
|
||||
let reference_unit_texts: Vec<&str> = reference_units
|
||||
.iter()
|
||||
.map(|unit| unit.text.as_str())
|
||||
.collect();
|
||||
|
||||
let (candidate_base_mask, base_candidate_mask) =
|
||||
lcs_keep_masks(&candidate_unit_texts, &base_unit_texts);
|
||||
let (stripped_candidate, candidate_new_chars, context_chars) =
|
||||
analyze_masked_tokens(&candidate_tokens, &context_mask);
|
||||
analyze_masked_units(&candidate_units, &candidate_base_mask);
|
||||
|
||||
let (reference_base_mask, base_reference_mask) =
|
||||
lcs_keep_masks(&reference_tokens, &base_tokens);
|
||||
let reference_context_mask: Vec<bool> = reference_base_mask
|
||||
.iter()
|
||||
.zip(reference_candidate_mask.iter())
|
||||
.map(|(&in_base, &in_candidate)| in_base && in_candidate)
|
||||
.collect();
|
||||
|
||||
lcs_keep_masks(&reference_unit_texts, &base_unit_texts);
|
||||
let (stripped_reference, reference_new_chars, _) =
|
||||
analyze_masked_tokens(&reference_tokens, &reference_context_mask);
|
||||
analyze_masked_units(&reference_units, &reference_base_mask);
|
||||
|
||||
let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
|
||||
|
||||
|
|
@ -299,13 +352,13 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
|
|||
.filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
|
||||
.sum();
|
||||
|
||||
let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
|
||||
let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
|
||||
let correctly_deleted_chars: usize = base_tokens
|
||||
let candidate_deleted_chars = count_unmasked_unit_chars(&base_units, &base_candidate_mask);
|
||||
let reference_deleted_chars = count_unmasked_unit_chars(&base_units, &base_reference_mask);
|
||||
let correctly_deleted_chars: usize = base_units
|
||||
.iter()
|
||||
.zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
|
||||
.filter_map(|(&token, (&in_candidate, &in_reference))| {
|
||||
(!in_candidate && !in_reference).then_some(token.len())
|
||||
.filter_map(|(unit, (&in_candidate, &in_reference))| {
|
||||
(!in_candidate && !in_reference).then_some(unit.text.len())
|
||||
})
|
||||
.sum();
|
||||
|
||||
|
|
@ -326,24 +379,28 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
|
|||
matched_edit_chars as f64 / reference_edit_chars as f64
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
let token_annotations = {
|
||||
let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
|
||||
let mut token_annotations = vec![TokenAnnotation::Context; candidate_tokens.len()];
|
||||
let mut new_index = 0;
|
||||
for (token_index, _token) in candidate_tokens.iter().enumerate() {
|
||||
if context_mask[token_index] {
|
||||
token_annotations.push(TokenAnnotation::Context);
|
||||
|
||||
for (unit_index, unit) in candidate_units.iter().enumerate() {
|
||||
let annotation = if candidate_base_mask[unit_index] {
|
||||
TokenAnnotation::Context
|
||||
} else {
|
||||
let annotation = if keep_mask[new_index] {
|
||||
TokenAnnotation::Kept
|
||||
} else {
|
||||
TokenAnnotation::Discarded
|
||||
};
|
||||
#[cfg(test)]
|
||||
token_annotations.push(annotation);
|
||||
new_index += 1;
|
||||
annotation
|
||||
};
|
||||
|
||||
for token_index in unit.token_start..unit.token_end {
|
||||
token_annotations[token_index] = annotation;
|
||||
}
|
||||
}
|
||||
|
||||
token_annotations
|
||||
};
|
||||
|
||||
|
|
@ -358,14 +415,30 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
|
|||
context_chars,
|
||||
kept_rate,
|
||||
recall_rate,
|
||||
#[cfg(test)]
|
||||
token_annotations,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn annotate_kept_rate_tokens(
|
||||
base: &str,
|
||||
candidate: &str,
|
||||
reference: &str,
|
||||
) -> Vec<AnnotatedToken> {
|
||||
let result = compute_kept_rate(base, candidate, reference);
|
||||
tokenize(candidate)
|
||||
.into_iter()
|
||||
.zip(result.token_annotations)
|
||||
.map(|(token, annotation)| AnnotatedToken {
|
||||
token: token.to_string(),
|
||||
annotation,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_kept_rate {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
|
||||
#[test]
|
||||
fn test_lcs_keep_masks() {
|
||||
|
|
@ -439,16 +512,24 @@ mod test_kept_rate {
|
|||
|
||||
#[test]
|
||||
fn test_missing_deletion() {
|
||||
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
|
||||
let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
|
||||
let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
|
||||
let base = indoc! {"
|
||||
fn example() {
|
||||
epr
|
||||
"};
|
||||
let candidate = indoc! {r#"
|
||||
fn example() {
|
||||
epr
|
||||
eprintln!("");
|
||||
"#};
|
||||
let reference = indoc! {r#"
|
||||
fn example() {
|
||||
eprintln!("");
|
||||
"#};
|
||||
|
||||
let result = compute_kept_rate(base, candidate, reference);
|
||||
assert!(
|
||||
result.kept_rate < 0.85,
|
||||
"expected kept_rate < 0.85, got {}",
|
||||
result.kept_rate
|
||||
);
|
||||
assert!(result.discarded_chars > 0);
|
||||
assert!((result.kept_rate - (14.0 / 15.0)).abs() < 1e-6);
|
||||
assert_eq!(result.kept_chars, 14);
|
||||
assert_eq!(result.discarded_chars, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -472,8 +553,17 @@ mod test_kept_rate {
|
|||
|
||||
#[test]
|
||||
fn test_bails_for_dirty_final() {
|
||||
let base = "fn example() {\n work();\n}\n";
|
||||
let candidate = "fn example() {\n work();\n predicted();\n}\n";
|
||||
let base = indoc! {"
|
||||
fn example() {
|
||||
work();
|
||||
}
|
||||
"};
|
||||
let candidate = indoc! {"
|
||||
fn example() {
|
||||
work();
|
||||
predicted();
|
||||
}
|
||||
"};
|
||||
let reference = format!(
|
||||
"fn example() {{\n work();\n {}\n}}\n",
|
||||
"settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
|
||||
|
|
@ -488,9 +578,19 @@ mod test_kept_rate {
|
|||
|
||||
#[test]
|
||||
fn test_eprintln_token_alignment() {
|
||||
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
|
||||
let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
|
||||
let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
|
||||
let base = indoc! {"
|
||||
fn example() {
|
||||
epr
|
||||
"};
|
||||
let candidate = indoc! {r#"
|
||||
fn example() {
|
||||
eprintln!("hello world!");
|
||||
"#};
|
||||
let reference = indoc! {r#"
|
||||
fn example() {
|
||||
eprintln!("");
|
||||
"#};
|
||||
|
||||
let result = compute_kept_rate(base, candidate, reference);
|
||||
assert!(result.discarded_chars > 0);
|
||||
assert!(result.kept_chars > 0);
|
||||
|
|
@ -499,6 +599,42 @@ mod test_kept_rate {
|
|||
assert_eq!(result.discarded_chars, 12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kept_rate_treats_unchanged_stale_text_as_context() {
|
||||
let base = indoc! {"
|
||||
a=fomr
|
||||
b=old
|
||||
"};
|
||||
let candidate = indoc! {"
|
||||
a=formula;
|
||||
b=old
|
||||
"};
|
||||
let reference = indoc! {"
|
||||
a=formula;
|
||||
b=new
|
||||
"};
|
||||
|
||||
let result = compute_kept_rate(base, candidate, reference);
|
||||
let candidate_tokens = tokenize(candidate);
|
||||
|
||||
assert_eq!(result.candidate_new_chars, "formula".len() + ";".len());
|
||||
assert_eq!(result.kept_chars, "formula".len() + ";".len());
|
||||
assert_eq!(result.discarded_chars, 0);
|
||||
assert_eq!(result.candidate_deleted_chars, "fomr".len());
|
||||
assert_eq!(result.correctly_deleted_chars, "fomr".len());
|
||||
assert!((result.kept_rate - 1.0).abs() < 1e-6);
|
||||
assert!((result.recall_rate - (2.0 / 3.0)).abs() < 1e-6);
|
||||
|
||||
let old_index = candidate_tokens
|
||||
.iter()
|
||||
.position(|&token| token == "old")
|
||||
.expect("old token not found");
|
||||
assert_eq!(
|
||||
result.token_annotations[old_index],
|
||||
TokenAnnotation::Context
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_annotations_rename() {
|
||||
let base = " foo(old_name)\n";
|
||||
|
|
@ -514,7 +650,7 @@ mod test_kept_rate {
|
|||
assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
|
||||
|
||||
for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
|
||||
if token == "new_name" {
|
||||
if matches!(token, "new" | "_" | "name") {
|
||||
assert_eq!(annotation, TokenAnnotation::Kept);
|
||||
} else {
|
||||
assert_eq!(annotation, TokenAnnotation::Context);
|
||||
|
|
@ -524,9 +660,18 @@ mod test_kept_rate {
|
|||
|
||||
#[test]
|
||||
fn test_annotations_eprintln_coloring() {
|
||||
let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
|
||||
let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
|
||||
let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
|
||||
let base = indoc! {"
|
||||
fn example() {
|
||||
epr
|
||||
"};
|
||||
let candidate = indoc! {r#"
|
||||
fn example() {
|
||||
eprintln!("hello world!");
|
||||
"#};
|
||||
let reference = indoc! {r#"
|
||||
fn example() {
|
||||
eprintln!("");
|
||||
"#};
|
||||
let result = compute_kept_rate(base, candidate, reference);
|
||||
let candidate_tokens = tokenize(candidate);
|
||||
|
||||
|
|
|
|||
710
crates/edit_prediction_metrics/src/main.rs
Normal file
710
crates/edit_prediction_metrics/src/main.rs
Normal file
|
|
@ -0,0 +1,710 @@
|
|||
use std::env;
|
||||
use std::fmt::Write as _;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process;
|
||||
|
||||
use edit_prediction_metrics::{
|
||||
ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation,
|
||||
annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes,
|
||||
delta_chr_f, exact_lines_match, extract_changed_lines_from_diff,
|
||||
has_isolated_whitespace_changes, is_editable_region_correct,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
fn main() {
|
||||
if let Err(error) = run() {
|
||||
eprintln!("error: {error}");
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn run() -> Result<(), String> {
|
||||
let args: Vec<String> = env::args().skip(1).collect();
|
||||
if args.is_empty() {
|
||||
print_usage();
|
||||
return Err("missing arguments".to_string());
|
||||
}
|
||||
|
||||
let input = CliInput::parse(&args)?;
|
||||
let report = match input {
|
||||
CliInput::Files {
|
||||
base_path,
|
||||
expected_patch_path,
|
||||
actual_patch_path,
|
||||
} => {
|
||||
let base = fs::read_to_string(&base_path)
|
||||
.map_err(|err| format!("failed to read {}: {err}", base_path.display()))?;
|
||||
let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| {
|
||||
format!("failed to read {}: {err}", expected_patch_path.display())
|
||||
})?;
|
||||
let actual_patch = fs::read_to_string(&actual_patch_path)
|
||||
.map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?;
|
||||
|
||||
let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?;
|
||||
let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?;
|
||||
|
||||
EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
|
||||
}
|
||||
CliInput::Json {
|
||||
json_path,
|
||||
prediction_index,
|
||||
} => {
|
||||
let json = fs::read_to_string(&json_path)
|
||||
.map_err(|err| format!("failed to read {}: {err}", json_path.display()))?;
|
||||
let example: JsonExample = serde_json::from_str(&json)
|
||||
.map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?;
|
||||
|
||||
let base = example.prompt_inputs.cursor_excerpt;
|
||||
let excerpt_start_row = example.prompt_inputs.excerpt_start_row;
|
||||
let expected_patch = example
|
||||
.expected_patches
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?;
|
||||
let actual_patch = example
|
||||
.predictions
|
||||
.into_iter()
|
||||
.nth(prediction_index)
|
||||
.ok_or_else(|| {
|
||||
format!("JSON input does not contain predictions[{prediction_index}]")
|
||||
})?
|
||||
.actual_patch;
|
||||
|
||||
let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?;
|
||||
let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?;
|
||||
|
||||
EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
|
||||
}
|
||||
};
|
||||
|
||||
print_report(&report);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_usage() {
|
||||
eprintln!(
|
||||
"Usage:\n edit_prediction_metrics --base <base.txt> --expected-patch <expected.diff> --actual-patch <actual.diff>\n edit_prediction_metrics --json <example.json> [--prediction-index <n>]"
|
||||
);
|
||||
}
|
||||
|
||||
enum CliInput {
|
||||
Files {
|
||||
base_path: std::path::PathBuf,
|
||||
expected_patch_path: std::path::PathBuf,
|
||||
actual_patch_path: std::path::PathBuf,
|
||||
},
|
||||
Json {
|
||||
json_path: std::path::PathBuf,
|
||||
prediction_index: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl CliInput {
|
||||
fn parse(args: &[String]) -> Result<Self, String> {
|
||||
let mut base_path = None;
|
||||
let mut expected_patch_path = None;
|
||||
let mut actual_patch_path = None;
|
||||
let mut json_path = None;
|
||||
let mut prediction_index = 0usize;
|
||||
|
||||
let mut index = 0;
|
||||
while index < args.len() {
|
||||
match args[index].as_str() {
|
||||
"--base" => {
|
||||
index += 1;
|
||||
base_path = Some(path_arg(args, index, "--base")?);
|
||||
}
|
||||
"--expected-patch" => {
|
||||
index += 1;
|
||||
expected_patch_path = Some(path_arg(args, index, "--expected-patch")?);
|
||||
}
|
||||
"--actual-patch" => {
|
||||
index += 1;
|
||||
actual_patch_path = Some(path_arg(args, index, "--actual-patch")?);
|
||||
}
|
||||
"--json" => {
|
||||
index += 1;
|
||||
json_path = Some(path_arg(args, index, "--json")?);
|
||||
}
|
||||
"--prediction-index" => {
|
||||
index += 1;
|
||||
let raw = string_arg(args, index, "--prediction-index")?;
|
||||
prediction_index = raw.parse::<usize>().map_err(|err| {
|
||||
format!("invalid value for --prediction-index ({raw}): {err}")
|
||||
})?;
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
print_usage();
|
||||
process::exit(0);
|
||||
}
|
||||
unknown => {
|
||||
return Err(format!("unrecognized argument: {unknown}"));
|
||||
}
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
|
||||
if let Some(json_path) = json_path {
|
||||
if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() {
|
||||
return Err(
|
||||
"--json cannot be combined with --base/--expected-patch/--actual-patch"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
return Ok(CliInput::Json {
|
||||
json_path,
|
||||
prediction_index,
|
||||
});
|
||||
}
|
||||
|
||||
match (base_path, expected_patch_path, actual_patch_path) {
|
||||
(Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => {
|
||||
Ok(CliInput::Files {
|
||||
base_path,
|
||||
expected_patch_path,
|
||||
actual_patch_path,
|
||||
})
|
||||
}
|
||||
_ => Err(
|
||||
"expected either --json <file> or all of --base, --expected-patch, and --actual-patch"
|
||||
.to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn path_arg(args: &[String], index: usize, flag: &str) -> Result<std::path::PathBuf, String> {
|
||||
Ok(Path::new(string_arg(args, index, flag)?).to_path_buf())
|
||||
}
|
||||
|
||||
fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> {
|
||||
args.get(index)
|
||||
.map(|value| value.as_str())
|
||||
.ok_or_else(|| format!("missing value for {flag}"))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EvaluationReport {
|
||||
base: String,
|
||||
expected: String,
|
||||
actual: String,
|
||||
kept_rate: KeptRateResult,
|
||||
exact_lines: ClassificationMetrics,
|
||||
delta_chr_f: DeltaChrFMetrics,
|
||||
expected_changed_lines: usize,
|
||||
actual_changed_lines: usize,
|
||||
token_changes: edit_prediction_metrics::TokenChangeCounts,
|
||||
isolated_whitespace_changes: bool,
|
||||
editable_region_correct: bool,
|
||||
expected_braces_disbalance: usize,
|
||||
actual_braces_disbalance: usize,
|
||||
}
|
||||
|
||||
impl EvaluationReport {
|
||||
fn new(
|
||||
base: String,
|
||||
expected_patch: String,
|
||||
actual_patch: String,
|
||||
expected: String,
|
||||
actual: String,
|
||||
) -> Self {
|
||||
let kept_rate = compute_kept_rate(&base, &actual, &expected);
|
||||
let exact_lines = exact_lines_match(&expected_patch, &actual_patch);
|
||||
let delta_chr_f = delta_chr_f(&base, &expected, &actual);
|
||||
let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch)
|
||||
.values()
|
||||
.sum();
|
||||
let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch)
|
||||
.values()
|
||||
.sum();
|
||||
let token_changes = count_patch_token_changes(&actual_patch);
|
||||
let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None);
|
||||
let editable_region_correct = is_editable_region_correct(&actual_patch);
|
||||
let expected_braces_disbalance = braces_disbalance(&expected);
|
||||
let actual_braces_disbalance = braces_disbalance(&actual);
|
||||
|
||||
Self {
|
||||
base,
|
||||
expected,
|
||||
actual,
|
||||
kept_rate,
|
||||
exact_lines,
|
||||
delta_chr_f,
|
||||
expected_changed_lines,
|
||||
actual_changed_lines,
|
||||
token_changes,
|
||||
isolated_whitespace_changes,
|
||||
editable_region_correct,
|
||||
expected_braces_disbalance,
|
||||
actual_braces_disbalance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn print_report(report: &EvaluationReport) {
|
||||
println!("Metrics");
|
||||
println!("=======");
|
||||
println!("kept_rate: {:.6}", report.kept_rate.kept_rate);
|
||||
println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate);
|
||||
println!("delta_chr_f: {:.6}", report.delta_chr_f.score);
|
||||
println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision);
|
||||
println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall);
|
||||
println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta);
|
||||
println!();
|
||||
|
||||
println!("Exact line match");
|
||||
println!("----------------");
|
||||
println!("true_positives: {}", report.exact_lines.true_positives);
|
||||
println!("false_positives: {}", report.exact_lines.false_positives);
|
||||
println!("false_negatives: {}", report.exact_lines.false_negatives);
|
||||
println!("precision: {:.6}", report.exact_lines.precision());
|
||||
println!("recall: {:.6}", report.exact_lines.recall());
|
||||
println!("f1: {:.6}", report.exact_lines.f1());
|
||||
println!("expected_changed_lines: {}", report.expected_changed_lines);
|
||||
println!("actual_changed_lines: {}", report.actual_changed_lines);
|
||||
println!();
|
||||
|
||||
println!("Patch structure");
|
||||
println!("---------------");
|
||||
println!("inserted_tokens: {}", report.token_changes.inserted_tokens);
|
||||
println!("deleted_tokens: {}", report.token_changes.deleted_tokens);
|
||||
println!(
|
||||
"isolated_whitespace_changes: {}",
|
||||
report.isolated_whitespace_changes
|
||||
);
|
||||
println!(
|
||||
"editable_region_correct: {}",
|
||||
report.editable_region_correct
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Final text checks");
|
||||
println!("-----------------");
|
||||
println!(
|
||||
"expected_braces_disbalance: {}",
|
||||
report.expected_braces_disbalance
|
||||
);
|
||||
println!(
|
||||
"actual_braces_disbalance: {}",
|
||||
report.actual_braces_disbalance
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Kept-rate breakdown");
|
||||
println!("-------------------");
|
||||
println!(
|
||||
"candidate_new_chars: {}",
|
||||
report.kept_rate.candidate_new_chars
|
||||
);
|
||||
println!(
|
||||
"reference_new_chars: {}",
|
||||
report.kept_rate.reference_new_chars
|
||||
);
|
||||
println!(
|
||||
"candidate_deleted_chars: {}",
|
||||
report.kept_rate.candidate_deleted_chars
|
||||
);
|
||||
println!(
|
||||
"reference_deleted_chars: {}",
|
||||
report.kept_rate.reference_deleted_chars
|
||||
);
|
||||
println!("kept_chars: {}", report.kept_rate.kept_chars);
|
||||
println!(
|
||||
"correctly_deleted_chars: {}",
|
||||
report.kept_rate.correctly_deleted_chars
|
||||
);
|
||||
println!("discarded_chars: {}", report.kept_rate.discarded_chars);
|
||||
println!("context_chars: {}", report.kept_rate.context_chars);
|
||||
println!();
|
||||
|
||||
print_kept_rate_explanation(&report.base, &report.actual, &report.expected);
|
||||
}
|
||||
|
||||
fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) {
|
||||
println!("Kept-rate explanation");
|
||||
println!("---------------------");
|
||||
println!("Legend: context = default, kept = green background, discarded = red background");
|
||||
println!();
|
||||
|
||||
let annotated = annotate_kept_rate_tokens(base, actual, expected);
|
||||
println!("Actual final text with token annotations:");
|
||||
println!("{}", render_annotated_tokens(&annotated));
|
||||
println!();
|
||||
}
|
||||
|
||||
fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String {
|
||||
const RESET: &str = "\x1b[0m";
|
||||
const KEPT_STYLE: &str = "\x1b[30;42m";
|
||||
const DISCARDED_STYLE: &str = "\x1b[30;41m";
|
||||
|
||||
let mut rendered = String::new();
|
||||
for token in tokens {
|
||||
let style = match token.annotation {
|
||||
TokenAnnotation::Context => "",
|
||||
TokenAnnotation::Kept => KEPT_STYLE,
|
||||
TokenAnnotation::Discarded => DISCARDED_STYLE,
|
||||
};
|
||||
|
||||
if style.is_empty() {
|
||||
rendered.push_str(&visualize_whitespace(&token.token));
|
||||
} else {
|
||||
rendered.push_str(style);
|
||||
rendered.push_str(&visualize_whitespace(&token.token));
|
||||
rendered.push_str(RESET);
|
||||
}
|
||||
}
|
||||
rendered
|
||||
}
|
||||
|
||||
fn visualize_whitespace(token: &str) -> String {
|
||||
let mut rendered = String::new();
|
||||
for ch in token.chars() {
|
||||
match ch {
|
||||
' ' => rendered.push('·'),
|
||||
'\t' => rendered.push('⇥'),
|
||||
'\n' => rendered.push_str("↵\n"),
|
||||
_ => rendered.push(ch),
|
||||
}
|
||||
}
|
||||
rendered
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JsonExample {
|
||||
prompt_inputs: PromptInputs,
|
||||
expected_patches: Vec<String>,
|
||||
predictions: Vec<Prediction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PromptInputs {
|
||||
cursor_excerpt: String,
|
||||
excerpt_start_row: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Prediction {
|
||||
actual_patch: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ParsedHunk {
|
||||
old_start: u32,
|
||||
lines: Vec<HunkLine>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum HunkLine {
|
||||
Context(String),
|
||||
Addition(String),
|
||||
Deletion(String),
|
||||
}
|
||||
|
||||
fn apply_patch_to_excerpt(
|
||||
base: &str,
|
||||
patch: &str,
|
||||
excerpt_start_row: u32,
|
||||
) -> Result<String, String> {
|
||||
let hunks = parse_diff_hunks(patch);
|
||||
|
||||
let result = try_apply_hunks(base, &hunks, excerpt_start_row);
|
||||
|
||||
// Predicted patches may use excerpt-relative line numbers instead of
|
||||
// file-global ones. When all hunks fall outside the excerpt window the
|
||||
// result is identical to the base text. Retry with a zero offset so the
|
||||
// line numbers are interpreted relative to the excerpt.
|
||||
if excerpt_start_row > 0 && !hunks.is_empty() {
|
||||
let should_retry = match &result {
|
||||
Ok(text) => text == base,
|
||||
Err(_) => true,
|
||||
};
|
||||
|
||||
if should_retry {
|
||||
let fallback = try_apply_hunks(base, &hunks, 0);
|
||||
if matches!(&fallback, Ok(text) if text != base) {
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn try_apply_hunks(
|
||||
base: &str,
|
||||
hunks: &[ParsedHunk],
|
||||
excerpt_start_row: u32,
|
||||
) -> Result<String, String> {
|
||||
let base_has_trailing_newline = base.ends_with('\n');
|
||||
let mut lines = split_preserving_final_empty_line(base);
|
||||
let original_line_count = lines.len() as u32;
|
||||
|
||||
let excerpt_end_row = excerpt_start_row + original_line_count;
|
||||
let mut line_delta: i64 = 0;
|
||||
|
||||
for hunk in hunks {
|
||||
let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) {
|
||||
Some(filtered) => filtered,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta;
|
||||
if local_start < 0 {
|
||||
return Err(format!(
|
||||
"patch application moved before excerpt start at source row {}",
|
||||
filtered.old_start
|
||||
));
|
||||
}
|
||||
let local_start = local_start as usize;
|
||||
|
||||
if local_start > lines.len() {
|
||||
return Err(format!(
|
||||
"patch application starts past excerpt end at local line {}",
|
||||
local_start + 1
|
||||
));
|
||||
}
|
||||
|
||||
let old_len = filtered
|
||||
.lines
|
||||
.iter()
|
||||
.filter(|line| !matches!(line, HunkLine::Addition(_)))
|
||||
.count();
|
||||
let new_len = filtered
|
||||
.lines
|
||||
.iter()
|
||||
.filter(|line| !matches!(line, HunkLine::Deletion(_)))
|
||||
.count();
|
||||
|
||||
let old_segment: Vec<&str> = filtered
|
||||
.lines
|
||||
.iter()
|
||||
.filter_map(|line| match line {
|
||||
HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()),
|
||||
HunkLine::Addition(_) => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let new_segment: Vec<String> = filtered
|
||||
.lines
|
||||
.iter()
|
||||
.filter_map(|line| match line {
|
||||
HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()),
|
||||
HunkLine::Deletion(_) => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if local_start + old_len > lines.len() {
|
||||
return Err(format!(
|
||||
"patch application exceeds excerpt bounds near source row {}",
|
||||
filtered.old_start
|
||||
));
|
||||
}
|
||||
|
||||
let current_segment: Vec<&str> = lines[local_start..local_start + old_len]
|
||||
.iter()
|
||||
.map(String::as_str)
|
||||
.collect();
|
||||
|
||||
if current_segment != old_segment {
|
||||
let mut details = String::new();
|
||||
let _ = write!(
|
||||
details,
|
||||
"patch context mismatch near source row {}: expected {:?}, found {:?}",
|
||||
filtered.old_start, old_segment, current_segment
|
||||
);
|
||||
return Err(details);
|
||||
}
|
||||
|
||||
lines.splice(local_start..local_start + old_len, new_segment);
|
||||
line_delta += new_len as i64 - old_len as i64;
|
||||
}
|
||||
|
||||
Ok(join_lines(&lines, base_has_trailing_newline))
|
||||
}
|
||||
|
||||
fn split_preserving_final_empty_line(text: &str) -> Vec<String> {
|
||||
let mut lines: Vec<String> = text.lines().map(ToString::to_string).collect();
|
||||
if text.ends_with('\n') {
|
||||
if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() {
|
||||
lines.push(String::new());
|
||||
}
|
||||
}
|
||||
lines
|
||||
}
|
||||
|
||||
fn join_lines(lines: &[String], had_trailing_newline: bool) -> String {
|
||||
if lines.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut joined = lines.join("\n");
|
||||
if had_trailing_newline && !joined.ends_with('\n') {
|
||||
joined.push('\n');
|
||||
}
|
||||
if !had_trailing_newline && joined.ends_with('\n') {
|
||||
joined.pop();
|
||||
}
|
||||
joined
|
||||
}
|
||||
|
||||
fn filter_hunk_to_excerpt(
|
||||
hunk: &ParsedHunk,
|
||||
excerpt_start_row: u32,
|
||||
excerpt_end_row: u32,
|
||||
) -> Option<ParsedHunk> {
|
||||
let mut filtered_lines = Vec::new();
|
||||
let mut current_old_row = hunk.old_start.saturating_sub(1);
|
||||
let mut filtered_old_start = None;
|
||||
let mut has_overlap = false;
|
||||
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
HunkLine::Context(text) => {
|
||||
let in_excerpt =
|
||||
current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
|
||||
if in_excerpt {
|
||||
filtered_old_start.get_or_insert(current_old_row);
|
||||
filtered_lines.push(HunkLine::Context(text.clone()));
|
||||
has_overlap = true;
|
||||
}
|
||||
current_old_row += 1;
|
||||
}
|
||||
HunkLine::Deletion(text) => {
|
||||
let in_excerpt =
|
||||
current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
|
||||
if in_excerpt {
|
||||
filtered_old_start.get_or_insert(current_old_row);
|
||||
filtered_lines.push(HunkLine::Deletion(text.clone()));
|
||||
has_overlap = true;
|
||||
}
|
||||
current_old_row += 1;
|
||||
}
|
||||
HunkLine::Addition(text) => {
|
||||
let insertion_in_excerpt =
|
||||
current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row;
|
||||
if insertion_in_excerpt {
|
||||
filtered_old_start.get_or_insert(current_old_row);
|
||||
filtered_lines.push(HunkLine::Addition(text.clone()));
|
||||
has_overlap = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_overlap {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(ParsedHunk {
|
||||
old_start: filtered_old_start.unwrap_or(excerpt_start_row),
|
||||
lines: filtered_lines,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
|
||||
let mut hunks = Vec::new();
|
||||
let mut current_hunk: Option<ParsedHunk> = None;
|
||||
|
||||
for line in diff.lines() {
|
||||
if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) {
|
||||
if let Some(hunk) = current_hunk.take() {
|
||||
hunks.push(hunk);
|
||||
}
|
||||
let _ = old_count;
|
||||
current_hunk = Some(ParsedHunk {
|
||||
old_start,
|
||||
lines: Vec::new(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(hunk) = current_hunk.as_mut() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(text) = line.strip_prefix('+') {
|
||||
if !line.starts_with("+++") {
|
||||
hunk.lines.push(HunkLine::Addition(text.to_string()));
|
||||
}
|
||||
} else if let Some(text) = line.strip_prefix('-') {
|
||||
if !line.starts_with("---") {
|
||||
hunk.lines.push(HunkLine::Deletion(text.to_string()));
|
||||
}
|
||||
} else if let Some(text) = line.strip_prefix(' ') {
|
||||
hunk.lines.push(HunkLine::Context(text.to_string()));
|
||||
} else if line.is_empty() {
|
||||
hunk.lines.push(HunkLine::Context(String::new()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(hunk) = current_hunk {
|
||||
hunks.push(hunk);
|
||||
}
|
||||
|
||||
hunks
|
||||
}
|
||||
|
||||
fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
|
||||
let line = line.strip_prefix("@@ -")?;
|
||||
let (old_part, rest) = line.split_once(' ')?;
|
||||
let rest = rest.strip_prefix('+')?;
|
||||
let (new_part, _) = rest.split_once(" @@")?;
|
||||
|
||||
let (old_start, old_count) = parse_hunk_range(old_part)?;
|
||||
let (new_start, new_count) = parse_hunk_range(new_part)?;
|
||||
Some((old_start, old_count, new_start, new_count))
|
||||
}
|
||||
|
||||
fn parse_hunk_range(part: &str) -> Option<(u32, u32)> {
|
||||
if let Some((start, count)) = part.split_once(',') {
|
||||
Some((start.parse().ok()?, count.parse().ok()?))
|
||||
} else {
|
||||
Some((part.parse().ok()?, 1))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn applies_patch_in_file_mode() {
|
||||
let base = "fn main() {\n println!(\"hello\");\n}\n";
|
||||
let patch = "@@ -1,3 +1,3 @@\n fn main() {\n- println!(\"hello\");\n+ println!(\"world\");\n }\n";
|
||||
|
||||
let actual = apply_patch_to_excerpt(base, patch, 0).unwrap();
|
||||
assert_eq!(actual, "fn main() {\n println!(\"world\");\n}\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn applies_patch_in_json_excerpt_mode() {
|
||||
let base = "b\nc\nd\n";
|
||||
let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
|
||||
|
||||
let actual = apply_patch_to_excerpt(base, patch, 1).unwrap();
|
||||
assert_eq!(actual, "x\ny\nd\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn applies_patch_with_excerpt_relative_line_numbers() {
|
||||
let base = "a\nb\nc\nd\n";
|
||||
// Patch uses excerpt-relative line numbers (line 2 of excerpt)
|
||||
// even though the excerpt starts at file row 100.
|
||||
let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
|
||||
|
||||
let actual = apply_patch_to_excerpt(base, patch, 100).unwrap();
|
||||
assert_eq!(actual, "a\nx\ny\nd\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefers_file_global_line_numbers_over_excerpt_relative() {
|
||||
let base = "a\nb\nc\n";
|
||||
// Patch uses file-global line numbers: excerpt starts at row 5,
|
||||
// hunk targets line 6 (1-based) = row 5 (0-based) = first line.
|
||||
let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n";
|
||||
|
||||
let actual = apply_patch_to_excerpt(base, patch, 5).unwrap();
|
||||
assert_eq!(actual, "x\ny\nc\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -687,6 +687,35 @@ fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
/// Reconstruct old and new text from a unified diff.
|
||||
///
|
||||
/// Context and deletion lines form the old text; context and addition
|
||||
/// lines form the new text. Returns `(old_text, new_text)`.
|
||||
pub fn reconstruct_texts_from_diff(patch_str: &str) -> (String, String) {
|
||||
let patch = Patch::parse_unified_diff(patch_str);
|
||||
let mut old_lines: Vec<&str> = Vec::new();
|
||||
let mut new_lines: Vec<&str> = Vec::new();
|
||||
|
||||
for hunk in &patch.hunks {
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
PatchLine::Context(content) => {
|
||||
old_lines.push(content);
|
||||
new_lines.push(content);
|
||||
}
|
||||
PatchLine::Deletion(content) => {
|
||||
old_lines.push(content);
|
||||
}
|
||||
PatchLine::Addition(content) => {
|
||||
new_lines.push(content);
|
||||
}
|
||||
PatchLine::Garbage(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(old_lines.join("\n"), new_lines.join("\n"))
|
||||
}
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct Patch {
|
||||
hunks: Vec<Hunk>,
|
||||
|
|
|
|||
|
|
@ -1,33 +1,158 @@
|
|||
fn char_class(character: char) -> u8 {
|
||||
if character.is_alphanumeric() || character == '_' {
|
||||
0
|
||||
use std::{iter::Peekable, str::CharIndices};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum CharClass {
|
||||
Identifier,
|
||||
Newline,
|
||||
Whitespace,
|
||||
Punctuation,
|
||||
}
|
||||
|
||||
const MULTI_CHAR_PUNCTUATION: &[&str] = &[
|
||||
">>>=", "<<=", ">>=", "...", "..=", "??=", "**=", ">>>", "::", "->", "=>", "==", "!=", "<=",
|
||||
">=", "&&", "||", "<<", ">>", "..", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "++", "--",
|
||||
"**", "??", "?.", ":=", "<-", "//", "/*", "*/",
|
||||
];
|
||||
|
||||
fn char_class(character: char) -> CharClass {
|
||||
if character == '\n' || character == '\r' {
|
||||
CharClass::Newline
|
||||
} else if character.is_whitespace() {
|
||||
1
|
||||
CharClass::Whitespace
|
||||
} else if character.is_alphanumeric() || character == '_' {
|
||||
CharClass::Identifier
|
||||
} else {
|
||||
2
|
||||
CharClass::Punctuation
|
||||
}
|
||||
}
|
||||
|
||||
fn is_identifier_boundary(previous: char, current: char, next: Option<char>) -> bool {
|
||||
(current.is_uppercase() && (previous.is_lowercase() || previous.is_numeric()))
|
||||
|| (current.is_uppercase()
|
||||
&& previous.is_uppercase()
|
||||
&& next.is_some_and(|next| next.is_lowercase()))
|
||||
}
|
||||
|
||||
fn push_identifier_tokens<'a>(identifier: &'a str, tokens: &mut Vec<&'a str>) {
|
||||
let characters: Vec<(usize, char)> = identifier.char_indices().collect();
|
||||
let mut segment_start = 0;
|
||||
let mut index = 0;
|
||||
|
||||
while index < characters.len() {
|
||||
let (byte_index, character) = characters[index];
|
||||
|
||||
if character == '_' {
|
||||
if segment_start < byte_index {
|
||||
tokens.push(&identifier[segment_start..byte_index]);
|
||||
}
|
||||
|
||||
let mut underscore_end = byte_index + character.len_utf8();
|
||||
index += 1;
|
||||
|
||||
while index < characters.len() && characters[index].1 == '_' {
|
||||
underscore_end = characters[index].0 + characters[index].1.len_utf8();
|
||||
index += 1;
|
||||
}
|
||||
|
||||
tokens.push(&identifier[byte_index..underscore_end]);
|
||||
segment_start = underscore_end;
|
||||
continue;
|
||||
}
|
||||
|
||||
if byte_index > segment_start {
|
||||
let previous = characters[index - 1].1;
|
||||
let next = characters.get(index + 1).map(|(_, character)| *character);
|
||||
|
||||
if is_identifier_boundary(previous, character, next) {
|
||||
tokens.push(&identifier[segment_start..byte_index]);
|
||||
segment_start = byte_index;
|
||||
}
|
||||
}
|
||||
|
||||
index += 1;
|
||||
}
|
||||
|
||||
if segment_start < identifier.len() {
|
||||
tokens.push(&identifier[segment_start..]);
|
||||
}
|
||||
}
|
||||
|
||||
fn push_punctuation_token<'a>(
|
||||
text: &'a str,
|
||||
start: usize,
|
||||
character: char,
|
||||
characters: &mut Peekable<CharIndices<'a>>,
|
||||
tokens: &mut Vec<&'a str>,
|
||||
) {
|
||||
let remaining = &text[start..];
|
||||
|
||||
for punctuation in MULTI_CHAR_PUNCTUATION {
|
||||
if remaining.starts_with(punctuation) {
|
||||
for _ in punctuation.chars().skip(1) {
|
||||
characters.next();
|
||||
}
|
||||
|
||||
tokens.push(&remaining[..punctuation.len()]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let end = start + character.len_utf8();
|
||||
tokens.push(&text[start..end]);
|
||||
}
|
||||
|
||||
pub(crate) fn tokenize(text: &str) -> Vec<&str> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut characters = text.char_indices().peekable();
|
||||
|
||||
while let Some((start, character)) = characters.next() {
|
||||
let class = char_class(character);
|
||||
if class == 2 {
|
||||
tokens.push(&text[start..start + character.len_utf8()]);
|
||||
continue;
|
||||
}
|
||||
match char_class(character) {
|
||||
CharClass::Identifier => {
|
||||
let mut end = start + character.len_utf8();
|
||||
|
||||
let mut end = start + character.len_utf8();
|
||||
while let Some(&(_, next_character)) = characters.peek() {
|
||||
if char_class(next_character) != class {
|
||||
break;
|
||||
while let Some(&(next_start, next_character)) = characters.peek() {
|
||||
if char_class(next_character) != CharClass::Identifier {
|
||||
break;
|
||||
}
|
||||
|
||||
end = next_start + next_character.len_utf8();
|
||||
characters.next();
|
||||
}
|
||||
|
||||
push_identifier_tokens(&text[start..end], &mut tokens);
|
||||
}
|
||||
CharClass::Newline => {
|
||||
let mut end = start + character.len_utf8();
|
||||
|
||||
while let Some(&(next_start, next_character)) = characters.peek() {
|
||||
if char_class(next_character) != CharClass::Newline {
|
||||
break;
|
||||
}
|
||||
|
||||
end = next_start + next_character.len_utf8();
|
||||
characters.next();
|
||||
}
|
||||
|
||||
tokens.push(&text[start..end]);
|
||||
}
|
||||
CharClass::Whitespace => {
|
||||
let mut end = start + character.len_utf8();
|
||||
|
||||
while let Some(&(next_start, next_character)) = characters.peek() {
|
||||
if char_class(next_character) != CharClass::Whitespace {
|
||||
break;
|
||||
}
|
||||
|
||||
end = next_start + next_character.len_utf8();
|
||||
characters.next();
|
||||
}
|
||||
|
||||
tokens.push(&text[start..end]);
|
||||
}
|
||||
CharClass::Punctuation => {
|
||||
push_punctuation_token(text, start, character, &mut characters, &mut tokens);
|
||||
}
|
||||
end += next_character.len_utf8();
|
||||
characters.next();
|
||||
}
|
||||
tokens.push(&text[start..end]);
|
||||
}
|
||||
|
||||
tokens
|
||||
|
|
@ -38,17 +163,58 @@ mod tests {
|
|||
use super::tokenize;
|
||||
|
||||
#[test]
|
||||
fn tokenizes_code_like_text() {
|
||||
fn tokenizes_code() {
|
||||
assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
|
||||
assert_eq!(
|
||||
tokenize("foo_bar123 + baz"),
|
||||
vec!["foo_bar123", " ", "+", " ", "baz"]
|
||||
vec!["foo", "_", "bar123", " ", "+", " ", "baz"]
|
||||
);
|
||||
assert_eq!(
|
||||
tokenize("print(\"hello\")"),
|
||||
vec!["print", "(", "\"", "hello", "\"", ")"]
|
||||
);
|
||||
assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
|
||||
assert_eq!(tokenize("hello_world"), vec!["hello", "_", "world"]);
|
||||
assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokenizes_identifier_case_styles() {
|
||||
assert_eq!(
|
||||
tokenize("camelCase PascalCase snake_case"),
|
||||
vec![
|
||||
"camel", "Case", " ", "Pascal", "Case", " ", "snake", "_", "case"
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
tokenize("myHTTPServer __private_value foo__bar"),
|
||||
vec![
|
||||
"my", "HTTP", "Server", " ", "__", "private", "_", "value", " ", "foo", "__", "bar"
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
tokenize("XMLHttpRequest Version2Update"),
|
||||
vec!["XML", "Http", "Request", " ", "Version2", "Update"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokenizes_grouped_punctuation() {
|
||||
assert_eq!(
|
||||
tokenize("a::b -> c != d ..= e"),
|
||||
vec![
|
||||
"a", "::", "b", " ", "->", " ", "c", " ", "!=", " ", "d", " ", "..=", " ", "e"
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
tokenize("foo?.bar ?? baz"),
|
||||
vec!["foo", "?.", "bar", " ", "??", " ", "baz"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokenize_whitespace_runs() {
|
||||
assert_eq!(tokenize(" "), vec![" "]);
|
||||
assert_eq!(tokenize(" \n foo"), vec![" ", "\n", " ", "foo"]);
|
||||
assert_eq!(tokenize("\r\n\nfoo"), vec!["\r\n\n", "foo"]);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ extend-exclude = [
|
|||
"crates/gpui_macos/src/dispatcher.rs",
|
||||
# Tests contain partially incomplete words (by design)
|
||||
"crates/edit_prediction_cli/src/split_commit.rs",
|
||||
"crates/edit_prediction_metrics/src/kept_rate.rs",
|
||||
# Eval examples contain intentionally partial words (e.g. "secur" for "secure")
|
||||
"crates/edit_prediction_cli/evals/",
|
||||
# Tests contain `baˇr` that cause `"ba" should be "by" or "be".`-like false-positives
|
||||
|
|
|
|||
Loading…
Reference in a new issue