mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
ep: Add jitter to cursor position (#57597)
When generating a training or evaluation example with `ep split-commit`, the cursor sampling logic becomes: 1. 80% chance of cursor being at the end of the source patch 2. 20% chance of cursor being at the beginning of the target patch 3. 20% chance of adding a jitter offset (same line, ±5 columns for now) Release Notes: - N/A
This commit is contained in:
parent
eb2223c080
commit
8bfe32010c
1 changed files with 63 additions and 29 deletions
|
|
@ -27,7 +27,7 @@ use clap::Args;
|
||||||
use edit_prediction::example_spec::ExampleSpec;
|
use edit_prediction::example_spec::ExampleSpec;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::Deserialize;
|
||||||
use similar::{DiffTag, TextDiff};
|
use similar::{DiffTag, TextDiff};
|
||||||
use std::collections::BTreeSet;
|
use std::collections::BTreeSet;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -74,11 +74,12 @@ pub struct AnnotatedCommit {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Cursor position in a file.
|
/// Cursor position in a file.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CursorPosition {
|
pub struct CursorPosition {
|
||||||
pub file: String,
|
pub file: String,
|
||||||
pub line: usize,
|
pub line: usize,
|
||||||
pub column: usize,
|
pub column: usize,
|
||||||
|
pub line_length: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for CursorPosition {
|
impl std::fmt::Display for CursorPosition {
|
||||||
|
|
@ -331,7 +332,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
|
||||||
// Sample cursor position
|
// Sample cursor position
|
||||||
let cursor = match cursor_opt {
|
let cursor = match cursor_opt {
|
||||||
Some(c) => c,
|
Some(c) => c,
|
||||||
None => sample_cursor_position(&patch, &split_commit)
|
None => sample_cursor_position(&split_commit, rng.as_mut())
|
||||||
.context("failed to sample cursor position")?,
|
.context("failed to sample cursor position")?,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -343,7 +344,8 @@ pub fn generate_evaluation_example_from_ordered_commit(
|
||||||
)
|
)
|
||||||
.context("failed to generate cursor excerpt")?;
|
.context("failed to generate cursor excerpt")?;
|
||||||
|
|
||||||
// Handle edge case where split_point == 0
|
// Where the source patch is empty, there's not enough info to make a
|
||||||
|
// meaningful prediction
|
||||||
if split == 0 {
|
if split == 0 {
|
||||||
split_commit.target_patch = String::new();
|
split_commit.target_patch = String::new();
|
||||||
}
|
}
|
||||||
|
|
@ -754,15 +756,12 @@ pub fn imitate_human_edits(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cursor position
|
// Calculate cursor position
|
||||||
let cursor = CursorPosition {
|
let line = if is_replacement {
|
||||||
file: tgt_edit_loc.filename.clone(),
|
src_edit_loc.as_ref().unwrap().source_line_number
|
||||||
line: if is_replacement {
|
} else {
|
||||||
src_edit_loc.as_ref().unwrap().source_line_number
|
tgt_edit_loc.target_line_number
|
||||||
} else {
|
|
||||||
tgt_edit_loc.target_line_number
|
|
||||||
},
|
|
||||||
column: new_src.len() + 1,
|
|
||||||
};
|
};
|
||||||
|
let column = new_src.len() + 1;
|
||||||
|
|
||||||
// Add remainder of source if similar enough to target remainder
|
// Add remainder of source if similar enough to target remainder
|
||||||
let remainder_src: String = (last_old_end..src_tokens.len())
|
let remainder_src: String = (last_old_end..src_tokens.len())
|
||||||
|
|
@ -785,6 +784,13 @@ pub fn imitate_human_edits(
|
||||||
return no_change;
|
return no_change;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let cursor = CursorPosition {
|
||||||
|
file: tgt_edit_loc.filename.clone(),
|
||||||
|
line,
|
||||||
|
column: column.min(new_src.len()),
|
||||||
|
line_length: new_src.len(),
|
||||||
|
};
|
||||||
|
|
||||||
// Build new source patch with the intermediate line
|
// Build new source patch with the intermediate line
|
||||||
let mut new_src_patch = src_patch;
|
let mut new_src_patch = src_patch;
|
||||||
if is_replacement {
|
if is_replacement {
|
||||||
|
|
@ -860,16 +866,17 @@ pub fn imitate_human_edits(
|
||||||
fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
|
fn locate_end_of_last_edit(patch: &Patch) -> Option<CursorPosition> {
|
||||||
let loc = locate_edited_line(patch, -1)?;
|
let loc = locate_edited_line(patch, -1)?;
|
||||||
|
|
||||||
let (line, col) = match &loc.patch_line {
|
let (line, column, line_length) = match &loc.patch_line {
|
||||||
PatchLine::Addition(content) => (loc.target_line_number, content.len()),
|
PatchLine::Addition(content) => (loc.target_line_number, content.len(), content.len()),
|
||||||
PatchLine::Deletion(_) => (loc.target_line_number, 1),
|
PatchLine::Deletion(_) => (loc.target_line_number, 1, 1),
|
||||||
_ => return None,
|
_ => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
Some(CursorPosition {
|
Some(CursorPosition {
|
||||||
file: loc.filename,
|
file: loc.filename,
|
||||||
line,
|
line,
|
||||||
column: col,
|
column,
|
||||||
|
line_length,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -878,7 +885,7 @@ fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
|
||||||
let loc = locate_edited_line(patch, 0)?;
|
let loc = locate_edited_line(patch, 0)?;
|
||||||
|
|
||||||
let hunk = patch.hunks.get(loc.hunk_index)?;
|
let hunk = patch.hunks.get(loc.hunk_index)?;
|
||||||
let column = if loc.line_index_within_hunk > 0 {
|
let line_length = if loc.line_index_within_hunk > 0 {
|
||||||
if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
|
if let Some(prev_line) = hunk.lines.get(loc.line_index_within_hunk - 1) {
|
||||||
let content = match prev_line {
|
let content = match prev_line {
|
||||||
PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
|
PatchLine::Context(s) | PatchLine::Addition(s) | PatchLine::Deletion(s) => s,
|
||||||
|
|
@ -893,32 +900,57 @@ fn locate_beginning_of_first_edit(patch: &Patch) -> Option<CursorPosition> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let line = loc.target_line_number.saturating_sub(1).max(1);
|
let line = loc.target_line_number.saturating_sub(1).max(1);
|
||||||
|
let column = line_length.saturating_sub(1);
|
||||||
|
|
||||||
Some(CursorPosition {
|
Some(CursorPosition {
|
||||||
file: loc.filename,
|
file: loc.filename,
|
||||||
line,
|
line,
|
||||||
column,
|
column,
|
||||||
|
line_length,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sample cursor position according to the following rules:
|
/// Sample cursor position according to the following rules:
|
||||||
/// 1. 50% chance of cursor being at the end of the source patch
|
/// 1. 80% chance of cursor being at the end of the source patch
|
||||||
/// 2. 50% chance of cursor being at the beginning of the target patch
|
/// 2. 20% chance of cursor being at the beginning of the target patch
|
||||||
pub fn sample_cursor_position(patch: &Patch, split_commit: &SplitCommit) -> Option<CursorPosition> {
|
/// 3. 20% chance of adding a jitter offset
|
||||||
// Try end of history first
|
pub fn sample_cursor_position(
|
||||||
|
split_commit: &SplitCommit,
|
||||||
|
rng: &mut dyn rand::RngCore,
|
||||||
|
) -> Option<CursorPosition> {
|
||||||
|
// End of history
|
||||||
let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
|
let src_patch = Patch::parse_unified_diff(&split_commit.source_patch);
|
||||||
if let Some(cursor) = locate_end_of_last_edit(&src_patch) {
|
let src_cursor = locate_end_of_last_edit(&src_patch);
|
||||||
return Some(cursor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try beginning of target
|
// Beginning of target
|
||||||
let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
|
let tgt_patch = Patch::parse_unified_diff(&split_commit.target_patch);
|
||||||
if let Some(cursor) = locate_beginning_of_first_edit(&tgt_patch) {
|
let tgt_cursor = locate_beginning_of_first_edit(&tgt_patch);
|
||||||
return Some(cursor);
|
|
||||||
|
// Randomly pick a cursor position
|
||||||
|
let prefer_source = rng.random_bool(0.8);
|
||||||
|
let mut cursor = if prefer_source {
|
||||||
|
src_cursor.or(tgt_cursor)
|
||||||
|
} else {
|
||||||
|
tgt_cursor.or(src_cursor)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Possible add jitter
|
||||||
|
let should_jitter = rng.random_bool(0.2);
|
||||||
|
if should_jitter {
|
||||||
|
if let Some(cursor) = cursor.as_mut() {
|
||||||
|
let col_offset = rng.random_range(1..=5);
|
||||||
|
if rng.random_bool(0.5) {
|
||||||
|
cursor.column = cursor
|
||||||
|
.column
|
||||||
|
.saturating_add(col_offset)
|
||||||
|
.min(cursor.line_length);
|
||||||
|
} else {
|
||||||
|
cursor.column = cursor.column.saturating_sub(col_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: use the original patch
|
cursor
|
||||||
locate_end_of_last_edit(patch)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get cursor excerpt from the patches.
|
/// Get cursor excerpt from the patches.
|
||||||
|
|
@ -1230,6 +1262,7 @@ Date: Mon Jan 1 00:00:00 2024
|
||||||
file: "src/main.rs".to_string(),
|
file: "src/main.rs".to_string(),
|
||||||
line: 42,
|
line: 42,
|
||||||
column: 10,
|
column: 10,
|
||||||
|
line_length: 80,
|
||||||
};
|
};
|
||||||
assert_eq!(cursor.to_string(), "src/main.rs:42:10");
|
assert_eq!(cursor.to_string(), "src/main.rs:42:10");
|
||||||
}
|
}
|
||||||
|
|
@ -1760,6 +1793,7 @@ index 123..456 789
|
||||||
file: "test.md".to_string(),
|
file: "test.md".to_string(),
|
||||||
line: 1,
|
line: 1,
|
||||||
column: 1, // Byte index 1 is inside '第' (bytes 0..3)
|
column: 1, // Byte index 1 is inside '第' (bytes 0..3)
|
||||||
|
line_length: 80,
|
||||||
};
|
};
|
||||||
|
|
||||||
let source_patch = r#"--- a/test.md
|
let source_patch = r#"--- a/test.md
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue