mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
ep: Move scores aggegation to edit_prediction_metrics (#55609)
This way, it can be shared with Python bindings. Release Notes: - N/A
This commit is contained in:
parent
7de96710e2
commit
3730621906
8 changed files with 672 additions and 566 deletions
|
|
@ -1,5 +1,4 @@
|
|||
use crate::PredictionProvider;
|
||||
use crate::metrics::ClassificationMetrics;
|
||||
use crate::paths::WORKTREES_DIR;
|
||||
use crate::qa::QaResult;
|
||||
use anyhow::{Context as _, Result};
|
||||
|
|
@ -149,74 +148,7 @@ where
|
|||
Ok(opt.unwrap_or_default())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleScore {
|
||||
pub delta_chr_f: f32,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_true_positives: usize,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_false_positives: usize,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_false_negatives: usize,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_precision: f64,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_recall: f64,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_beta: f64,
|
||||
pub braces_disbalance: usize,
|
||||
#[serde(default)]
|
||||
pub exact_lines_tp: usize,
|
||||
#[serde(default)]
|
||||
pub exact_lines_fp: usize,
|
||||
#[serde(default)]
|
||||
pub exact_lines_fn: usize,
|
||||
#[serde(default)]
|
||||
pub reversal_ratio: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_distance: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_exact_match: Option<bool>,
|
||||
pub wrong_editable_region: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub has_isolated_whitespace_changes: bool,
|
||||
#[serde(default)]
|
||||
pub inserted_tokens: usize,
|
||||
#[serde(default)]
|
||||
pub deleted_tokens: usize,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub kept_rate: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub recall_rate: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub kept_chars: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub correctly_deleted_chars: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub discarded_chars: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cumulative_logprob: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub avg_logprob: Option<f64>,
|
||||
}
|
||||
|
||||
impl ExampleScore {
|
||||
pub fn delta_chr_f_counts(&self) -> ClassificationMetrics {
|
||||
ClassificationMetrics {
|
||||
true_positives: self.delta_chr_f_true_positives,
|
||||
false_positives: self.delta_chr_f_false_positives,
|
||||
false_negatives: self.delta_chr_f_false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exact_lines_counts(&self) -> ClassificationMetrics {
|
||||
ClassificationMetrics {
|
||||
true_positives: self.exact_lines_tp,
|
||||
false_positives: self.exact_lines_fp,
|
||||
false_negatives: self.exact_lines_fn,
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type ExampleScore = edit_prediction_metrics::PredictionScore;
|
||||
|
||||
impl Example {
|
||||
pub fn repo_name(&self) -> Result<RepoName<'_>> {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ mod qa;
|
|||
mod reorder_patch;
|
||||
mod repair;
|
||||
mod retrieve_context;
|
||||
mod reversal_tracking;
|
||||
mod score;
|
||||
mod split_commit;
|
||||
mod split_dataset;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
#![allow(unused_imports)]
|
||||
|
||||
use crate::example::ActualCursor;
|
||||
|
||||
pub use edit_prediction_metrics::ClassificationMetrics;
|
||||
pub use edit_prediction_metrics::Counts;
|
||||
pub use edit_prediction_metrics::DeltaChrFMetrics;
|
||||
|
|
@ -14,11 +12,5 @@ pub use edit_prediction_metrics::delta_chr_f;
|
|||
pub use edit_prediction_metrics::delta_chr_f_beta;
|
||||
pub use edit_prediction_metrics::exact_lines_match;
|
||||
pub use edit_prediction_metrics::extract_changed_lines_from_diff;
|
||||
pub use edit_prediction_metrics::has_isolated_whitespace_changes;
|
||||
pub use edit_prediction_metrics::is_editable_region_correct;
|
||||
|
||||
pub fn has_isolated_whitespace_changes(patch_str: &str, cursor: Option<&ActualCursor>) -> bool {
|
||||
edit_prediction_metrics::has_isolated_whitespace_changes(
|
||||
patch_str,
|
||||
cursor.map(|cursor| cursor.row),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
use std::path::Path;
|
||||
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
pub fn compute_prediction_reversal_ratio(
|
||||
prompt_inputs: &ZetaPromptInput,
|
||||
predicted_content: &str,
|
||||
cursor_path: &Path,
|
||||
) -> f32 {
|
||||
edit_prediction_metrics::compute_prediction_reversal_ratio_from_history(
|
||||
prompt_inputs.cursor_excerpt.as_ref(),
|
||||
&prompt_inputs.events,
|
||||
prompt_inputs.excerpt_start_row,
|
||||
predicted_content,
|
||||
cursor_path,
|
||||
)
|
||||
}
|
||||
|
|
@ -1,22 +1,21 @@
|
|||
use crate::{
|
||||
PredictArgs, PredictionProvider,
|
||||
example::{ActualCursor, Example, ExampleScore},
|
||||
example::Example,
|
||||
format_prompt::TeacherPrompt,
|
||||
headless::EpAppState,
|
||||
metrics,
|
||||
parse_output::parse_prediction_output,
|
||||
predict::run_prediction,
|
||||
progress::{ExampleProgress, Step},
|
||||
reversal_tracking,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction_metrics::{
|
||||
ActualPredictionCursor, PredictionReversalContext, PredictionScoringInput,
|
||||
};
|
||||
use gpui::AsyncApp;
|
||||
use serde::Serialize;
|
||||
use std::fs::File;
|
||||
use std::io::BufWriter;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
|
||||
|
||||
pub async fn run_scoring(
|
||||
example: &mut Example,
|
||||
|
|
@ -37,18 +36,6 @@ pub async fn run_scoring(
|
|||
let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
|
||||
let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
|
||||
|
||||
let expected_texts: Vec<String> = expected_patches_with_cursors
|
||||
.iter()
|
||||
.map(|(patch, _)| {
|
||||
apply_diff_to_string(patch, original_text)
|
||||
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
|
||||
// The actual_cursor_offset from Teacher is relative to the editable region, while the expected
|
||||
// cursor from the patch is relative to the hunk. We need to apply the patch to the editable
|
||||
// region to find where the hunk matched, then compute the expected cursor position.
|
||||
let old_editable_region = if let Some(p) = example.prompt.as_ref() {
|
||||
if matches!(
|
||||
p.provider,
|
||||
|
|
@ -65,33 +52,12 @@ pub async fn run_scoring(
|
|||
None
|
||||
};
|
||||
|
||||
let zero_scores = ExampleScore {
|
||||
delta_chr_f: 0.0,
|
||||
delta_chr_f_true_positives: 0,
|
||||
delta_chr_f_false_positives: 0,
|
||||
delta_chr_f_false_negatives: 0,
|
||||
delta_chr_f_precision: 0.0,
|
||||
delta_chr_f_recall: 0.0,
|
||||
delta_chr_f_beta: metrics::delta_chr_f_beta(),
|
||||
braces_disbalance: 0,
|
||||
exact_lines_tp: 0,
|
||||
exact_lines_fp: 0,
|
||||
exact_lines_fn: 0,
|
||||
reversal_ratio: 0.0,
|
||||
cursor_distance: None,
|
||||
cursor_exact_match: None,
|
||||
wrong_editable_region: None,
|
||||
has_isolated_whitespace_changes: false,
|
||||
inserted_tokens: 0,
|
||||
deleted_tokens: 0,
|
||||
kept_rate: None,
|
||||
recall_rate: None,
|
||||
kept_chars: None,
|
||||
correctly_deleted_chars: None,
|
||||
discarded_chars: None,
|
||||
cumulative_logprob: None,
|
||||
avg_logprob: None,
|
||||
};
|
||||
let prepared_expected_patches = edit_prediction_metrics::prepare_expected_patches(
|
||||
&expected_patches_with_cursors,
|
||||
original_text,
|
||||
old_editable_region.as_deref(),
|
||||
)
|
||||
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))?;
|
||||
|
||||
let cursor_path = example.spec.cursor_path.as_ref();
|
||||
|
||||
|
|
@ -104,162 +70,36 @@ pub async fn run_scoring(
|
|||
.map(|(patch, _)| patch)
|
||||
});
|
||||
|
||||
let Some(actual_patch) = actual_patch else {
|
||||
scores.push(zero_scores.clone());
|
||||
continue;
|
||||
};
|
||||
let actual_cursor =
|
||||
prediction
|
||||
.actual_cursor
|
||||
.as_ref()
|
||||
.map(|cursor| ActualPredictionCursor {
|
||||
row: cursor.row,
|
||||
editable_region_offset: cursor.editable_region_offset,
|
||||
});
|
||||
|
||||
let token_changes = metrics::count_patch_token_changes(&actual_patch);
|
||||
|
||||
let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
|
||||
Ok(text) => text,
|
||||
Err(_) => {
|
||||
let mut s = zero_scores.clone();
|
||||
s.inserted_tokens = token_changes.inserted_tokens;
|
||||
s.deleted_tokens = token_changes.deleted_tokens;
|
||||
scores.push(s);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
|
||||
let mut best_expected_cursor: Option<usize> = None;
|
||||
let mut best_patch_idx: Option<usize> = None;
|
||||
let mut best_expected_text: Option<&str> = None;
|
||||
|
||||
for (idx, expected) in expected_texts.iter().enumerate() {
|
||||
let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
|
||||
if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
|
||||
best_delta_chr_f_metrics = delta_chr_f_metrics;
|
||||
best_patch_idx = Some(idx);
|
||||
best_expected_text = Some(expected);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(idx) = best_patch_idx {
|
||||
// Get the raw cursor offset from the expected patch (relative to hunk new text)
|
||||
let expected_cursor_in_patch = expected_patches_with_cursors
|
||||
.get(idx)
|
||||
.and_then(|(_, cursor)| *cursor);
|
||||
|
||||
// For Teacher prompts, we need to apply the patch to the editable region
|
||||
// to find where the hunk matched, then compute the actual cursor position
|
||||
if let (Some(editable_region), Some(cursor_in_patch)) =
|
||||
(&old_editable_region, expected_cursor_in_patch)
|
||||
{
|
||||
let (patch, _) = &expected_patches_with_cursors[idx];
|
||||
if let Ok((_, hunk_offset)) =
|
||||
apply_diff_to_string_with_hunk_offset(patch, editable_region)
|
||||
{
|
||||
let hunk_start = hunk_offset.unwrap_or(0);
|
||||
best_expected_cursor = Some(hunk_start + cursor_in_patch);
|
||||
}
|
||||
} else {
|
||||
// For non-Teacher prompts or if we can't compute, use raw offset
|
||||
best_expected_cursor = expected_cursor_in_patch;
|
||||
}
|
||||
}
|
||||
|
||||
let disbalance_before = metrics::braces_disbalance(&original_text);
|
||||
let disbalance_after = metrics::braces_disbalance(&actual_text);
|
||||
let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
|
||||
|
||||
// Compute exact lines match against best matching expected patch
|
||||
let best_exact_lines = expected_patches_with_cursors
|
||||
.iter()
|
||||
.map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
|
||||
.max_by_key(|m| m.true_positives)
|
||||
.unwrap_or_default();
|
||||
|
||||
// Compute reversal ratio
|
||||
let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
|
||||
prompt_inputs,
|
||||
&actual_text,
|
||||
cursor_path,
|
||||
);
|
||||
|
||||
// Compute cursor position metrics
|
||||
let (cursor_distance, cursor_exact_match) =
|
||||
compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
|
||||
|
||||
// Compute approximation of editable region correctness
|
||||
let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
|
||||
|
||||
// Check for isolated whitespace changes.
|
||||
let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
|
||||
&actual_patch,
|
||||
prediction.actual_cursor.as_ref(),
|
||||
);
|
||||
|
||||
let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) =
|
||||
best_expected_text
|
||||
.map(|reference_text| {
|
||||
let result =
|
||||
metrics::compute_kept_rate(original_text, &actual_text, reference_text);
|
||||
(
|
||||
Some(result.kept_rate),
|
||||
Some(result.recall_rate),
|
||||
Some(result.kept_chars),
|
||||
Some(result.correctly_deleted_chars),
|
||||
Some(result.discarded_chars),
|
||||
)
|
||||
})
|
||||
.unwrap_or((None, None, None, None, None));
|
||||
|
||||
scores.push(ExampleScore {
|
||||
delta_chr_f: best_delta_chr_f_metrics.score as f32,
|
||||
delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
|
||||
delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
|
||||
delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
|
||||
delta_chr_f_precision: best_delta_chr_f_metrics.precision,
|
||||
delta_chr_f_recall: best_delta_chr_f_metrics.recall,
|
||||
delta_chr_f_beta: best_delta_chr_f_metrics.beta,
|
||||
braces_disbalance,
|
||||
exact_lines_tp: best_exact_lines.true_positives,
|
||||
exact_lines_fp: best_exact_lines.false_positives,
|
||||
exact_lines_fn: best_exact_lines.false_negatives,
|
||||
reversal_ratio,
|
||||
cursor_distance,
|
||||
cursor_exact_match,
|
||||
wrong_editable_region,
|
||||
has_isolated_whitespace_changes,
|
||||
inserted_tokens: token_changes.inserted_tokens,
|
||||
deleted_tokens: token_changes.deleted_tokens,
|
||||
kept_rate,
|
||||
recall_rate,
|
||||
kept_chars,
|
||||
correctly_deleted_chars,
|
||||
discarded_chars,
|
||||
cumulative_logprob: prediction.cumulative_logprob,
|
||||
avg_logprob: prediction.avg_logprob,
|
||||
});
|
||||
scores.push(edit_prediction_metrics::score_prediction(
|
||||
PredictionScoringInput {
|
||||
original_text,
|
||||
expected_patches: &prepared_expected_patches,
|
||||
actual_patch: actual_patch.as_deref(),
|
||||
actual_cursor,
|
||||
reversal_context: Some(PredictionReversalContext {
|
||||
edit_history: &prompt_inputs.events,
|
||||
excerpt_start_row: prompt_inputs.excerpt_start_row,
|
||||
cursor_path,
|
||||
}),
|
||||
cumulative_logprob: prediction.cumulative_logprob,
|
||||
avg_logprob: prediction.avg_logprob,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_cursor_metrics(
|
||||
expected_cursor_editable_region_offset: Option<usize>,
|
||||
actual_cursor: Option<&ActualCursor>,
|
||||
) -> (Option<usize>, Option<bool>) {
|
||||
match (expected_cursor_editable_region_offset, actual_cursor) {
|
||||
(Some(expected), Some(actual)) => {
|
||||
let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
|
||||
let exact_match = distance == 0;
|
||||
(Some(distance), Some(exact_match))
|
||||
}
|
||||
(None, None) => {
|
||||
// Neither has cursor position - skip cursor scoring
|
||||
(None, None)
|
||||
}
|
||||
(Some(_), None) | (None, Some(_)) => {
|
||||
// Only one has cursor position - count as miss
|
||||
(None, Some(false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_report(examples: &[Example], verbose: bool) {
|
||||
const MAX_EXAMPLES_DEFAULT: usize = 20;
|
||||
use crate::metrics::ClassificationMetrics;
|
||||
|
|
@ -633,286 +473,27 @@ fn truncate_name(name: &str, max_len: usize) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SummaryJson {
|
||||
pub total_examples: usize,
|
||||
pub avg_delta_chr_f: f32,
|
||||
pub delta_chr_f_beta: f64,
|
||||
pub delta_chr_f_true_positives: usize,
|
||||
pub delta_chr_f_false_positives: usize,
|
||||
pub delta_chr_f_false_negatives: usize,
|
||||
pub delta_chr_f_precision: f64,
|
||||
pub delta_chr_f_recall: f64,
|
||||
pub avg_braces_disbalance: f32,
|
||||
pub exact_lines_true_positives: usize,
|
||||
pub exact_lines_false_positives: usize,
|
||||
pub exact_lines_false_negatives: usize,
|
||||
pub exact_lines_precision: f64,
|
||||
pub exact_lines_recall: f64,
|
||||
pub exact_lines_f1: f64,
|
||||
pub avg_reversal_ratio: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub qa_avg_reverts_edits: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub qa_avg_confidence: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_exact_match_rate: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_avg_distance: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_total_evaluated: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub wrong_editable_region_rate: Option<f32>,
|
||||
pub isolated_whitespace_rate: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avg_kept_rate: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avg_recall_rate: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_kept_chars: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_correctly_deleted_chars: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_discarded_chars: Option<usize>,
|
||||
}
|
||||
pub type SummaryJson = edit_prediction_metrics::SummaryJson;
|
||||
|
||||
pub fn compute_summary(examples: &[Example]) -> SummaryJson {
|
||||
use crate::metrics::ClassificationMetrics;
|
||||
edit_prediction_metrics::compute_summary(examples.iter().flat_map(|example| {
|
||||
example
|
||||
.score
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(move |(score_idx, score)| {
|
||||
let qa = example
|
||||
.qa
|
||||
.get(score_idx)
|
||||
.and_then(|qa| qa.as_ref())
|
||||
.map(|qa| edit_prediction_metrics::QaSummaryData {
|
||||
reverts_edits: qa.reverts_edits,
|
||||
confidence: qa.confidence,
|
||||
});
|
||||
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
let mut all_reversal_ratios = Vec::new();
|
||||
let mut braces_disbalance_sum: usize = 0;
|
||||
let mut total_delta_chr_f = ClassificationMetrics::default();
|
||||
let mut total_delta_chr_f_precision = 0.0;
|
||||
let mut total_delta_chr_f_recall = 0.0;
|
||||
let mut delta_chr_f_beta = 0.0;
|
||||
let mut total_exact_lines = ClassificationMetrics::default();
|
||||
let mut total_scores: usize = 0;
|
||||
let mut qa_reverts_count: usize = 0;
|
||||
let mut qa_reverts_total: usize = 0;
|
||||
let mut qa_confidence_sum: u64 = 0;
|
||||
let mut qa_confidence_count: usize = 0;
|
||||
let mut cursor_exact_matches: usize = 0;
|
||||
let mut cursor_total: usize = 0;
|
||||
let mut cursor_distance_sum: usize = 0;
|
||||
let mut cursor_distance_count: usize = 0;
|
||||
let mut wrong_editable_region_count: usize = 0;
|
||||
let mut wrong_editable_region_total: usize = 0;
|
||||
let mut isolated_whitespace_count: usize = 0;
|
||||
let mut kept_rate_sum: f64 = 0.0;
|
||||
let mut kept_rate_count: usize = 0;
|
||||
let mut kept_chars_total: usize = 0;
|
||||
let mut kept_chars_count: usize = 0;
|
||||
let mut correctly_deleted_chars_total: usize = 0;
|
||||
let mut correctly_deleted_chars_count: usize = 0;
|
||||
let mut discarded_chars_total: usize = 0;
|
||||
let mut discarded_chars_count: usize = 0;
|
||||
let mut recall_rate_sum: f64 = 0.0;
|
||||
let mut recall_rate_count: usize = 0;
|
||||
|
||||
for example in examples {
|
||||
for (score_idx, score) in example.score.iter().enumerate() {
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
all_reversal_ratios.push(score.reversal_ratio);
|
||||
total_scores += 1;
|
||||
braces_disbalance_sum += score.braces_disbalance;
|
||||
total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
|
||||
total_delta_chr_f_precision += score.delta_chr_f_precision;
|
||||
total_delta_chr_f_recall += score.delta_chr_f_recall;
|
||||
delta_chr_f_beta = score.delta_chr_f_beta;
|
||||
total_exact_lines.accumulate(&score.exact_lines_counts());
|
||||
|
||||
// Accumulate QA metrics
|
||||
if let Some(Some(qa)) = example.qa.get(score_idx) {
|
||||
if let Some(reverts) = qa.reverts_edits {
|
||||
qa_reverts_total += 1;
|
||||
if reverts {
|
||||
qa_reverts_count += 1;
|
||||
}
|
||||
}
|
||||
if let Some(conf) = qa.confidence {
|
||||
qa_confidence_sum += conf as u64;
|
||||
qa_confidence_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate wrong editable region metrics
|
||||
if let Some(wrong) = score.wrong_editable_region {
|
||||
wrong_editable_region_total += 1;
|
||||
if wrong {
|
||||
wrong_editable_region_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate isolated whitespace metrics
|
||||
if score.has_isolated_whitespace_changes {
|
||||
isolated_whitespace_count += 1;
|
||||
}
|
||||
|
||||
// Accumulate kept and recall rate metrics
|
||||
if let Some(kr) = score.kept_rate {
|
||||
kept_rate_sum += kr;
|
||||
kept_rate_count += 1;
|
||||
}
|
||||
if let Some(kept_chars) = score.kept_chars {
|
||||
kept_chars_total += kept_chars;
|
||||
kept_chars_count += 1;
|
||||
}
|
||||
if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
|
||||
correctly_deleted_chars_total += correctly_deleted_chars;
|
||||
correctly_deleted_chars_count += 1;
|
||||
}
|
||||
if let Some(discarded_chars) = score.discarded_chars {
|
||||
discarded_chars_total += discarded_chars;
|
||||
discarded_chars_count += 1;
|
||||
}
|
||||
if let Some(rr) = score.recall_rate {
|
||||
recall_rate_sum += rr;
|
||||
recall_rate_count += 1;
|
||||
}
|
||||
|
||||
// Accumulate cursor metrics
|
||||
if let Some(exact_match) = score.cursor_exact_match {
|
||||
cursor_total += 1;
|
||||
if exact_match {
|
||||
cursor_exact_matches += 1;
|
||||
}
|
||||
}
|
||||
if let Some(dist) = score.cursor_distance {
|
||||
cursor_distance_sum += dist;
|
||||
cursor_distance_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
|
||||
};
|
||||
|
||||
let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
|
||||
};
|
||||
|
||||
let avg_braces_disbalance = if total_scores == 0 {
|
||||
0.0
|
||||
} else {
|
||||
braces_disbalance_sum as f32 / total_scores as f32
|
||||
};
|
||||
|
||||
let qa_avg_reverts_edits = if qa_reverts_total > 0 {
|
||||
Some(qa_reverts_count as f32 / qa_reverts_total as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let qa_avg_confidence = if qa_confidence_count > 0 {
|
||||
Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cursor_exact_match_rate = if cursor_total > 0 {
|
||||
Some(cursor_exact_matches as f32 / cursor_total as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cursor_avg_distance = if cursor_distance_count > 0 {
|
||||
Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cursor_total_evaluated = if cursor_total > 0 {
|
||||
Some(cursor_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
|
||||
Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let isolated_whitespace_rate = if total_scores > 0 {
|
||||
Some(isolated_whitespace_count as f32 / total_scores as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let avg_kept_rate = if kept_rate_count > 0 {
|
||||
Some(kept_rate_sum / kept_rate_count as f64)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let avg_recall_rate = if recall_rate_count > 0 {
|
||||
Some(recall_rate_sum / recall_rate_count as f64)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let total_kept_chars = if kept_chars_count > 0 {
|
||||
Some(kept_chars_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 {
|
||||
Some(correctly_deleted_chars_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let total_discarded_chars = if discarded_chars_count > 0 {
|
||||
Some(discarded_chars_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
SummaryJson {
|
||||
total_examples: total_scores,
|
||||
avg_delta_chr_f,
|
||||
delta_chr_f_beta,
|
||||
delta_chr_f_true_positives: total_delta_chr_f.true_positives,
|
||||
delta_chr_f_false_positives: total_delta_chr_f.false_positives,
|
||||
delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
|
||||
delta_chr_f_precision: if total_scores == 0 {
|
||||
0.0
|
||||
} else {
|
||||
total_delta_chr_f_precision / total_scores as f64
|
||||
},
|
||||
delta_chr_f_recall: if total_scores == 0 {
|
||||
0.0
|
||||
} else {
|
||||
total_delta_chr_f_recall / total_scores as f64
|
||||
},
|
||||
avg_braces_disbalance,
|
||||
exact_lines_true_positives: total_exact_lines.true_positives,
|
||||
exact_lines_false_positives: total_exact_lines.false_positives,
|
||||
exact_lines_false_negatives: total_exact_lines.false_negatives,
|
||||
exact_lines_precision: total_exact_lines.precision(),
|
||||
exact_lines_recall: total_exact_lines.recall(),
|
||||
exact_lines_f1: total_exact_lines.f1(),
|
||||
avg_reversal_ratio,
|
||||
qa_avg_reverts_edits,
|
||||
qa_avg_confidence,
|
||||
cursor_exact_match_rate,
|
||||
cursor_avg_distance,
|
||||
cursor_total_evaluated,
|
||||
wrong_editable_region_rate,
|
||||
isolated_whitespace_rate,
|
||||
avg_kept_rate,
|
||||
avg_recall_rate,
|
||||
total_kept_chars,
|
||||
total_correctly_deleted_chars,
|
||||
total_discarded_chars,
|
||||
}
|
||||
edit_prediction_metrics::PredictionSummaryInput { score, qa }
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
mod kept_rate;
|
||||
mod patch_metrics;
|
||||
mod prediction_score;
|
||||
mod reversal;
|
||||
mod summary;
|
||||
mod tokenize;
|
||||
mod tree_sitter;
|
||||
|
||||
|
|
@ -22,5 +24,10 @@ pub use patch_metrics::extract_changed_lines_from_diff;
|
|||
pub use patch_metrics::has_isolated_whitespace_changes;
|
||||
pub use patch_metrics::is_editable_region_correct;
|
||||
pub use patch_metrics::reconstruct_texts_from_diff;
|
||||
pub use prediction_score::{
|
||||
ActualPredictionCursor, PredictionReversalContext, PredictionScore, PredictionScoringInput,
|
||||
PrepareExpectedPatchError, PreparedExpectedPatch, prepare_expected_patches, score_prediction,
|
||||
};
|
||||
pub use reversal::compute_prediction_reversal_ratio_from_history;
|
||||
pub use summary::{PredictionSummaryInput, QaSummaryData, SummaryJson, compute_summary};
|
||||
pub use tree_sitter::count_tree_sitter_errors;
|
||||
|
|
|
|||
319
crates/edit_prediction_metrics/src/prediction_score.rs
Normal file
319
crates/edit_prediction_metrics/src/prediction_score.rs
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
|
||||
|
||||
use crate::patch_metrics::{
|
||||
ClassificationMetrics, DeltaChrFMetrics, braces_disbalance, count_patch_token_changes,
|
||||
delta_chr_f, delta_chr_f_beta, exact_lines_match, has_isolated_whitespace_changes,
|
||||
is_editable_region_correct,
|
||||
};
|
||||
use crate::reversal::compute_prediction_reversal_ratio_from_history;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictionScore {
|
||||
pub delta_chr_f: f32,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_true_positives: usize,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_false_positives: usize,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_false_negatives: usize,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_precision: f64,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_recall: f64,
|
||||
#[serde(default)]
|
||||
pub delta_chr_f_beta: f64,
|
||||
pub braces_disbalance: usize,
|
||||
#[serde(default)]
|
||||
pub exact_lines_tp: usize,
|
||||
#[serde(default)]
|
||||
pub exact_lines_fp: usize,
|
||||
#[serde(default)]
|
||||
pub exact_lines_fn: usize,
|
||||
#[serde(default)]
|
||||
pub reversal_ratio: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_distance: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_exact_match: Option<bool>,
|
||||
pub wrong_editable_region: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub has_isolated_whitespace_changes: bool,
|
||||
#[serde(default)]
|
||||
pub inserted_tokens: usize,
|
||||
#[serde(default)]
|
||||
pub deleted_tokens: usize,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub kept_rate: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub recall_rate: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub kept_chars: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub correctly_deleted_chars: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub discarded_chars: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cumulative_logprob: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub avg_logprob: Option<f64>,
|
||||
}
|
||||
|
||||
impl PredictionScore {
|
||||
pub fn zero() -> Self {
|
||||
Self {
|
||||
delta_chr_f: 0.0,
|
||||
delta_chr_f_true_positives: 0,
|
||||
delta_chr_f_false_positives: 0,
|
||||
delta_chr_f_false_negatives: 0,
|
||||
delta_chr_f_precision: 0.0,
|
||||
delta_chr_f_recall: 0.0,
|
||||
delta_chr_f_beta: delta_chr_f_beta(),
|
||||
braces_disbalance: 0,
|
||||
exact_lines_tp: 0,
|
||||
exact_lines_fp: 0,
|
||||
exact_lines_fn: 0,
|
||||
reversal_ratio: 0.0,
|
||||
cursor_distance: None,
|
||||
cursor_exact_match: None,
|
||||
wrong_editable_region: None,
|
||||
has_isolated_whitespace_changes: false,
|
||||
inserted_tokens: 0,
|
||||
deleted_tokens: 0,
|
||||
kept_rate: None,
|
||||
recall_rate: None,
|
||||
kept_chars: None,
|
||||
correctly_deleted_chars: None,
|
||||
discarded_chars: None,
|
||||
cumulative_logprob: None,
|
||||
avg_logprob: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delta_chr_f_counts(&self) -> ClassificationMetrics {
|
||||
ClassificationMetrics {
|
||||
true_positives: self.delta_chr_f_true_positives,
|
||||
false_positives: self.delta_chr_f_false_positives,
|
||||
false_negatives: self.delta_chr_f_false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exact_lines_counts(&self) -> ClassificationMetrics {
|
||||
ClassificationMetrics {
|
||||
true_positives: self.exact_lines_tp,
|
||||
false_positives: self.exact_lines_fp,
|
||||
false_negatives: self.exact_lines_fn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PredictionScore {
|
||||
fn default() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PreparedExpectedPatch {
|
||||
pub patch: String,
|
||||
pub text: String,
|
||||
pub cursor_editable_region_offset: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PrepareExpectedPatchError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for PrepareExpectedPatchError {
|
||||
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.message.fmt(formatter)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for PrepareExpectedPatchError {}
|
||||
|
||||
pub fn prepare_expected_patches(
|
||||
expected_patches_with_cursors: &[(String, Option<usize>)],
|
||||
original_text: &str,
|
||||
old_editable_region: Option<&str>,
|
||||
) -> Result<Vec<PreparedExpectedPatch>, PrepareExpectedPatchError> {
|
||||
expected_patches_with_cursors
|
||||
.iter()
|
||||
.map(|(patch, cursor_in_patch)| {
|
||||
let text = apply_diff_to_string(patch, original_text).map_err(|error| {
|
||||
PrepareExpectedPatchError {
|
||||
message: error.to_string(),
|
||||
}
|
||||
})?;
|
||||
let cursor_editable_region_offset =
|
||||
if let (Some(editable_region), Some(cursor_in_patch)) =
|
||||
(old_editable_region, *cursor_in_patch)
|
||||
{
|
||||
match apply_diff_to_string_with_hunk_offset(patch, editable_region) {
|
||||
Ok((_, hunk_offset)) => Some(hunk_offset.unwrap_or(0) + cursor_in_patch),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
*cursor_in_patch
|
||||
};
|
||||
|
||||
Ok(PreparedExpectedPatch {
|
||||
patch: patch.clone(),
|
||||
text,
|
||||
cursor_editable_region_offset,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ActualPredictionCursor {
|
||||
pub row: u32,
|
||||
pub editable_region_offset: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct PredictionReversalContext<'a> {
|
||||
pub edit_history: &'a [Arc<zeta_prompt::Event>],
|
||||
pub excerpt_start_row: Option<u32>,
|
||||
pub cursor_path: &'a Path,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct PredictionScoringInput<'a> {
|
||||
pub original_text: &'a str,
|
||||
pub expected_patches: &'a [PreparedExpectedPatch],
|
||||
pub actual_patch: Option<&'a str>,
|
||||
pub actual_cursor: Option<ActualPredictionCursor>,
|
||||
pub reversal_context: Option<PredictionReversalContext<'a>>,
|
||||
pub cumulative_logprob: Option<f64>,
|
||||
pub avg_logprob: Option<f64>,
|
||||
}
|
||||
|
||||
pub fn score_prediction(input: PredictionScoringInput<'_>) -> PredictionScore {
|
||||
let Some(actual_patch) = input.actual_patch else {
|
||||
return PredictionScore::zero();
|
||||
};
|
||||
|
||||
let token_changes = count_patch_token_changes(actual_patch);
|
||||
|
||||
let actual_text = match apply_diff_to_string(actual_patch, input.original_text) {
|
||||
Ok(text) => text,
|
||||
Err(_) => {
|
||||
let mut score = PredictionScore::zero();
|
||||
score.inserted_tokens = token_changes.inserted_tokens;
|
||||
score.deleted_tokens = token_changes.deleted_tokens;
|
||||
return score;
|
||||
}
|
||||
};
|
||||
|
||||
let mut best_delta_chr_f_metrics = DeltaChrFMetrics::default();
|
||||
let mut best_expected_cursor = None;
|
||||
let mut best_expected_text = None;
|
||||
|
||||
for expected in input.expected_patches {
|
||||
let delta_chr_f_metrics = delta_chr_f(input.original_text, &expected.text, &actual_text);
|
||||
if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
|
||||
best_delta_chr_f_metrics = delta_chr_f_metrics;
|
||||
best_expected_cursor = expected.cursor_editable_region_offset;
|
||||
best_expected_text = Some(expected.text.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let disbalance_before = braces_disbalance(input.original_text);
|
||||
let disbalance_after = braces_disbalance(&actual_text);
|
||||
let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
|
||||
|
||||
let best_exact_lines = input
|
||||
.expected_patches
|
||||
.iter()
|
||||
.map(|expected| exact_lines_match(&expected.patch, actual_patch))
|
||||
.max_by_key(|metrics| metrics.true_positives)
|
||||
.unwrap_or_default();
|
||||
|
||||
let reversal_ratio = input
|
||||
.reversal_context
|
||||
.map(|context| {
|
||||
compute_prediction_reversal_ratio_from_history(
|
||||
input.original_text,
|
||||
context.edit_history,
|
||||
context.excerpt_start_row,
|
||||
&actual_text,
|
||||
context.cursor_path,
|
||||
)
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let (cursor_distance, cursor_exact_match) =
|
||||
compute_cursor_metrics(best_expected_cursor, input.actual_cursor);
|
||||
|
||||
let wrong_editable_region = Some(!is_editable_region_correct(actual_patch));
|
||||
let has_isolated_whitespace_changes =
|
||||
has_isolated_whitespace_changes(actual_patch, input.actual_cursor.map(|cursor| cursor.row));
|
||||
|
||||
let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) =
|
||||
best_expected_text
|
||||
.map(|reference_text| {
|
||||
let result = crate::kept_rate::compute_kept_rate(
|
||||
input.original_text,
|
||||
&actual_text,
|
||||
reference_text,
|
||||
);
|
||||
(
|
||||
Some(result.kept_rate),
|
||||
Some(result.recall_rate),
|
||||
Some(result.kept_chars),
|
||||
Some(result.correctly_deleted_chars),
|
||||
Some(result.discarded_chars),
|
||||
)
|
||||
})
|
||||
.unwrap_or((None, None, None, None, None));
|
||||
|
||||
PredictionScore {
|
||||
delta_chr_f: best_delta_chr_f_metrics.score as f32,
|
||||
delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
|
||||
delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
|
||||
delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
|
||||
delta_chr_f_precision: best_delta_chr_f_metrics.precision,
|
||||
delta_chr_f_recall: best_delta_chr_f_metrics.recall,
|
||||
delta_chr_f_beta: best_delta_chr_f_metrics.beta,
|
||||
braces_disbalance,
|
||||
exact_lines_tp: best_exact_lines.true_positives,
|
||||
exact_lines_fp: best_exact_lines.false_positives,
|
||||
exact_lines_fn: best_exact_lines.false_negatives,
|
||||
reversal_ratio,
|
||||
cursor_distance,
|
||||
cursor_exact_match,
|
||||
wrong_editable_region,
|
||||
has_isolated_whitespace_changes,
|
||||
inserted_tokens: token_changes.inserted_tokens,
|
||||
deleted_tokens: token_changes.deleted_tokens,
|
||||
kept_rate,
|
||||
recall_rate,
|
||||
kept_chars,
|
||||
correctly_deleted_chars,
|
||||
discarded_chars,
|
||||
cumulative_logprob: input.cumulative_logprob,
|
||||
avg_logprob: input.avg_logprob,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_cursor_metrics(
|
||||
expected_cursor_editable_region_offset: Option<usize>,
|
||||
actual_cursor: Option<ActualPredictionCursor>,
|
||||
) -> (Option<usize>, Option<bool>) {
|
||||
match (expected_cursor_editable_region_offset, actual_cursor) {
|
||||
(Some(expected), Some(actual)) => {
|
||||
let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
|
||||
let exact_match = distance == 0;
|
||||
(Some(distance), Some(exact_match))
|
||||
}
|
||||
(None, None) => (None, None),
|
||||
(Some(_), None) | (None, Some(_)) => (None, Some(false)),
|
||||
}
|
||||
}
|
||||
293
crates/edit_prediction_metrics/src/summary.rs
Normal file
293
crates/edit_prediction_metrics/src/summary.rs
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
use serde::Serialize;
|
||||
|
||||
use crate::patch_metrics::ClassificationMetrics;
|
||||
use crate::prediction_score::PredictionScore;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct QaSummaryData {
|
||||
pub reverts_edits: Option<bool>,
|
||||
pub confidence: Option<u8>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct PredictionSummaryInput<'a> {
|
||||
pub score: &'a PredictionScore,
|
||||
pub qa: Option<QaSummaryData>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct SummaryJson {
|
||||
pub total_examples: usize,
|
||||
pub avg_delta_chr_f: f32,
|
||||
pub delta_chr_f_beta: f64,
|
||||
pub delta_chr_f_true_positives: usize,
|
||||
pub delta_chr_f_false_positives: usize,
|
||||
pub delta_chr_f_false_negatives: usize,
|
||||
pub delta_chr_f_precision: f64,
|
||||
pub delta_chr_f_recall: f64,
|
||||
pub avg_braces_disbalance: f32,
|
||||
pub exact_lines_true_positives: usize,
|
||||
pub exact_lines_false_positives: usize,
|
||||
pub exact_lines_false_negatives: usize,
|
||||
pub exact_lines_precision: f64,
|
||||
pub exact_lines_recall: f64,
|
||||
pub exact_lines_f1: f64,
|
||||
pub avg_reversal_ratio: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub qa_avg_reverts_edits: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub qa_avg_confidence: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_exact_match_rate: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_avg_distance: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cursor_total_evaluated: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub wrong_editable_region_rate: Option<f32>,
|
||||
pub isolated_whitespace_rate: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avg_kept_rate: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub avg_recall_rate: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_kept_chars: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_correctly_deleted_chars: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_discarded_chars: Option<usize>,
|
||||
}
|
||||
|
||||
pub fn compute_summary<'a>(
|
||||
predictions: impl IntoIterator<Item = PredictionSummaryInput<'a>>,
|
||||
) -> SummaryJson {
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
let mut all_reversal_ratios = Vec::new();
|
||||
let mut braces_disbalance_sum: usize = 0;
|
||||
let mut total_delta_chr_f = ClassificationMetrics::default();
|
||||
let mut total_delta_chr_f_precision = 0.0;
|
||||
let mut total_delta_chr_f_recall = 0.0;
|
||||
let mut delta_chr_f_beta = 0.0;
|
||||
let mut total_exact_lines = ClassificationMetrics::default();
|
||||
let mut total_scores: usize = 0;
|
||||
let mut qa_reverts_count: usize = 0;
|
||||
let mut qa_reverts_total: usize = 0;
|
||||
let mut qa_confidence_sum: u64 = 0;
|
||||
let mut qa_confidence_count: usize = 0;
|
||||
let mut cursor_exact_matches: usize = 0;
|
||||
let mut cursor_total: usize = 0;
|
||||
let mut cursor_distance_sum: usize = 0;
|
||||
let mut cursor_distance_count: usize = 0;
|
||||
let mut wrong_editable_region_count: usize = 0;
|
||||
let mut wrong_editable_region_total: usize = 0;
|
||||
let mut isolated_whitespace_count: usize = 0;
|
||||
let mut kept_rate_sum: f64 = 0.0;
|
||||
let mut kept_rate_count: usize = 0;
|
||||
let mut kept_chars_total: usize = 0;
|
||||
let mut kept_chars_count: usize = 0;
|
||||
let mut correctly_deleted_chars_total: usize = 0;
|
||||
let mut correctly_deleted_chars_count: usize = 0;
|
||||
let mut discarded_chars_total: usize = 0;
|
||||
let mut discarded_chars_count: usize = 0;
|
||||
let mut recall_rate_sum: f64 = 0.0;
|
||||
let mut recall_rate_count: usize = 0;
|
||||
|
||||
for prediction in predictions {
|
||||
let score = prediction.score;
|
||||
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
all_reversal_ratios.push(score.reversal_ratio);
|
||||
total_scores += 1;
|
||||
braces_disbalance_sum += score.braces_disbalance;
|
||||
total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
|
||||
total_delta_chr_f_precision += score.delta_chr_f_precision;
|
||||
total_delta_chr_f_recall += score.delta_chr_f_recall;
|
||||
delta_chr_f_beta = score.delta_chr_f_beta;
|
||||
total_exact_lines.accumulate(&score.exact_lines_counts());
|
||||
|
||||
if let Some(qa) = prediction.qa {
|
||||
if let Some(reverts) = qa.reverts_edits {
|
||||
qa_reverts_total += 1;
|
||||
if reverts {
|
||||
qa_reverts_count += 1;
|
||||
}
|
||||
}
|
||||
if let Some(confidence) = qa.confidence {
|
||||
qa_confidence_sum += confidence as u64;
|
||||
qa_confidence_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(wrong) = score.wrong_editable_region {
|
||||
wrong_editable_region_total += 1;
|
||||
if wrong {
|
||||
wrong_editable_region_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if score.has_isolated_whitespace_changes {
|
||||
isolated_whitespace_count += 1;
|
||||
}
|
||||
|
||||
if let Some(kept_rate) = score.kept_rate {
|
||||
kept_rate_sum += kept_rate;
|
||||
kept_rate_count += 1;
|
||||
}
|
||||
if let Some(kept_chars) = score.kept_chars {
|
||||
kept_chars_total += kept_chars;
|
||||
kept_chars_count += 1;
|
||||
}
|
||||
if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
|
||||
correctly_deleted_chars_total += correctly_deleted_chars;
|
||||
correctly_deleted_chars_count += 1;
|
||||
}
|
||||
if let Some(discarded_chars) = score.discarded_chars {
|
||||
discarded_chars_total += discarded_chars;
|
||||
discarded_chars_count += 1;
|
||||
}
|
||||
if let Some(recall_rate) = score.recall_rate {
|
||||
recall_rate_sum += recall_rate;
|
||||
recall_rate_count += 1;
|
||||
}
|
||||
|
||||
if let Some(exact_match) = score.cursor_exact_match {
|
||||
cursor_total += 1;
|
||||
if exact_match {
|
||||
cursor_exact_matches += 1;
|
||||
}
|
||||
}
|
||||
if let Some(distance) = score.cursor_distance {
|
||||
cursor_distance_sum += distance;
|
||||
cursor_distance_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
|
||||
};
|
||||
|
||||
let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
|
||||
};
|
||||
|
||||
let avg_braces_disbalance = if total_scores == 0 {
|
||||
0.0
|
||||
} else {
|
||||
braces_disbalance_sum as f32 / total_scores as f32
|
||||
};
|
||||
|
||||
let qa_avg_reverts_edits = if qa_reverts_total > 0 {
|
||||
Some(qa_reverts_count as f32 / qa_reverts_total as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let qa_avg_confidence = if qa_confidence_count > 0 {
|
||||
Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cursor_exact_match_rate = if cursor_total > 0 {
|
||||
Some(cursor_exact_matches as f32 / cursor_total as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cursor_avg_distance = if cursor_distance_count > 0 {
|
||||
Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cursor_total_evaluated = if cursor_total > 0 {
|
||||
Some(cursor_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
|
||||
Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let isolated_whitespace_rate = if total_scores > 0 {
|
||||
Some(isolated_whitespace_count as f32 / total_scores as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let avg_kept_rate = if kept_rate_count > 0 {
|
||||
Some(kept_rate_sum / kept_rate_count as f64)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let avg_recall_rate = if recall_rate_count > 0 {
|
||||
Some(recall_rate_sum / recall_rate_count as f64)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let total_kept_chars = if kept_chars_count > 0 {
|
||||
Some(kept_chars_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 {
|
||||
Some(correctly_deleted_chars_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let total_discarded_chars = if discarded_chars_count > 0 {
|
||||
Some(discarded_chars_total)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
SummaryJson {
|
||||
total_examples: total_scores,
|
||||
avg_delta_chr_f,
|
||||
delta_chr_f_beta,
|
||||
delta_chr_f_true_positives: total_delta_chr_f.true_positives,
|
||||
delta_chr_f_false_positives: total_delta_chr_f.false_positives,
|
||||
delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
|
||||
delta_chr_f_precision: if total_scores == 0 {
|
||||
0.0
|
||||
} else {
|
||||
total_delta_chr_f_precision / total_scores as f64
|
||||
},
|
||||
delta_chr_f_recall: if total_scores == 0 {
|
||||
0.0
|
||||
} else {
|
||||
total_delta_chr_f_recall / total_scores as f64
|
||||
},
|
||||
avg_braces_disbalance,
|
||||
exact_lines_true_positives: total_exact_lines.true_positives,
|
||||
exact_lines_false_positives: total_exact_lines.false_positives,
|
||||
exact_lines_false_negatives: total_exact_lines.false_negatives,
|
||||
exact_lines_precision: total_exact_lines.precision(),
|
||||
exact_lines_recall: total_exact_lines.recall(),
|
||||
exact_lines_f1: total_exact_lines.f1(),
|
||||
avg_reversal_ratio,
|
||||
qa_avg_reverts_edits,
|
||||
qa_avg_confidence,
|
||||
cursor_exact_match_rate,
|
||||
cursor_avg_distance,
|
||||
cursor_total_evaluated,
|
||||
wrong_editable_region_rate,
|
||||
isolated_whitespace_rate,
|
||||
avg_kept_rate,
|
||||
avg_recall_rate,
|
||||
total_kept_chars,
|
||||
total_correctly_deleted_chars,
|
||||
total_discarded_chars,
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue