ep: Add jitter to cursor position (#57597)

When generating a training or evaluation example with `ep split-commit`,
the cursor sampling logic becomes:

1. 80% chance of cursor being at the end of the source patch
2. 20% chance of cursor being at the beginning of the target patch
3. 20% chance of adding a jitter offset (same line, ±5 columns for now)

Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2026-05-24 19:43:53 +03:00 committed by GitHub
parent eb2223c080
commit 8bfe32010c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,7 +27,7 @@ use clap::Args;
use edit_prediction::example_spec::ExampleSpec; use edit_prediction::example_spec::ExampleSpec;
use rand::Rng; use rand::Rng;
use rand::SeedableRng; use rand::SeedableRng;
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use similar::{DiffTag, TextDiff}; use similar::{DiffTag, TextDiff};
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::fs; use std::fs;
@ -74,11 +74,12 @@ pub struct AnnotatedCommit {
} }
/// Cursor position in a file. /// Cursor position in a file.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CursorPosition { pub struct CursorPosition {
pub file: String, pub file: String,
pub line: usize, pub line: usize,
pub column: usize, pub column: usize,
pub line_length: usize,
} }
impl std::fmt::Display for CursorPosition { impl std::fmt::Display for CursorPosition {
@ -331,7 +332,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
// Sample cursor position // Sample cursor position
let cursor = match cursor_opt { let cursor = match cursor_opt {
Some(c) => c, Some(c) => c,
None => sample_cursor_position(&patch, &split_commit) None => sample_cursor_position(&split_commit, rng.as_mut())
.context("failed to sample cursor position")?, .context("failed to sample cursor position")?,
}; };
@ -343,7 +344,8 @@ pub fn generate_evaluation_example_from_ordered_commit(
) )
.context("failed to generate cursor excerpt")?; .context("failed to generate cursor excerpt")?;
// Handle edge case where split_point == 0 // Where the source patch is empty, there's not enough info to make a
// meaningful prediction
if split == 0 { if split == 0 {
split_commit.target_patch = String::new(); split_commit.target_patch = String::new();
} }
@ -754,15 +756,12 @@ pub fn imitate_human_edits(
} }
// Calculate cursor position // Calculate cursor position
let cursor = CursorPosition { let line = if is_replacement {
file: tgt_edit_loc.filename.clone(), src_edit_loc.as_ref().unwrap().source_line_number
line: if is_replacement { } else {
src_edit_loc.as_ref().unwrap().source_line_number tgt_edit_loc.target_line_number
} else {
tgt_edit_loc.target_line_number
},
column: new_src.len() + 1,
}; };
let column = new_src.len() + 1;
// Add remainder of source if similar enough to target remainder // Add remainder of source if similar enough to target remainder
let remainder_src: String = (last_old_end..src_tokens.len()) let remainder_src: String = (last_old_end..src_tokens.len())
@ -785,6 +784,13 @@ pub fn imitate_human_edits(
return no_change; return no_change;
} }
let cursor = CursorPosition {
file: tgt_edit_loc.filename.clone(),
line,
column: column.min(new_src.len()),
line_length: new_src.len(),
};
// Build new source patch with the intermediate line // Build new source patch with the intermediate line
let mut new_src_patch = src_patch; let mut new_src_patch = src_patch;
if is_replacement { if is_replacement {
@ -860,16 +866,17 @@ pub fn imitate_human_edits(
fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> { fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
let loc = locate_edited_line(patch, -1)?; let loc = locate_edited_line(patch, -1)?;
let (line, col) = match &loc.patch_line { let (line, column, line_length) = match &loc.patch_line {
PatchLine::Addition(content) => (loc.target_line_number, content.len()), PatchLine::Addition(content) => (loc.target_line_number, content.len(), content.len()),
PatchLine::Deletion(_) => (loc.target_line_number, 1), PatchLine::Deletion(_) => (loc.target_line_number, 1, 1),
_ => return None, _ => return None,
}; };
Some(CursorPosition { Some(CursorPosition {
file: loc.filename, file: loc.filename,
line, line,
column: col, column,
line_length,
}) })
} }
@ -878,7 +885,7 @@ fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
let loc = locate_edited_line(patch, 0)?; let loc = locate_edited_line(patch, 0)?;
let hunk = patch.hunks.get(loc.hunk_index)?; let hunk = patch.hunks.get(loc.hunk_index)?;
let column = if loc.line_index_within_hunk > 0 { let line_length = if loc.line_index_within_hunk > 0 {
if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) { if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
let content = match prev_line { let content = match prev_line {
PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s, PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
@ -893,32 +900,57 @@ fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
}; };
let line = loc.target_line_number.saturating_sub(1).max(1); let line = loc.target_line_number.saturating_sub(1).max(1);
let column = line_length.saturating_sub(1);
Some(CursorPosition { Some(CursorPosition {
file: loc.filename, file: loc.filename,
line, line,
column, column,
line_length,
}) })
} }
/// Sample cursor position according to the following rules: /// Sample cursor position according to the following rules:
/// 1. 50% chance of cursor being at the end of the source patch /// 1. 80% chance of cursor being at the end of the source patch
/// 2. 50% chance of cursor being at the beginning of the target patch /// 2. 20% chance of cursor being at the beginning of the target patch
pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> { /// 3. 20% chance of adding a jitter offset
// Try end of history first pub fn sample_cursor_position(
split_commit: &SplitCommit,
rng: &mut dyn rand::RngCore,
) -> Option<CursorPosition> {
// End of history
let src_patch = Patch::parse_unified_diff(&split_commit.source_patch); let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
if let Some(cursor) = locate_end_of_last_edit(&src_patch) { let src_cursor = locate_end_of_last_edit(&src_patch);
return Some(cursor);
}
// Try beginning of target // Beginning of target
let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch); let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) { let tgt_cursor = locate_beginning_of_first_edit(&tgt_patch);
return Some(cursor);
// Randomly pick a cursor position
let prefer_source = rng.random_bool(0.8);
let mut cursor = if prefer_source {
src_cursor.or(tgt_cursor)
} else {
tgt_cursor.or(src_cursor)
};
// Possible add jitter
let should_jitter = rng.random_bool(0.2);
if should_jitter {
if let Some(cursor) = cursor.as_mut() {
let col_offset = rng.random_range(1..=5);
if rng.random_bool(0.5) {
cursor.column = cursor
.column
.saturating_add(col_offset)
.min(cursor.line_length);
} else {
cursor.column = cursor.column.saturating_sub(col_offset);
}
}
} }
// Fallback: use the original patch cursor
locate_end_of_last_edit(patch)
} }
/// Get cursor excerpt from the patches. /// Get cursor excerpt from the patches.
@ -1230,6 +1262,7 @@ Date: Mon Jan 1 00:00:00 2024
file: "src/main.rs".to_string(), file: "src/main.rs".to_string(),
line: 42, line: 42,
column: 10, column: 10,
line_length: 80,
}; };
assert_eq!(cursor.to_string(), "src/main.rs:42:10"); assert_eq!(cursor.to_string(), "src/main.rs:42:10");
} }
@ -1760,6 +1793,7 @@ index 123..456 789
file: "test.md".to_string(), file: "test.md".to_string(),
line: 1, line: 1,
column: 1, // Byte index 1 is inside '第' (bytes 0..3) column: 1, // Byte index 1 is inside '第' (bytes 0..3)
line_length: 80,
}; };
let source_patch = r#"--- a/test.md let source_patch = r#"--- a/test.md