ep: Add jitter to cursor position (#57597)

When generating a training or evaluation example with `ep split-commit`,
the cursor sampling logic becomes:

1. 80% chance of cursor being at the end of the source patch
2. 20% chance of cursor being at the beginning of the target patch
3. 20% chance of adding a jitter offset (same line, ±5 columns for now)

Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2026-05-24 19:43:53 +03:00 committed by GitHub
parent eb2223c080
commit 8bfe32010c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,7 +27,7 @@ use clap::Args;
use edit_prediction::example_spec::ExampleSpec;
use rand::Rng;
use rand::SeedableRng;
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use similar::{DiffTag, TextDiff};
use std::collections::BTreeSet;
use std::fs;
@ -74,11 +74,12 @@ pub struct AnnotatedCommit {
}
/// Cursor position in a file.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CursorPosition {
pub file: String,
pub line: usize,
pub column: usize,
pub line_length: usize,
}
impl std::fmt::Display for CursorPosition {
@ -331,7 +332,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
// Sample cursor position
let cursor = match cursor_opt {
Some(c) => c,
None => sample_cursor_position(&patch, &split_commit)
None => sample_cursor_position(&split_commit, rng.as_mut())
.context("failed to sample cursor position")?,
};
@ -343,7 +344,8 @@ pub fn generate_evaluation_example_from_ordered_commit(
)
.context("failed to generate cursor excerpt")?;
// Handle edge case where split_point == 0
// Where the source patch is empty, there's not enough info to make a
// meaningful prediction
if split == 0 {
split_commit.target_patch = String::new();
}
@ -754,15 +756,12 @@ pub fn imitate_human_edits(
}
// Calculate cursor position
let cursor = CursorPosition {
file: tgt_edit_loc.filename.clone(),
line: if is_replacement {
src_edit_loc.as_ref().unwrap().source_line_number
} else {
tgt_edit_loc.target_line_number
},
column: new_src.len() + 1,
let line = if is_replacement {
src_edit_loc.as_ref().unwrap().source_line_number
} else {
tgt_edit_loc.target_line_number
};
let column = new_src.len() + 1;
// Add remainder of source if similar enough to target remainder
let remainder_src: String = (last_old_end..src_tokens.len())
@ -785,6 +784,13 @@ pub fn imitate_human_edits(
return no_change;
}
let cursor = CursorPosition {
file: tgt_edit_loc.filename.clone(),
line,
column: column.min(new_src.len()),
line_length: new_src.len(),
};
// Build new source patch with the intermediate line
let mut new_src_patch = src_patch;
if is_replacement {
@ -860,16 +866,17 @@ pub fn imitate_human_edits(
fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
let loc = locate_edited_line(patch, -1)?;
let (line, col) = match &loc.patch_line {
PatchLine::Addition(content) => (loc.target_line_number, content.len()),
PatchLine::Deletion(_) => (loc.target_line_number, 1),
let (line, column, line_length) = match &loc.patch_line {
PatchLine::Addition(content) => (loc.target_line_number, content.len(), content.len()),
PatchLine::Deletion(_) => (loc.target_line_number, 1, 1),
_ => return None,
};
Some(CursorPosition {
file: loc.filename,
line,
column: col,
column,
line_length,
})
}
@ -878,7 +885,7 @@ fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
let loc = locate_edited_line(patch, 0)?;
let hunk = patch.hunks.get(loc.hunk_index)?;
let column = if loc.line_index_within_hunk > 0 {
let line_length = if loc.line_index_within_hunk > 0 {
if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
let content = match prev_line {
PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
@ -893,32 +900,57 @@ fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
};
let line = loc.target_line_number.saturating_sub(1).max(1);
let column = line_length.saturating_sub(1);
Some(CursorPosition {
file: loc.filename,
line,
column,
line_length,
})
}
/// Sample cursor position according to the following rules:
/// 1. 50% chance of cursor being at the end of the source patch
/// 2. 50% chance of cursor being at the beginning of the target patch
pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
// Try end of history first
/// 1. 80% chance of cursor being at the end of the source patch
/// 2. 20% chance of cursor being at the beginning of the target patch
/// 3. 20% chance of adding a jitter offset
pub fn sample_cursor_position(
split_commit: &SplitCommit,
rng: &mut dyn rand::RngCore,
) -> Option<CursorPosition> {
// End of history
let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
return Some(cursor);
}
let src_cursor = locate_end_of_last_edit(&src_patch);
// Try beginning of target
// Beginning of target
let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
return Some(cursor);
let tgt_cursor = locate_beginning_of_first_edit(&tgt_patch);
// Randomly pick a cursor position
let prefer_source = rng.random_bool(0.8);
let mut cursor = if prefer_source {
src_cursor.or(tgt_cursor)
} else {
tgt_cursor.or(src_cursor)
};
// Possible add jitter
let should_jitter = rng.random_bool(0.2);
if should_jitter {
if let Some(cursor) = cursor.as_mut() {
let col_offset = rng.random_range(1..=5);
if rng.random_bool(0.5) {
cursor.column = cursor
.column
.saturating_add(col_offset)
.min(cursor.line_length);
} else {
cursor.column = cursor.column.saturating_sub(col_offset);
}
}
}
// Fallback: use the original patch
locate_end_of_last_edit(patch)
cursor
}
/// Get cursor excerpt from the patches.
@ -1230,6 +1262,7 @@ Date: Mon Jan 1 00:00:00 2024
file: "src/main.rs".to_string(),
line: 42,
column: 10,
line_length: 80,
};
assert_eq!(cursor.to_string(), "src/main.rs:42:10");
}
@ -1760,6 +1793,7 @@ index 123..456 789
file: "test.md".to_string(),
line: 1,
column: 1, // Byte index 1 is inside '第' (bytes 0..3)
line_length: 80,
};
let source_patch = r#"--- a/test.md