From 373062190669c1b7f72262c0435f7df0b0d553e3 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Mon, 4 May 2026 12:53:36 +0300 Subject: [PATCH] ep: Move scores aggegation to edit_prediction_metrics (#55609) This way, it can be shared with Python bindings. Release Notes: - N/A --- crates/edit_prediction_cli/src/example.rs | 70 +-- crates/edit_prediction_cli/src/main.rs | 1 - crates/edit_prediction_cli/src/metrics.rs | 10 +- .../src/reversal_tracking.rs | 17 - crates/edit_prediction_cli/src/score.rs | 521 ++---------------- .../src/edit_prediction_metrics.rs | 7 + .../src/prediction_score.rs | 319 +++++++++++ crates/edit_prediction_metrics/src/summary.rs | 293 ++++++++++ 8 files changed, 672 insertions(+), 566 deletions(-) delete mode 100644 crates/edit_prediction_cli/src/reversal_tracking.rs create mode 100644 crates/edit_prediction_metrics/src/prediction_score.rs create mode 100644 crates/edit_prediction_metrics/src/summary.rs diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 516f77ce2cb..0b5a75260fc 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,5 +1,4 @@ use crate::PredictionProvider; -use crate::metrics::ClassificationMetrics; use crate::paths::WORKTREES_DIR; use crate::qa::QaResult; use anyhow::{Context as _, Result}; @@ -149,74 +148,7 @@ where Ok(opt.unwrap_or_default()) } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ExampleScore { - pub delta_chr_f: f32, - #[serde(default)] - pub delta_chr_f_true_positives: usize, - #[serde(default)] - pub delta_chr_f_false_positives: usize, - #[serde(default)] - pub delta_chr_f_false_negatives: usize, - #[serde(default)] - pub delta_chr_f_precision: f64, - #[serde(default)] - pub delta_chr_f_recall: f64, - #[serde(default)] - pub delta_chr_f_beta: f64, - pub braces_disbalance: usize, - #[serde(default)] - pub exact_lines_tp: usize, - #[serde(default)] - pub exact_lines_fp: usize, - #[serde(default)] - pub exact_lines_fn: usize, - #[serde(default)] - pub reversal_ratio: f32, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cursor_distance: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cursor_exact_match: Option, - pub wrong_editable_region: Option, - #[serde(default)] - pub has_isolated_whitespace_changes: bool, - #[serde(default)] - pub inserted_tokens: usize, - #[serde(default)] - pub deleted_tokens: usize, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub kept_rate: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub recall_rate: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub kept_chars: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub correctly_deleted_chars: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub discarded_chars: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cumulative_logprob: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub avg_logprob: Option, -} - -impl ExampleScore { - pub fn delta_chr_f_counts(&self) -> ClassificationMetrics { - ClassificationMetrics { - true_positives: self.delta_chr_f_true_positives, - false_positives: self.delta_chr_f_false_positives, - false_negatives: self.delta_chr_f_false_negatives, - } - } - - pub fn exact_lines_counts(&self) -> ClassificationMetrics { - ClassificationMetrics { - true_positives: self.exact_lines_tp, - false_positives: self.exact_lines_fp, - false_negatives: self.exact_lines_fn, - } - } -} +pub type ExampleScore = edit_prediction_metrics::PredictionScore; impl Example { pub fn repo_name(&self) -> Result> { diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 0ab16690e6c..e15a65a5980 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -19,7 +19,6 @@ mod qa; mod reorder_patch; mod repair; mod retrieve_context; -mod reversal_tracking; mod score; mod split_commit; mod split_dataset; diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 916d1498e6e..4bb8f22e2de 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,7 +1,5 @@ #![allow(unused_imports)] -use crate::example::ActualCursor; - pub use edit_prediction_metrics::ClassificationMetrics; pub use edit_prediction_metrics::Counts; pub use edit_prediction_metrics::DeltaChrFMetrics; @@ -14,11 +12,5 @@ pub use edit_prediction_metrics::delta_chr_f; pub use edit_prediction_metrics::delta_chr_f_beta; pub use edit_prediction_metrics::exact_lines_match; pub use edit_prediction_metrics::extract_changed_lines_from_diff; +pub use edit_prediction_metrics::has_isolated_whitespace_changes; pub use edit_prediction_metrics::is_editable_region_correct; - -pub fn has_isolated_whitespace_changes(patch_str: &str, cursor: Option<&ActualCursor>) -> bool { - edit_prediction_metrics::has_isolated_whitespace_changes( - patch_str, - cursor.map(|cursor| cursor.row), - ) -} diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs deleted file mode 100644 index 58d52ed84e6..00000000000 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::path::Path; - -use zeta_prompt::ZetaPromptInput; - -pub fn compute_prediction_reversal_ratio( - prompt_inputs: &ZetaPromptInput, - predicted_content: &str, - cursor_path: &Path, -) -> f32 { - edit_prediction_metrics::compute_prediction_reversal_ratio_from_history( - prompt_inputs.cursor_excerpt.as_ref(), - &prompt_inputs.events, - prompt_inputs.excerpt_start_row, - predicted_content, - cursor_path, - ) -} diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 5e7721e84f7..48ce081f429 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -1,22 +1,21 @@ use crate::{ PredictArgs, PredictionProvider, - example::{ActualCursor, Example, ExampleScore}, + example::Example, format_prompt::TeacherPrompt, headless::EpAppState, - metrics, parse_output::parse_prediction_output, predict::run_prediction, progress::{ExampleProgress, Step}, - reversal_tracking, }; use anyhow::Context as _; +use edit_prediction_metrics::{ + ActualPredictionCursor, PredictionReversalContext, PredictionScoringInput, +}; use gpui::AsyncApp; -use serde::Serialize; use std::fs::File; use std::io::BufWriter; use std::path::Path; use std::sync::Arc; -use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset}; pub async fn run_scoring( example: &mut Example, @@ -37,18 +36,6 @@ pub async fn run_scoring( let original_text: &str = prompt_inputs.cursor_excerpt.as_ref(); let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions(); - let expected_texts: Vec = expected_patches_with_cursors - .iter() - .map(|(patch, _)| { - apply_diff_to_string(patch, original_text) - .with_context(|| format!("Expected patch did not apply for {}", example.spec.name)) - }) - .collect::, _>>()?; - - // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets. - // The actual_cursor_offset from Teacher is relative to the editable region, while the expected - // cursor from the patch is relative to the hunk. We need to apply the patch to the editable - // region to find where the hunk matched, then compute the expected cursor position. let old_editable_region = if let Some(p) = example.prompt.as_ref() { if matches!( p.provider, @@ -65,33 +52,12 @@ pub async fn run_scoring( None }; - let zero_scores = ExampleScore { - delta_chr_f: 0.0, - delta_chr_f_true_positives: 0, - delta_chr_f_false_positives: 0, - delta_chr_f_false_negatives: 0, - delta_chr_f_precision: 0.0, - delta_chr_f_recall: 0.0, - delta_chr_f_beta: metrics::delta_chr_f_beta(), - braces_disbalance: 0, - exact_lines_tp: 0, - exact_lines_fp: 0, - exact_lines_fn: 0, - reversal_ratio: 0.0, - cursor_distance: None, - cursor_exact_match: None, - wrong_editable_region: None, - has_isolated_whitespace_changes: false, - inserted_tokens: 0, - deleted_tokens: 0, - kept_rate: None, - recall_rate: None, - kept_chars: None, - correctly_deleted_chars: None, - discarded_chars: None, - cumulative_logprob: None, - avg_logprob: None, - }; + let prepared_expected_patches = edit_prediction_metrics::prepare_expected_patches( + &expected_patches_with_cursors, + original_text, + old_editable_region.as_deref(), + ) + .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))?; let cursor_path = example.spec.cursor_path.as_ref(); @@ -104,162 +70,36 @@ pub async fn run_scoring( .map(|(patch, _)| patch) }); - let Some(actual_patch) = actual_patch else { - scores.push(zero_scores.clone()); - continue; - }; + let actual_cursor = + prediction + .actual_cursor + .as_ref() + .map(|cursor| ActualPredictionCursor { + row: cursor.row, + editable_region_offset: cursor.editable_region_offset, + }); - let token_changes = metrics::count_patch_token_changes(&actual_patch); - - let actual_text = match apply_diff_to_string(&actual_patch, original_text) { - Ok(text) => text, - Err(_) => { - let mut s = zero_scores.clone(); - s.inserted_tokens = token_changes.inserted_tokens; - s.deleted_tokens = token_changes.deleted_tokens; - scores.push(s); - continue; - } - }; - - let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default(); - let mut best_expected_cursor: Option = None; - let mut best_patch_idx: Option = None; - let mut best_expected_text: Option<&str> = None; - - for (idx, expected) in expected_texts.iter().enumerate() { - let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text); - if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score { - best_delta_chr_f_metrics = delta_chr_f_metrics; - best_patch_idx = Some(idx); - best_expected_text = Some(expected); - } - } - - if let Some(idx) = best_patch_idx { - // Get the raw cursor offset from the expected patch (relative to hunk new text) - let expected_cursor_in_patch = expected_patches_with_cursors - .get(idx) - .and_then(|(_, cursor)| *cursor); - - // For Teacher prompts, we need to apply the patch to the editable region - // to find where the hunk matched, then compute the actual cursor position - if let (Some(editable_region), Some(cursor_in_patch)) = - (&old_editable_region, expected_cursor_in_patch) - { - let (patch, _) = &expected_patches_with_cursors[idx]; - if let Ok((_, hunk_offset)) = - apply_diff_to_string_with_hunk_offset(patch, editable_region) - { - let hunk_start = hunk_offset.unwrap_or(0); - best_expected_cursor = Some(hunk_start + cursor_in_patch); - } - } else { - // For non-Teacher prompts or if we can't compute, use raw offset - best_expected_cursor = expected_cursor_in_patch; - } - } - - let disbalance_before = metrics::braces_disbalance(&original_text); - let disbalance_after = metrics::braces_disbalance(&actual_text); - let braces_disbalance = disbalance_after.saturating_sub(disbalance_before); - - // Compute exact lines match against best matching expected patch - let best_exact_lines = expected_patches_with_cursors - .iter() - .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch)) - .max_by_key(|m| m.true_positives) - .unwrap_or_default(); - - // Compute reversal ratio - let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio( - prompt_inputs, - &actual_text, - cursor_path, - ); - - // Compute cursor position metrics - let (cursor_distance, cursor_exact_match) = - compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref()); - - // Compute approximation of editable region correctness - let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch)); - - // Check for isolated whitespace changes. - let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes( - &actual_patch, - prediction.actual_cursor.as_ref(), - ); - - let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) = - best_expected_text - .map(|reference_text| { - let result = - metrics::compute_kept_rate(original_text, &actual_text, reference_text); - ( - Some(result.kept_rate), - Some(result.recall_rate), - Some(result.kept_chars), - Some(result.correctly_deleted_chars), - Some(result.discarded_chars), - ) - }) - .unwrap_or((None, None, None, None, None)); - - scores.push(ExampleScore { - delta_chr_f: best_delta_chr_f_metrics.score as f32, - delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives, - delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives, - delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives, - delta_chr_f_precision: best_delta_chr_f_metrics.precision, - delta_chr_f_recall: best_delta_chr_f_metrics.recall, - delta_chr_f_beta: best_delta_chr_f_metrics.beta, - braces_disbalance, - exact_lines_tp: best_exact_lines.true_positives, - exact_lines_fp: best_exact_lines.false_positives, - exact_lines_fn: best_exact_lines.false_negatives, - reversal_ratio, - cursor_distance, - cursor_exact_match, - wrong_editable_region, - has_isolated_whitespace_changes, - inserted_tokens: token_changes.inserted_tokens, - deleted_tokens: token_changes.deleted_tokens, - kept_rate, - recall_rate, - kept_chars, - correctly_deleted_chars, - discarded_chars, - cumulative_logprob: prediction.cumulative_logprob, - avg_logprob: prediction.avg_logprob, - }); + scores.push(edit_prediction_metrics::score_prediction( + PredictionScoringInput { + original_text, + expected_patches: &prepared_expected_patches, + actual_patch: actual_patch.as_deref(), + actual_cursor, + reversal_context: Some(PredictionReversalContext { + edit_history: &prompt_inputs.events, + excerpt_start_row: prompt_inputs.excerpt_start_row, + cursor_path, + }), + cumulative_logprob: prediction.cumulative_logprob, + avg_logprob: prediction.avg_logprob, + }, + )); } example.score = scores; Ok(()) } -fn compute_cursor_metrics( - expected_cursor_editable_region_offset: Option, - actual_cursor: Option<&ActualCursor>, -) -> (Option, Option) { - match (expected_cursor_editable_region_offset, actual_cursor) { - (Some(expected), Some(actual)) => { - let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default()); - let exact_match = distance == 0; - (Some(distance), Some(exact_match)) - } - (None, None) => { - // Neither has cursor position - skip cursor scoring - (None, None) - } - (Some(_), None) | (None, Some(_)) => { - // Only one has cursor position - count as miss - (None, Some(false)) - } - } -} - pub fn print_report(examples: &[Example], verbose: bool) { const MAX_EXAMPLES_DEFAULT: usize = 20; use crate::metrics::ClassificationMetrics; @@ -633,286 +473,27 @@ fn truncate_name(name: &str, max_len: usize) -> String { } } -#[derive(Serialize)] -pub struct SummaryJson { - pub total_examples: usize, - pub avg_delta_chr_f: f32, - pub delta_chr_f_beta: f64, - pub delta_chr_f_true_positives: usize, - pub delta_chr_f_false_positives: usize, - pub delta_chr_f_false_negatives: usize, - pub delta_chr_f_precision: f64, - pub delta_chr_f_recall: f64, - pub avg_braces_disbalance: f32, - pub exact_lines_true_positives: usize, - pub exact_lines_false_positives: usize, - pub exact_lines_false_negatives: usize, - pub exact_lines_precision: f64, - pub exact_lines_recall: f64, - pub exact_lines_f1: f64, - pub avg_reversal_ratio: f32, - #[serde(skip_serializing_if = "Option::is_none")] - pub qa_avg_reverts_edits: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub qa_avg_confidence: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub cursor_exact_match_rate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub cursor_avg_distance: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub cursor_total_evaluated: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub wrong_editable_region_rate: Option, - pub isolated_whitespace_rate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub avg_kept_rate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub avg_recall_rate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub total_kept_chars: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub total_correctly_deleted_chars: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub total_discarded_chars: Option, -} +pub type SummaryJson = edit_prediction_metrics::SummaryJson; pub fn compute_summary(examples: &[Example]) -> SummaryJson { - use crate::metrics::ClassificationMetrics; + edit_prediction_metrics::compute_summary(examples.iter().flat_map(|example| { + example + .score + .iter() + .enumerate() + .map(move |(score_idx, score)| { + let qa = example + .qa + .get(score_idx) + .and_then(|qa| qa.as_ref()) + .map(|qa| edit_prediction_metrics::QaSummaryData { + reverts_edits: qa.reverts_edits, + confidence: qa.confidence, + }); - let mut all_delta_chr_f_scores = Vec::new(); - let mut all_reversal_ratios = Vec::new(); - let mut braces_disbalance_sum: usize = 0; - let mut total_delta_chr_f = ClassificationMetrics::default(); - let mut total_delta_chr_f_precision = 0.0; - let mut total_delta_chr_f_recall = 0.0; - let mut delta_chr_f_beta = 0.0; - let mut total_exact_lines = ClassificationMetrics::default(); - let mut total_scores: usize = 0; - let mut qa_reverts_count: usize = 0; - let mut qa_reverts_total: usize = 0; - let mut qa_confidence_sum: u64 = 0; - let mut qa_confidence_count: usize = 0; - let mut cursor_exact_matches: usize = 0; - let mut cursor_total: usize = 0; - let mut cursor_distance_sum: usize = 0; - let mut cursor_distance_count: usize = 0; - let mut wrong_editable_region_count: usize = 0; - let mut wrong_editable_region_total: usize = 0; - let mut isolated_whitespace_count: usize = 0; - let mut kept_rate_sum: f64 = 0.0; - let mut kept_rate_count: usize = 0; - let mut kept_chars_total: usize = 0; - let mut kept_chars_count: usize = 0; - let mut correctly_deleted_chars_total: usize = 0; - let mut correctly_deleted_chars_count: usize = 0; - let mut discarded_chars_total: usize = 0; - let mut discarded_chars_count: usize = 0; - let mut recall_rate_sum: f64 = 0.0; - let mut recall_rate_count: usize = 0; - - for example in examples { - for (score_idx, score) in example.score.iter().enumerate() { - all_delta_chr_f_scores.push(score.delta_chr_f); - all_reversal_ratios.push(score.reversal_ratio); - total_scores += 1; - braces_disbalance_sum += score.braces_disbalance; - total_delta_chr_f.accumulate(&score.delta_chr_f_counts()); - total_delta_chr_f_precision += score.delta_chr_f_precision; - total_delta_chr_f_recall += score.delta_chr_f_recall; - delta_chr_f_beta = score.delta_chr_f_beta; - total_exact_lines.accumulate(&score.exact_lines_counts()); - - // Accumulate QA metrics - if let Some(Some(qa)) = example.qa.get(score_idx) { - if let Some(reverts) = qa.reverts_edits { - qa_reverts_total += 1; - if reverts { - qa_reverts_count += 1; - } - } - if let Some(conf) = qa.confidence { - qa_confidence_sum += conf as u64; - qa_confidence_count += 1; - } - } - - // Accumulate wrong editable region metrics - if let Some(wrong) = score.wrong_editable_region { - wrong_editable_region_total += 1; - if wrong { - wrong_editable_region_count += 1; - } - } - - // Accumulate isolated whitespace metrics - if score.has_isolated_whitespace_changes { - isolated_whitespace_count += 1; - } - - // Accumulate kept and recall rate metrics - if let Some(kr) = score.kept_rate { - kept_rate_sum += kr; - kept_rate_count += 1; - } - if let Some(kept_chars) = score.kept_chars { - kept_chars_total += kept_chars; - kept_chars_count += 1; - } - if let Some(correctly_deleted_chars) = score.correctly_deleted_chars { - correctly_deleted_chars_total += correctly_deleted_chars; - correctly_deleted_chars_count += 1; - } - if let Some(discarded_chars) = score.discarded_chars { - discarded_chars_total += discarded_chars; - discarded_chars_count += 1; - } - if let Some(rr) = score.recall_rate { - recall_rate_sum += rr; - recall_rate_count += 1; - } - - // Accumulate cursor metrics - if let Some(exact_match) = score.cursor_exact_match { - cursor_total += 1; - if exact_match { - cursor_exact_matches += 1; - } - } - if let Some(dist) = score.cursor_distance { - cursor_distance_sum += dist; - cursor_distance_count += 1; - } - } - } - - let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() { - 0.0 - } else { - all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32 - }; - - let avg_reversal_ratio = if all_reversal_ratios.is_empty() { - 0.0 - } else { - all_reversal_ratios.iter().sum::() / all_reversal_ratios.len() as f32 - }; - - let avg_braces_disbalance = if total_scores == 0 { - 0.0 - } else { - braces_disbalance_sum as f32 / total_scores as f32 - }; - - let qa_avg_reverts_edits = if qa_reverts_total > 0 { - Some(qa_reverts_count as f32 / qa_reverts_total as f32) - } else { - None - }; - - let qa_avg_confidence = if qa_confidence_count > 0 { - Some(qa_confidence_sum as f32 / qa_confidence_count as f32) - } else { - None - }; - - let cursor_exact_match_rate = if cursor_total > 0 { - Some(cursor_exact_matches as f32 / cursor_total as f32) - } else { - None - }; - - let cursor_avg_distance = if cursor_distance_count > 0 { - Some(cursor_distance_sum as f32 / cursor_distance_count as f32) - } else { - None - }; - - let cursor_total_evaluated = if cursor_total > 0 { - Some(cursor_total) - } else { - None - }; - - let wrong_editable_region_rate = if wrong_editable_region_total > 0 { - Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32) - } else { - None - }; - - let isolated_whitespace_rate = if total_scores > 0 { - Some(isolated_whitespace_count as f32 / total_scores as f32) - } else { - None - }; - - let avg_kept_rate = if kept_rate_count > 0 { - Some(kept_rate_sum / kept_rate_count as f64) - } else { - None - }; - - let avg_recall_rate = if recall_rate_count > 0 { - Some(recall_rate_sum / recall_rate_count as f64) - } else { - None - }; - - let total_kept_chars = if kept_chars_count > 0 { - Some(kept_chars_total) - } else { - None - }; - - let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 { - Some(correctly_deleted_chars_total) - } else { - None - }; - - let total_discarded_chars = if discarded_chars_count > 0 { - Some(discarded_chars_total) - } else { - None - }; - - SummaryJson { - total_examples: total_scores, - avg_delta_chr_f, - delta_chr_f_beta, - delta_chr_f_true_positives: total_delta_chr_f.true_positives, - delta_chr_f_false_positives: total_delta_chr_f.false_positives, - delta_chr_f_false_negatives: total_delta_chr_f.false_negatives, - delta_chr_f_precision: if total_scores == 0 { - 0.0 - } else { - total_delta_chr_f_precision / total_scores as f64 - }, - delta_chr_f_recall: if total_scores == 0 { - 0.0 - } else { - total_delta_chr_f_recall / total_scores as f64 - }, - avg_braces_disbalance, - exact_lines_true_positives: total_exact_lines.true_positives, - exact_lines_false_positives: total_exact_lines.false_positives, - exact_lines_false_negatives: total_exact_lines.false_negatives, - exact_lines_precision: total_exact_lines.precision(), - exact_lines_recall: total_exact_lines.recall(), - exact_lines_f1: total_exact_lines.f1(), - avg_reversal_ratio, - qa_avg_reverts_edits, - qa_avg_confidence, - cursor_exact_match_rate, - cursor_avg_distance, - cursor_total_evaluated, - wrong_editable_region_rate, - isolated_whitespace_rate, - avg_kept_rate, - avg_recall_rate, - total_kept_chars, - total_correctly_deleted_chars, - total_discarded_chars, - } + edit_prediction_metrics::PredictionSummaryInput { score, qa } + }) + })) } pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> { diff --git a/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs b/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs index 3afe02fd083..74ad639b7e9 100644 --- a/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs +++ b/crates/edit_prediction_metrics/src/edit_prediction_metrics.rs @@ -1,6 +1,8 @@ mod kept_rate; mod patch_metrics; +mod prediction_score; mod reversal; +mod summary; mod tokenize; mod tree_sitter; @@ -22,5 +24,10 @@ pub use patch_metrics::extract_changed_lines_from_diff; pub use patch_metrics::has_isolated_whitespace_changes; pub use patch_metrics::is_editable_region_correct; pub use patch_metrics::reconstruct_texts_from_diff; +pub use prediction_score::{ + ActualPredictionCursor, PredictionReversalContext, PredictionScore, PredictionScoringInput, + PrepareExpectedPatchError, PreparedExpectedPatch, prepare_expected_patches, score_prediction, +}; pub use reversal::compute_prediction_reversal_ratio_from_history; +pub use summary::{PredictionSummaryInput, QaSummaryData, SummaryJson, compute_summary}; pub use tree_sitter::count_tree_sitter_errors; diff --git a/crates/edit_prediction_metrics/src/prediction_score.rs b/crates/edit_prediction_metrics/src/prediction_score.rs new file mode 100644 index 00000000000..55c1d828762 --- /dev/null +++ b/crates/edit_prediction_metrics/src/prediction_score.rs @@ -0,0 +1,319 @@ +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt; +use std::path::Path; +use std::sync::Arc; +use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset}; + +use crate::patch_metrics::{ + ClassificationMetrics, DeltaChrFMetrics, braces_disbalance, count_patch_token_changes, + delta_chr_f, delta_chr_f_beta, exact_lines_match, has_isolated_whitespace_changes, + is_editable_region_correct, +}; +use crate::reversal::compute_prediction_reversal_ratio_from_history; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PredictionScore { + pub delta_chr_f: f32, + #[serde(default)] + pub delta_chr_f_true_positives: usize, + #[serde(default)] + pub delta_chr_f_false_positives: usize, + #[serde(default)] + pub delta_chr_f_false_negatives: usize, + #[serde(default)] + pub delta_chr_f_precision: f64, + #[serde(default)] + pub delta_chr_f_recall: f64, + #[serde(default)] + pub delta_chr_f_beta: f64, + pub braces_disbalance: usize, + #[serde(default)] + pub exact_lines_tp: usize, + #[serde(default)] + pub exact_lines_fp: usize, + #[serde(default)] + pub exact_lines_fn: usize, + #[serde(default)] + pub reversal_ratio: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cursor_distance: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cursor_exact_match: Option, + pub wrong_editable_region: Option, + #[serde(default)] + pub has_isolated_whitespace_changes: bool, + #[serde(default)] + pub inserted_tokens: usize, + #[serde(default)] + pub deleted_tokens: usize, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub kept_rate: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub recall_rate: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub kept_chars: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub correctly_deleted_chars: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub discarded_chars: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cumulative_logprob: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub avg_logprob: Option, +} + +impl PredictionScore { + pub fn zero() -> Self { + Self { + delta_chr_f: 0.0, + delta_chr_f_true_positives: 0, + delta_chr_f_false_positives: 0, + delta_chr_f_false_negatives: 0, + delta_chr_f_precision: 0.0, + delta_chr_f_recall: 0.0, + delta_chr_f_beta: delta_chr_f_beta(), + braces_disbalance: 0, + exact_lines_tp: 0, + exact_lines_fp: 0, + exact_lines_fn: 0, + reversal_ratio: 0.0, + cursor_distance: None, + cursor_exact_match: None, + wrong_editable_region: None, + has_isolated_whitespace_changes: false, + inserted_tokens: 0, + deleted_tokens: 0, + kept_rate: None, + recall_rate: None, + kept_chars: None, + correctly_deleted_chars: None, + discarded_chars: None, + cumulative_logprob: None, + avg_logprob: None, + } + } + + pub fn delta_chr_f_counts(&self) -> ClassificationMetrics { + ClassificationMetrics { + true_positives: self.delta_chr_f_true_positives, + false_positives: self.delta_chr_f_false_positives, + false_negatives: self.delta_chr_f_false_negatives, + } + } + + pub fn exact_lines_counts(&self) -> ClassificationMetrics { + ClassificationMetrics { + true_positives: self.exact_lines_tp, + false_positives: self.exact_lines_fp, + false_negatives: self.exact_lines_fn, + } + } +} + +impl Default for PredictionScore { + fn default() -> Self { + Self::zero() + } +} + +#[derive(Clone, Debug)] +pub struct PreparedExpectedPatch { + pub patch: String, + pub text: String, + pub cursor_editable_region_offset: Option, +} + +#[derive(Clone, Debug)] +pub struct PrepareExpectedPatchError { + message: String, +} + +impl fmt::Display for PrepareExpectedPatchError { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + self.message.fmt(formatter) + } +} + +impl Error for PrepareExpectedPatchError {} + +pub fn prepare_expected_patches( + expected_patches_with_cursors: &[(String, Option)], + original_text: &str, + old_editable_region: Option<&str>, +) -> Result, PrepareExpectedPatchError> { + expected_patches_with_cursors + .iter() + .map(|(patch, cursor_in_patch)| { + let text = apply_diff_to_string(patch, original_text).map_err(|error| { + PrepareExpectedPatchError { + message: error.to_string(), + } + })?; + let cursor_editable_region_offset = + if let (Some(editable_region), Some(cursor_in_patch)) = + (old_editable_region, *cursor_in_patch) + { + match apply_diff_to_string_with_hunk_offset(patch, editable_region) { + Ok((_, hunk_offset)) => Some(hunk_offset.unwrap_or(0) + cursor_in_patch), + Err(_) => None, + } + } else { + *cursor_in_patch + }; + + Ok(PreparedExpectedPatch { + patch: patch.clone(), + text, + cursor_editable_region_offset, + }) + }) + .collect() +} + +#[derive(Clone, Copy, Debug)] +pub struct ActualPredictionCursor { + pub row: u32, + pub editable_region_offset: Option, +} + +#[derive(Clone, Copy, Debug)] +pub struct PredictionReversalContext<'a> { + pub edit_history: &'a [Arc], + pub excerpt_start_row: Option, + pub cursor_path: &'a Path, +} + +#[derive(Clone, Copy, Debug)] +pub struct PredictionScoringInput<'a> { + pub original_text: &'a str, + pub expected_patches: &'a [PreparedExpectedPatch], + pub actual_patch: Option<&'a str>, + pub actual_cursor: Option, + pub reversal_context: Option>, + pub cumulative_logprob: Option, + pub avg_logprob: Option, +} + +pub fn score_prediction(input: PredictionScoringInput<'_>) -> PredictionScore { + let Some(actual_patch) = input.actual_patch else { + return PredictionScore::zero(); + }; + + let token_changes = count_patch_token_changes(actual_patch); + + let actual_text = match apply_diff_to_string(actual_patch, input.original_text) { + Ok(text) => text, + Err(_) => { + let mut score = PredictionScore::zero(); + score.inserted_tokens = token_changes.inserted_tokens; + score.deleted_tokens = token_changes.deleted_tokens; + return score; + } + }; + + let mut best_delta_chr_f_metrics = DeltaChrFMetrics::default(); + let mut best_expected_cursor = None; + let mut best_expected_text = None; + + for expected in input.expected_patches { + let delta_chr_f_metrics = delta_chr_f(input.original_text, &expected.text, &actual_text); + if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score { + best_delta_chr_f_metrics = delta_chr_f_metrics; + best_expected_cursor = expected.cursor_editable_region_offset; + best_expected_text = Some(expected.text.as_str()); + } + } + + let disbalance_before = braces_disbalance(input.original_text); + let disbalance_after = braces_disbalance(&actual_text); + let braces_disbalance = disbalance_after.saturating_sub(disbalance_before); + + let best_exact_lines = input + .expected_patches + .iter() + .map(|expected| exact_lines_match(&expected.patch, actual_patch)) + .max_by_key(|metrics| metrics.true_positives) + .unwrap_or_default(); + + let reversal_ratio = input + .reversal_context + .map(|context| { + compute_prediction_reversal_ratio_from_history( + input.original_text, + context.edit_history, + context.excerpt_start_row, + &actual_text, + context.cursor_path, + ) + }) + .unwrap_or(0.0); + + let (cursor_distance, cursor_exact_match) = + compute_cursor_metrics(best_expected_cursor, input.actual_cursor); + + let wrong_editable_region = Some(!is_editable_region_correct(actual_patch)); + let has_isolated_whitespace_changes = + has_isolated_whitespace_changes(actual_patch, input.actual_cursor.map(|cursor| cursor.row)); + + let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) = + best_expected_text + .map(|reference_text| { + let result = crate::kept_rate::compute_kept_rate( + input.original_text, + &actual_text, + reference_text, + ); + ( + Some(result.kept_rate), + Some(result.recall_rate), + Some(result.kept_chars), + Some(result.correctly_deleted_chars), + Some(result.discarded_chars), + ) + }) + .unwrap_or((None, None, None, None, None)); + + PredictionScore { + delta_chr_f: best_delta_chr_f_metrics.score as f32, + delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives, + delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives, + delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives, + delta_chr_f_precision: best_delta_chr_f_metrics.precision, + delta_chr_f_recall: best_delta_chr_f_metrics.recall, + delta_chr_f_beta: best_delta_chr_f_metrics.beta, + braces_disbalance, + exact_lines_tp: best_exact_lines.true_positives, + exact_lines_fp: best_exact_lines.false_positives, + exact_lines_fn: best_exact_lines.false_negatives, + reversal_ratio, + cursor_distance, + cursor_exact_match, + wrong_editable_region, + has_isolated_whitespace_changes, + inserted_tokens: token_changes.inserted_tokens, + deleted_tokens: token_changes.deleted_tokens, + kept_rate, + recall_rate, + kept_chars, + correctly_deleted_chars, + discarded_chars, + cumulative_logprob: input.cumulative_logprob, + avg_logprob: input.avg_logprob, + } +} + +fn compute_cursor_metrics( + expected_cursor_editable_region_offset: Option, + actual_cursor: Option, +) -> (Option, Option) { + match (expected_cursor_editable_region_offset, actual_cursor) { + (Some(expected), Some(actual)) => { + let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default()); + let exact_match = distance == 0; + (Some(distance), Some(exact_match)) + } + (None, None) => (None, None), + (Some(_), None) | (None, Some(_)) => (None, Some(false)), + } +} diff --git a/crates/edit_prediction_metrics/src/summary.rs b/crates/edit_prediction_metrics/src/summary.rs new file mode 100644 index 00000000000..249ae185755 --- /dev/null +++ b/crates/edit_prediction_metrics/src/summary.rs @@ -0,0 +1,293 @@ +use serde::Serialize; + +use crate::patch_metrics::ClassificationMetrics; +use crate::prediction_score::PredictionScore; + +#[derive(Clone, Copy, Debug, Default)] +pub struct QaSummaryData { + pub reverts_edits: Option, + pub confidence: Option, +} + +#[derive(Clone, Copy, Debug)] +pub struct PredictionSummaryInput<'a> { + pub score: &'a PredictionScore, + pub qa: Option, +} + +#[derive(Clone, Debug, Serialize)] +pub struct SummaryJson { + pub total_examples: usize, + pub avg_delta_chr_f: f32, + pub delta_chr_f_beta: f64, + pub delta_chr_f_true_positives: usize, + pub delta_chr_f_false_positives: usize, + pub delta_chr_f_false_negatives: usize, + pub delta_chr_f_precision: f64, + pub delta_chr_f_recall: f64, + pub avg_braces_disbalance: f32, + pub exact_lines_true_positives: usize, + pub exact_lines_false_positives: usize, + pub exact_lines_false_negatives: usize, + pub exact_lines_precision: f64, + pub exact_lines_recall: f64, + pub exact_lines_f1: f64, + pub avg_reversal_ratio: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub qa_avg_reverts_edits: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub qa_avg_confidence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor_exact_match_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor_avg_distance: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor_total_evaluated: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub wrong_editable_region_rate: Option, + pub isolated_whitespace_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub avg_kept_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub avg_recall_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_kept_chars: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_correctly_deleted_chars: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_discarded_chars: Option, +} + +pub fn compute_summary<'a>( + predictions: impl IntoIterator>, +) -> SummaryJson { + let mut all_delta_chr_f_scores = Vec::new(); + let mut all_reversal_ratios = Vec::new(); + let mut braces_disbalance_sum: usize = 0; + let mut total_delta_chr_f = ClassificationMetrics::default(); + let mut total_delta_chr_f_precision = 0.0; + let mut total_delta_chr_f_recall = 0.0; + let mut delta_chr_f_beta = 0.0; + let mut total_exact_lines = ClassificationMetrics::default(); + let mut total_scores: usize = 0; + let mut qa_reverts_count: usize = 0; + let mut qa_reverts_total: usize = 0; + let mut qa_confidence_sum: u64 = 0; + let mut qa_confidence_count: usize = 0; + let mut cursor_exact_matches: usize = 0; + let mut cursor_total: usize = 0; + let mut cursor_distance_sum: usize = 0; + let mut cursor_distance_count: usize = 0; + let mut wrong_editable_region_count: usize = 0; + let mut wrong_editable_region_total: usize = 0; + let mut isolated_whitespace_count: usize = 0; + let mut kept_rate_sum: f64 = 0.0; + let mut kept_rate_count: usize = 0; + let mut kept_chars_total: usize = 0; + let mut kept_chars_count: usize = 0; + let mut correctly_deleted_chars_total: usize = 0; + let mut correctly_deleted_chars_count: usize = 0; + let mut discarded_chars_total: usize = 0; + let mut discarded_chars_count: usize = 0; + let mut recall_rate_sum: f64 = 0.0; + let mut recall_rate_count: usize = 0; + + for prediction in predictions { + let score = prediction.score; + + all_delta_chr_f_scores.push(score.delta_chr_f); + all_reversal_ratios.push(score.reversal_ratio); + total_scores += 1; + braces_disbalance_sum += score.braces_disbalance; + total_delta_chr_f.accumulate(&score.delta_chr_f_counts()); + total_delta_chr_f_precision += score.delta_chr_f_precision; + total_delta_chr_f_recall += score.delta_chr_f_recall; + delta_chr_f_beta = score.delta_chr_f_beta; + total_exact_lines.accumulate(&score.exact_lines_counts()); + + if let Some(qa) = prediction.qa { + if let Some(reverts) = qa.reverts_edits { + qa_reverts_total += 1; + if reverts { + qa_reverts_count += 1; + } + } + if let Some(confidence) = qa.confidence { + qa_confidence_sum += confidence as u64; + qa_confidence_count += 1; + } + } + + if let Some(wrong) = score.wrong_editable_region { + wrong_editable_region_total += 1; + if wrong { + wrong_editable_region_count += 1; + } + } + + if score.has_isolated_whitespace_changes { + isolated_whitespace_count += 1; + } + + if let Some(kept_rate) = score.kept_rate { + kept_rate_sum += kept_rate; + kept_rate_count += 1; + } + if let Some(kept_chars) = score.kept_chars { + kept_chars_total += kept_chars; + kept_chars_count += 1; + } + if let Some(correctly_deleted_chars) = score.correctly_deleted_chars { + correctly_deleted_chars_total += correctly_deleted_chars; + correctly_deleted_chars_count += 1; + } + if let Some(discarded_chars) = score.discarded_chars { + discarded_chars_total += discarded_chars; + discarded_chars_count += 1; + } + if let Some(recall_rate) = score.recall_rate { + recall_rate_sum += recall_rate; + recall_rate_count += 1; + } + + if let Some(exact_match) = score.cursor_exact_match { + cursor_total += 1; + if exact_match { + cursor_exact_matches += 1; + } + } + if let Some(distance) = score.cursor_distance { + cursor_distance_sum += distance; + cursor_distance_count += 1; + } + } + + let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() { + 0.0 + } else { + all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32 + }; + + let avg_reversal_ratio = if all_reversal_ratios.is_empty() { + 0.0 + } else { + all_reversal_ratios.iter().sum::() / all_reversal_ratios.len() as f32 + }; + + let avg_braces_disbalance = if total_scores == 0 { + 0.0 + } else { + braces_disbalance_sum as f32 / total_scores as f32 + }; + + let qa_avg_reverts_edits = if qa_reverts_total > 0 { + Some(qa_reverts_count as f32 / qa_reverts_total as f32) + } else { + None + }; + + let qa_avg_confidence = if qa_confidence_count > 0 { + Some(qa_confidence_sum as f32 / qa_confidence_count as f32) + } else { + None + }; + + let cursor_exact_match_rate = if cursor_total > 0 { + Some(cursor_exact_matches as f32 / cursor_total as f32) + } else { + None + }; + + let cursor_avg_distance = if cursor_distance_count > 0 { + Some(cursor_distance_sum as f32 / cursor_distance_count as f32) + } else { + None + }; + + let cursor_total_evaluated = if cursor_total > 0 { + Some(cursor_total) + } else { + None + }; + + let wrong_editable_region_rate = if wrong_editable_region_total > 0 { + Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32) + } else { + None + }; + + let isolated_whitespace_rate = if total_scores > 0 { + Some(isolated_whitespace_count as f32 / total_scores as f32) + } else { + None + }; + + let avg_kept_rate = if kept_rate_count > 0 { + Some(kept_rate_sum / kept_rate_count as f64) + } else { + None + }; + + let avg_recall_rate = if recall_rate_count > 0 { + Some(recall_rate_sum / recall_rate_count as f64) + } else { + None + }; + + let total_kept_chars = if kept_chars_count > 0 { + Some(kept_chars_total) + } else { + None + }; + + let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 { + Some(correctly_deleted_chars_total) + } else { + None + }; + + let total_discarded_chars = if discarded_chars_count > 0 { + Some(discarded_chars_total) + } else { + None + }; + + SummaryJson { + total_examples: total_scores, + avg_delta_chr_f, + delta_chr_f_beta, + delta_chr_f_true_positives: total_delta_chr_f.true_positives, + delta_chr_f_false_positives: total_delta_chr_f.false_positives, + delta_chr_f_false_negatives: total_delta_chr_f.false_negatives, + delta_chr_f_precision: if total_scores == 0 { + 0.0 + } else { + total_delta_chr_f_precision / total_scores as f64 + }, + delta_chr_f_recall: if total_scores == 0 { + 0.0 + } else { + total_delta_chr_f_recall / total_scores as f64 + }, + avg_braces_disbalance, + exact_lines_true_positives: total_exact_lines.true_positives, + exact_lines_false_positives: total_exact_lines.false_positives, + exact_lines_false_negatives: total_exact_lines.false_negatives, + exact_lines_precision: total_exact_lines.precision(), + exact_lines_recall: total_exact_lines.recall(), + exact_lines_f1: total_exact_lines.f1(), + avg_reversal_ratio, + qa_avg_reverts_edits, + qa_avg_confidence, + cursor_exact_match_rate, + cursor_avg_distance, + cursor_total_evaluated, + wrong_editable_region_rate, + isolated_whitespace_rate, + avg_kept_rate, + avg_recall_rate, + total_kept_chars, + total_correctly_deleted_chars, + total_discarded_chars, + } +}