mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Changes to make ep split-commits work (#46160)
1. `apply_diff` will create a file if the diff says so (by starting with `--- /dev/null`) 2. Update examples format to match recent changes 3. `ep split-commits` can work with a stream of inputs and generate `n` samples per input 4. Unicode handling fixes Release Notes: - N/A
This commit is contained in:
parent
8829f1d1d0
commit
92b0f144c0
4 changed files with 451 additions and 123 deletions
|
|
@ -33,7 +33,9 @@ pub async fn apply_diff(
|
|||
let mut paths = Vec::new();
|
||||
for line in diff_str.lines() {
|
||||
if let DiffLine::OldPath { path } = DiffLine::parse(line) {
|
||||
paths.push(RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc());
|
||||
if path != "/dev/null" {
|
||||
paths.push(RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc());
|
||||
}
|
||||
}
|
||||
}
|
||||
worktree
|
||||
|
|
@ -55,17 +57,41 @@ pub async fn apply_diff(
|
|||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { path, hunk } => {
|
||||
DiffEvent::Hunk {
|
||||
path,
|
||||
hunk,
|
||||
is_new_file,
|
||||
} => {
|
||||
let buffer = match current_file {
|
||||
None => {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(path.as_ref(), cx)
|
||||
.context("no such path")?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
let buffer = if is_new_file {
|
||||
// New file - create it first, then open the buffer
|
||||
let worktree_id = worktree.read_with(cx, |wt, _| wt.id())?;
|
||||
let rel_path =
|
||||
RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?;
|
||||
let project_path = project::ProjectPath {
|
||||
worktree_id,
|
||||
path: rel_path.into_arc(),
|
||||
};
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_entry(project_path.clone(), false, cx)
|
||||
})?
|
||||
.await?;
|
||||
project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?
|
||||
} else {
|
||||
// Existing file - find and open it
|
||||
let project_path = project
|
||||
.update(cx, |project, cx| {
|
||||
project.find_project_path(path.as_ref(), cx)
|
||||
})?
|
||||
.context("no such path")?;
|
||||
project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?
|
||||
};
|
||||
included_files.insert(path.to_string(), buffer.clone());
|
||||
current_file = Some(buffer);
|
||||
current_file.as_ref().unwrap()
|
||||
|
|
@ -75,7 +101,7 @@ pub async fn apply_diff(
|
|||
|
||||
buffer.read_with(cx, |buffer, _| {
|
||||
edits.extend(
|
||||
resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice())
|
||||
resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice(), is_new_file)
|
||||
.with_context(|| format!("Diff:\n{diff_str}"))?,
|
||||
);
|
||||
anyhow::Ok(())
|
||||
|
|
@ -184,7 +210,11 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
|||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
DiffEvent::Hunk {
|
||||
hunk,
|
||||
path: _,
|
||||
is_new_file: _,
|
||||
} => {
|
||||
let hunk_offset = text
|
||||
.find(&hunk.context)
|
||||
.ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?;
|
||||
|
|
@ -210,7 +240,11 @@ pub fn edits_for_diff(content: &str, diff_str: &str) -> Result<Vec<(Range<usize>
|
|||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
DiffEvent::Hunk {
|
||||
hunk,
|
||||
path: _,
|
||||
is_new_file: _,
|
||||
} => {
|
||||
if hunk.context.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
|
@ -259,8 +293,14 @@ struct DiffParser<'a> {
|
|||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum DiffEvent<'a> {
|
||||
Hunk { path: Cow<'a, str>, hunk: Hunk },
|
||||
FileEnd { renamed_to: Option<Cow<'a, str>> },
|
||||
Hunk {
|
||||
path: Cow<'a, str>,
|
||||
hunk: Hunk,
|
||||
is_new_file: bool,
|
||||
},
|
||||
FileEnd {
|
||||
renamed_to: Option<Cow<'a, str>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
|
|
@ -305,9 +345,16 @@ impl<'a> DiffParser<'a> {
|
|||
if let Some(file) = &self.current_file
|
||||
&& !self.hunk.is_empty()
|
||||
{
|
||||
let is_new_file = file.old_path == "/dev/null";
|
||||
let path = if is_new_file {
|
||||
file.new_path.clone()
|
||||
} else {
|
||||
file.old_path.clone()
|
||||
};
|
||||
return Ok(Some(DiffEvent::Hunk {
|
||||
path: file.old_path.clone(),
|
||||
path,
|
||||
hunk: mem::take(&mut self.hunk),
|
||||
is_new_file,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
|
@ -315,7 +362,7 @@ impl<'a> DiffParser<'a> {
|
|||
if file_done {
|
||||
if let Some(PatchFile { old_path, new_path }) = self.current_file.take() {
|
||||
return Ok(Some(DiffEvent::FileEnd {
|
||||
renamed_to: if old_path != new_path {
|
||||
renamed_to: if old_path != new_path && old_path != "/dev/null" {
|
||||
Some(new_path)
|
||||
} else {
|
||||
None
|
||||
|
|
@ -397,8 +444,9 @@ fn resolve_hunk_edits_in_buffer(
|
|||
hunk: Hunk,
|
||||
buffer: &TextBufferSnapshot,
|
||||
ranges: &[Range<Anchor>],
|
||||
is_new_file: bool,
|
||||
) -> Result<impl Iterator<Item = (Range<Anchor>, Arc<str>)>, anyhow::Error> {
|
||||
let context_offset = if hunk.context.is_empty() {
|
||||
let context_offset = if is_new_file || hunk.context.is_empty() {
|
||||
Ok(0)
|
||||
} else {
|
||||
let mut offset = None;
|
||||
|
|
@ -775,7 +823,8 @@ mod tests {
|
|||
range: 4..4,
|
||||
text: "AND\n".into()
|
||||
}],
|
||||
}
|
||||
},
|
||||
is_new_file: false,
|
||||
},
|
||||
DiffEvent::FileEnd { renamed_to: None }
|
||||
],
|
||||
|
|
|
|||
|
|
@ -328,7 +328,9 @@ fn main() {
|
|||
return;
|
||||
}
|
||||
Command::SplitCommit(split_commit_args) => {
|
||||
if let Err(error) = split_commit::run_split_commit(split_commit_args) {
|
||||
if let Err(error) =
|
||||
split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
|
||||
{
|
||||
eprintln!("{error:#}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,7 +150,11 @@ impl ToString for Patch {
|
|||
let mut result = self.header.clone();
|
||||
for hunk in &self.hunks {
|
||||
let current_file = hunk.filename.clone();
|
||||
result.push_str(&format!("--- a/{}\n", current_file));
|
||||
if hunk.is_file_creation() {
|
||||
result.push_str("--- /dev/null\n");
|
||||
} else {
|
||||
result.push_str(&format!("--- a/{}\n", current_file));
|
||||
}
|
||||
result.push_str(&format!("+++ b/{}\n", current_file));
|
||||
result.push_str(&hunk.to_string());
|
||||
}
|
||||
|
|
@ -330,6 +334,11 @@ impl ToString for Hunk {
|
|||
}
|
||||
|
||||
impl Hunk {
|
||||
/// Returns true if this hunk represents a file creation (old side is empty).
|
||||
pub fn is_file_creation(&self) -> bool {
|
||||
self.old_start == 0 && self.old_count == 0
|
||||
}
|
||||
|
||||
/// Render the hunk header
|
||||
pub fn header_string(&self) -> String {
|
||||
format!(
|
||||
|
|
@ -1459,4 +1468,31 @@ mod tests {
|
|||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_creation_diff_header() {
|
||||
// When old_start and old_count are both 0, the file is being created,
|
||||
// so the --- line should be /dev/null instead of a/filename
|
||||
let patch = Patch::parse_unified_diff(indoc! {"
|
||||
--- a/new_file.rs
|
||||
+++ b/new_file.rs
|
||||
@@ -0,0 +1,3 @@
|
||||
+fn main() {
|
||||
+ println!(\"hello\");
|
||||
+}
|
||||
"});
|
||||
|
||||
let actual = patch.to_string();
|
||||
assert_eq!(
|
||||
actual,
|
||||
indoc! {"
|
||||
--- /dev/null
|
||||
+++ b/new_file.rs
|
||||
@@ -0,0 +1,3 @@
|
||||
+fn main() {
|
||||
+ println!(\"hello\");
|
||||
+}
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,31 +5,37 @@
|
|||
//!
|
||||
//! TODO: Port Python code to generate chronologically-ordered commits
|
||||
use crate::reorder_patch::{Patch, PatchLine, extract_edits, locate_edited_line};
|
||||
|
||||
/// Find the largest valid UTF-8 char boundary at or before `index` in `s`.
|
||||
fn floor_char_boundary(s: &str, index: usize) -> usize {
|
||||
if index >= s.len() {
|
||||
s.len()
|
||||
} else if s.is_char_boundary(index) {
|
||||
index
|
||||
} else {
|
||||
// Find the nearest valid character boundary at or before index
|
||||
(0..index)
|
||||
.rev()
|
||||
.find(|&i| s.is_char_boundary(i))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
use anyhow::{Context as _, Result};
|
||||
use clap::Args;
|
||||
use edit_prediction::example_spec::ExampleSpec;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use similar::{DiffTag, TextDiff};
|
||||
use std::collections::BTreeSet;
|
||||
use std::fs;
|
||||
use std::io::{self, Read};
|
||||
use std::io::{self, Write};
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// `ep split-commit` CLI args.
|
||||
#[derive(Debug, Args, Clone)]
|
||||
pub struct SplitCommitArgs {
|
||||
/// Path to the commit file (use "-" for stdin)
|
||||
#[arg(long, short = 'c')]
|
||||
pub commit: String,
|
||||
|
||||
/// Repository URL
|
||||
#[arg(long, short = 'r', default_value_t = String::new())]
|
||||
pub repository_url: String,
|
||||
|
||||
/// Commit hash
|
||||
#[arg(long, default_value_t = String::new())]
|
||||
pub commit_hash: String,
|
||||
|
||||
/// Split point (float 0.0-1.0 for fraction, or integer for index)
|
||||
#[arg(long, short = 's')]
|
||||
pub split_point: Option<String>,
|
||||
|
|
@ -41,6 +47,28 @@ pub struct SplitCommitArgs {
|
|||
/// Pretty-print JSON output
|
||||
#[arg(long, short = 'p')]
|
||||
pub pretty: bool,
|
||||
|
||||
/// Number of samples to generate per commit (samples random split points)
|
||||
#[arg(long, short = 'n')]
|
||||
pub num_samples: Option<usize>,
|
||||
}
|
||||
|
||||
/// Input format for annotated commits (JSON Lines).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AnnotatedCommit {
|
||||
/// Repository path (e.g., "repos/zed")
|
||||
pub repo: String,
|
||||
/// Repository URL (e.g., "https://github.com/zed-industries/zed")
|
||||
pub repo_url: String,
|
||||
/// Commit SHA
|
||||
pub commit_sha: String,
|
||||
/// Chronologically reordered commit diff
|
||||
pub reordered_commit: String,
|
||||
/// Original commit diff
|
||||
pub original_commit: String,
|
||||
/// Whether diff stats match between original and reordered
|
||||
pub diff_stats_match: bool,
|
||||
}
|
||||
|
||||
/// Cursor position in a file.
|
||||
|
|
@ -64,21 +92,6 @@ pub struct SplitCommit {
|
|||
pub target_patch: String,
|
||||
}
|
||||
|
||||
/// The evaluation case structure that will be serialized to JSON.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EvaluationCase {
|
||||
pub repository_url: String,
|
||||
pub commit: String,
|
||||
pub edit_history: Vec<String>,
|
||||
pub cursor_position: String,
|
||||
pub cursor_excerpt: String,
|
||||
pub expected_hunks: Vec<String>,
|
||||
pub expected_patch: String,
|
||||
pub allowed_patch: String,
|
||||
pub expected_context_excerpts: Vec<String>,
|
||||
pub extra: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Split point specification for evaluation generation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SplitPoint {
|
||||
|
|
@ -98,38 +111,122 @@ fn parse_split_point(value: &str) -> Option<SplitPoint> {
|
|||
|
||||
/// Entry point for the `ep split-commit` subcommand.
|
||||
///
|
||||
/// This runs synchronously and prints a single JSON object to stdout.
|
||||
pub fn run_split_commit(args: &SplitCommitArgs) -> Result<()> {
|
||||
let commit = if args.commit == "-" {
|
||||
let mut content = String::new();
|
||||
io::stdin()
|
||||
.read_to_string(&mut content)
|
||||
.context("failed to read commit diff from stdin")?;
|
||||
content
|
||||
/// This runs synchronously and outputs JSON Lines (one output per input line).
|
||||
pub fn run_split_commit(
|
||||
args: &SplitCommitArgs,
|
||||
inputs: &[PathBuf],
|
||||
output_path: Option<&PathBuf>,
|
||||
) -> Result<()> {
|
||||
use std::collections::HashSet;
|
||||
use std::io::BufRead;
|
||||
|
||||
let stdin_path = PathBuf::from("-");
|
||||
let inputs = if inputs.is_empty() {
|
||||
std::slice::from_ref(&stdin_path)
|
||||
} else {
|
||||
fs::read_to_string(&args.commit)
|
||||
.with_context(|| format!("failed to read commit diff from {}", args.commit))?
|
||||
inputs
|
||||
};
|
||||
|
||||
let split_point = args.split_point.as_deref().and_then(parse_split_point);
|
||||
let mut output_lines = Vec::new();
|
||||
|
||||
let case = generate_evaluation_example_from_ordered_commit(
|
||||
&commit,
|
||||
&args.repository_url,
|
||||
&args.commit_hash,
|
||||
split_point,
|
||||
args.seed,
|
||||
)
|
||||
.context("failed to generate evaluation example")?;
|
||||
for input_path in inputs {
|
||||
let input: Box<dyn BufRead> = if input_path.as_os_str() == "-" {
|
||||
Box::new(io::BufReader::new(io::stdin()))
|
||||
} else {
|
||||
let file = fs::File::open(input_path)
|
||||
.with_context(|| format!("failed to open input file {}", input_path.display()))?;
|
||||
Box::new(io::BufReader::new(file))
|
||||
};
|
||||
|
||||
let json = if args.pretty {
|
||||
serde_json::to_string_pretty(&case)
|
||||
} else {
|
||||
serde_json::to_string(&case)
|
||||
for (line_num, line_result) in input.lines().enumerate() {
|
||||
let line =
|
||||
line_result.with_context(|| format!("failed to read line {}", line_num + 1))?;
|
||||
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let annotated: AnnotatedCommit = serde_json::from_str(&line)
|
||||
.with_context(|| format!("failed to parse JSON at line {}", line_num + 1))?;
|
||||
|
||||
// Generate multiple samples if num_samples is set
|
||||
if let Some(num_samples) = args.num_samples {
|
||||
let mut seen_samples: HashSet<String> = HashSet::new();
|
||||
let base_seed = args.seed.unwrap_or_else(|| rand::random());
|
||||
|
||||
for sample_idx in 0..num_samples {
|
||||
let sample_seed = base_seed.wrapping_add(sample_idx as u64);
|
||||
|
||||
let case = generate_evaluation_example_from_ordered_commit(
|
||||
&annotated.reordered_commit,
|
||||
&annotated.repo_url,
|
||||
&annotated.commit_sha,
|
||||
None, // Use random split point for multi-sample mode
|
||||
Some(sample_seed),
|
||||
Some(sample_idx),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to generate evaluation example for commit {} at line {} (sample {})",
|
||||
annotated.commit_sha,
|
||||
line_num + 1,
|
||||
sample_idx
|
||||
)
|
||||
})?;
|
||||
|
||||
let json = if args.pretty {
|
||||
serde_json::to_string_pretty(&case)
|
||||
} else {
|
||||
serde_json::to_string(&case)
|
||||
}
|
||||
.context("failed to serialize evaluation case as JSON")?;
|
||||
|
||||
// Only add unique samples (different split points may produce same result)
|
||||
if seen_samples.insert(json.clone()) {
|
||||
output_lines.push(json);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let case = generate_evaluation_example_from_ordered_commit(
|
||||
&annotated.reordered_commit,
|
||||
&annotated.repo_url,
|
||||
&annotated.commit_sha,
|
||||
split_point.clone(),
|
||||
args.seed,
|
||||
None,
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to generate evaluation example for commit {} at line {}",
|
||||
annotated.commit_sha,
|
||||
line_num + 1
|
||||
)
|
||||
})?;
|
||||
|
||||
let json = if args.pretty {
|
||||
serde_json::to_string_pretty(&case)
|
||||
} else {
|
||||
serde_json::to_string(&case)
|
||||
}
|
||||
.context("failed to serialize evaluation case as JSON")?;
|
||||
|
||||
output_lines.push(json);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let output_content = output_lines.join("\n") + if output_lines.is_empty() { "" } else { "\n" };
|
||||
|
||||
if let Some(path) = output_path {
|
||||
fs::write(path, &output_content)
|
||||
.with_context(|| format!("failed to write output to {}", path.display()))?;
|
||||
} else {
|
||||
io::stdout()
|
||||
.write_all(output_content.as_bytes())
|
||||
.context("failed to write to stdout")?;
|
||||
}
|
||||
.context("failed to serialize evaluation case as JSON")?;
|
||||
|
||||
println!("{json}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -141,13 +238,15 @@ pub fn run_split_commit(args: &SplitCommitArgs) -> Result<()> {
|
|||
/// * `commit_hash` - Hash of the commit
|
||||
/// * `split_point` - Point at which the commit will be split (None for random)
|
||||
/// * `seed` - Optional seed for randomness
|
||||
/// * `sample_num` - Optional sample number for generating unique names
|
||||
pub fn generate_evaluation_example_from_ordered_commit(
|
||||
commit: &str,
|
||||
repository_url: &str,
|
||||
commit_hash: &str,
|
||||
split_point: Option<SplitPoint>,
|
||||
seed: Option<u64>,
|
||||
) -> Result<EvaluationCase> {
|
||||
sample_num: Option<usize>,
|
||||
) -> Result<ExampleSpec> {
|
||||
let mut rng: Box<dyn rand::RngCore> = match seed {
|
||||
Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
|
||||
None => Box::new(rand::rngs::ThreadRng::default()),
|
||||
|
|
@ -222,17 +321,29 @@ pub fn generate_evaluation_example_from_ordered_commit(
|
|||
split_commit.target_patch = String::new();
|
||||
}
|
||||
|
||||
Ok(EvaluationCase {
|
||||
let repo_name = repository_url
|
||||
.trim_end_matches('/')
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("unknown");
|
||||
let short_sha = &commit_hash[..commit_hash.len().min(8)];
|
||||
let name = match sample_num {
|
||||
Some(n) => format!("{}-{}-{}", repo_name, short_sha, n),
|
||||
None => format!("{}-{}", repo_name, short_sha),
|
||||
};
|
||||
|
||||
Ok(ExampleSpec {
|
||||
name,
|
||||
repository_url: repository_url.to_string(),
|
||||
commit: format!("{}~1", commit_hash),
|
||||
edit_history: vec![split_commit.source_patch.clone()],
|
||||
cursor_position: cursor.to_string(),
|
||||
cursor_excerpt,
|
||||
expected_hunks: vec![split_commit.target_patch.clone()],
|
||||
expected_patch: split_commit.target_patch.clone(),
|
||||
allowed_patch: split_commit.target_patch,
|
||||
expected_context_excerpts: vec![],
|
||||
extra: serde_json::json!({}),
|
||||
revision: format!("{}~1", commit_hash),
|
||||
edit_history: split_commit.source_patch.clone(),
|
||||
// cursor_position: cursor.to_string(),
|
||||
cursor_path: Path::new(&cursor.file).into(),
|
||||
cursor_position: cursor_excerpt,
|
||||
expected_patches: vec![split_commit.target_patch],
|
||||
tags: vec![],
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -591,11 +702,13 @@ pub fn imitate_human_edits(
|
|||
// Split within this replace operation
|
||||
let offset = split_index - edit_index;
|
||||
if offset < ins.len() {
|
||||
new_src.push_str(&ins[..offset]);
|
||||
let safe_offset = floor_char_boundary(&ins, offset);
|
||||
new_src.push_str(&ins[..safe_offset]);
|
||||
} else {
|
||||
new_src.push_str(&ins);
|
||||
let del_offset = offset - ins.len();
|
||||
new_src.push_str(&del[..del_offset.min(del.len())]);
|
||||
let safe_del_offset = floor_char_boundary(&del, del_offset.min(del.len()));
|
||||
new_src.push_str(&del[..safe_del_offset]);
|
||||
}
|
||||
split_found = true;
|
||||
last_old_end = op.old_range().end;
|
||||
|
|
@ -610,7 +723,8 @@ pub fn imitate_human_edits(
|
|||
let repl: String = op.new_range().map(|i| tgt_tokens[i].as_str()).collect();
|
||||
if edit_index + repl.len() >= split_index {
|
||||
let offset = split_index - edit_index;
|
||||
new_src.push_str(&repl[..offset]);
|
||||
let safe_offset = floor_char_boundary(&repl, offset);
|
||||
new_src.push_str(&repl[..safe_offset]);
|
||||
split_found = true;
|
||||
break;
|
||||
} else {
|
||||
|
|
@ -622,9 +736,10 @@ pub fn imitate_human_edits(
|
|||
let repl: String = op.old_range().map(|i| src_tokens[i].as_str()).collect();
|
||||
if edit_index + repl.len() >= split_index {
|
||||
let offset = split_index - edit_index;
|
||||
new_src.push_str(&repl[..offset]);
|
||||
let safe_offset = floor_char_boundary(&repl, offset);
|
||||
new_src.push_str(&repl[..safe_offset]);
|
||||
split_found = true;
|
||||
last_old_end = op.old_range().start + offset.min(op.old_range().len());
|
||||
last_old_end = op.old_range().start + safe_offset.min(op.old_range().len());
|
||||
break;
|
||||
} else {
|
||||
edit_index += repl.len();
|
||||
|
|
@ -685,15 +800,25 @@ pub fn imitate_human_edits(
|
|||
}
|
||||
} else {
|
||||
// For pure insertions, we need to add or modify a hunk
|
||||
if let Some(hunk) = new_src_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
|
||||
// Insert the partial line at the same position as target
|
||||
hunk.lines.insert(
|
||||
tgt_edit_loc.line_index_within_hunk,
|
||||
PatchLine::Addition(new_src.clone()),
|
||||
);
|
||||
hunk.new_count += 1;
|
||||
} else if new_src_patch.hunks.is_empty() {
|
||||
// Source patch is empty, create a new hunk based on target
|
||||
// Check if the source hunk exists AND has enough lines for the target's line index
|
||||
let can_insert_in_existing_hunk = new_src_patch
|
||||
.hunks
|
||||
.get(tgt_edit_loc.hunk_index)
|
||||
.map_or(false, |hunk| {
|
||||
tgt_edit_loc.line_index_within_hunk <= hunk.lines.len()
|
||||
});
|
||||
|
||||
if can_insert_in_existing_hunk {
|
||||
if let Some(hunk) = new_src_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
|
||||
// Insert the partial line at the same position as target
|
||||
hunk.lines.insert(
|
||||
tgt_edit_loc.line_index_within_hunk,
|
||||
PatchLine::Addition(new_src.clone()),
|
||||
);
|
||||
hunk.new_count += 1;
|
||||
}
|
||||
} else {
|
||||
// Source patch is empty or has incompatible hunk structure, create a new hunk based on target
|
||||
if let Some(tgt_hunk) = tgt_patch.hunks.get(tgt_edit_loc.hunk_index) {
|
||||
let mut new_hunk = tgt_hunk.clone();
|
||||
// Replace the full addition with the partial one
|
||||
|
|
@ -827,6 +952,18 @@ pub fn get_cursor_excerpt(
|
|||
_ => {}
|
||||
}
|
||||
}
|
||||
// If hunk only has deletions (file deletion), include deletion lines
|
||||
if excerpt_lines.is_empty() {
|
||||
excerpt_first_line = hunk.old_start as usize;
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
PatchLine::Deletion(s) => {
|
||||
excerpt_lines.push(s.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -834,19 +971,68 @@ pub fn get_cursor_excerpt(
|
|||
// Search in target patch if not found
|
||||
if excerpt_lines.is_empty() {
|
||||
let tgt = Patch::parse_unified_diff(target_patch);
|
||||
if let Some(loc) = locate_edited_line(&tgt, 0) {
|
||||
if loc.filename == cursor.file {
|
||||
if let Some(hunk) = tgt.hunks.get(loc.hunk_index) {
|
||||
excerpt_first_line = hunk.new_start as usize;
|
||||
// Search all hunks for the cursor file, not just the first edit's hunk
|
||||
for hunk in &tgt.hunks {
|
||||
if hunk.filename == cursor.file {
|
||||
excerpt_first_line = hunk.new_start as usize;
|
||||
// First try to collect deletions and context (what exists before edits)
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
PatchLine::Deletion(s) | PatchLine::Context(s) => {
|
||||
excerpt_lines.push(s.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// If hunk only has additions (no deletions/context), include all lines
|
||||
// This handles cases like adding to an empty file or section
|
||||
if excerpt_lines.is_empty() {
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
PatchLine::Deletion(s) | PatchLine::Context(s) => {
|
||||
PatchLine::Addition(s)
|
||||
| PatchLine::Deletion(s)
|
||||
| PatchLine::Context(s) => {
|
||||
excerpt_lines.push(s.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !excerpt_lines.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also search source patch hunks if still not found (for fallback cursor case)
|
||||
if excerpt_lines.is_empty() {
|
||||
for hunk in &src.hunks {
|
||||
if hunk.filename == cursor.file {
|
||||
excerpt_first_line = hunk.new_start as usize;
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
PatchLine::Addition(s) | PatchLine::Context(s) => {
|
||||
excerpt_lines.push(s.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// If hunk only has deletions, include deletion lines
|
||||
if excerpt_lines.is_empty() {
|
||||
excerpt_first_line = hunk.old_start as usize;
|
||||
for line in &hunk.lines {
|
||||
match line {
|
||||
PatchLine::Deletion(s) => {
|
||||
excerpt_lines.push(s.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !excerpt_lines.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -860,6 +1046,16 @@ pub fn get_cursor_excerpt(
|
|||
let line_num = excerpt_first_line + i;
|
||||
if line_num == cursor.line {
|
||||
let col = cursor.column.min(line.len());
|
||||
// Ensure we split at a valid UTF-8 character boundary
|
||||
let col = if line.is_char_boundary(col) {
|
||||
col
|
||||
} else {
|
||||
// Find the nearest valid character boundary
|
||||
(0..=col)
|
||||
.rev()
|
||||
.find(|&i| line.is_char_boundary(i))
|
||||
.unwrap_or(0)
|
||||
};
|
||||
let (before, after) = line.split_at(col);
|
||||
*line = format!("{}<|user_cursor|>{}", before, after);
|
||||
break;
|
||||
|
|
@ -871,6 +1067,10 @@ pub fn get_cursor_excerpt(
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use edit_prediction::example_spec::ExampleSpec;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
|
@ -979,12 +1179,13 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"abc123",
|
||||
Some(SplitPoint::Fraction(0.5)),
|
||||
Some(42),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let case = result.unwrap();
|
||||
assert_eq!(case.repository_url, "https://github.com/test/repo");
|
||||
assert_eq!(case.commit, "abc123~1");
|
||||
assert_eq!(case.revision, "abc123~1");
|
||||
assert!(!case.edit_history.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -1009,6 +1210,7 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"abc123",
|
||||
Some(SplitPoint::Fraction(0.5)),
|
||||
Some(12345),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1018,12 +1220,13 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"abc123",
|
||||
Some(SplitPoint::Fraction(0.5)),
|
||||
Some(12345),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Results should be identical
|
||||
assert_eq!(result1.edit_history, result2.edit_history);
|
||||
assert_eq!(result1.expected_patch, result2.expected_patch);
|
||||
assert_eq!(result1.expected_patches, result2.expected_patches);
|
||||
assert_eq!(result1.cursor_position, result2.cursor_position);
|
||||
}
|
||||
|
||||
|
|
@ -1085,13 +1288,14 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"hash",
|
||||
Some(SplitPoint::Fraction(0.2)),
|
||||
Some(1),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let case = result.unwrap();
|
||||
|
||||
// Source should have some edits
|
||||
let src_patch = Patch::parse_unified_diff(&case.edit_history[0]);
|
||||
let src_patch = Patch::parse_unified_diff(&case.edit_history);
|
||||
assert!(src_patch.stats().added > 0);
|
||||
}
|
||||
|
||||
|
|
@ -1118,12 +1322,13 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"hash",
|
||||
Some(SplitPoint::Index(2)),
|
||||
Some(1),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let case = result.unwrap();
|
||||
|
||||
let src_patch = Patch::parse_unified_diff(&case.edit_history[0]);
|
||||
let src_patch = Patch::parse_unified_diff(&case.edit_history);
|
||||
// Pure insertion adds a partial line, so we expect 3 (2 original + 1 partial)
|
||||
assert_eq!(src_patch.stats().added, 3);
|
||||
}
|
||||
|
|
@ -1148,37 +1353,39 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"hash",
|
||||
Some(SplitPoint::Fraction(0.5)),
|
||||
Some(42),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Cursor excerpt should contain the cursor marker
|
||||
assert!(
|
||||
result.cursor_excerpt.contains("<|user_cursor|>"),
|
||||
result.cursor_position.contains("<|user_cursor|>"),
|
||||
"Cursor excerpt should contain marker: {}",
|
||||
result.cursor_excerpt
|
||||
result.cursor_position
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluation_case_json_serialization() {
|
||||
let case = EvaluationCase {
|
||||
let case = ExampleSpec {
|
||||
name: "test-abc123".to_string(),
|
||||
repository_url: "https://github.com/test/repo".to_string(),
|
||||
commit: "abc123~1".to_string(),
|
||||
edit_history: vec!["patch1".to_string()],
|
||||
cursor_position: "file.rs:10:5".to_string(),
|
||||
cursor_excerpt: "some code<|user_cursor|>".to_string(),
|
||||
expected_hunks: vec!["hunk1".to_string()],
|
||||
expected_patch: "patch".to_string(),
|
||||
allowed_patch: "patch".to_string(),
|
||||
expected_context_excerpts: vec![],
|
||||
extra: serde_json::json!({}),
|
||||
revision: "abc123~1".to_string(),
|
||||
edit_history: "patch1".to_string(),
|
||||
// cursor_position: "file.rs:10:5".to_string(),
|
||||
cursor_path: Path::new("file.rs").into(),
|
||||
cursor_position: "some code<|user_cursor|>".to_string(),
|
||||
expected_patches: vec!["patch".to_string()],
|
||||
tags: vec![],
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&case).unwrap();
|
||||
let deserialized: EvaluationCase = serde_json::from_str(&json).unwrap();
|
||||
let deserialized: ExampleSpec = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(case.repository_url, deserialized.repository_url);
|
||||
assert_eq!(case.commit, deserialized.commit);
|
||||
assert_eq!(case.revision, deserialized.revision);
|
||||
assert_eq!(case.cursor_position, deserialized.cursor_position);
|
||||
}
|
||||
|
||||
|
|
@ -1192,6 +1399,7 @@ Date: Mon Jan 1 00:00:00 2024
|
|||
"hash",
|
||||
Some(SplitPoint::Fraction(0.5)),
|
||||
Some(1),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
|
|
@ -1224,6 +1432,7 @@ index 123..456 789
|
|||
"hash",
|
||||
Some(SplitPoint::Index(1)),
|
||||
Some(1),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1231,8 +1440,8 @@ index 123..456 789
|
|||
|
||||
// The edit history should contain the group header (// lines)
|
||||
// but not the commit metadata
|
||||
assert!(!case.edit_history[0].contains("Author:"));
|
||||
assert!(!case.edit_history[0].contains("Date:"));
|
||||
assert!(!case.edit_history.contains("Author:"));
|
||||
assert!(!case.edit_history.contains("Date:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1462,4 +1671,36 @@ index 123..456 789
|
|||
"At least one seed should produce a partial intermediate state"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cursor_excerpt_with_multibyte_utf8() {
|
||||
// Test that cursor excerpt handles multi-byte UTF-8 characters correctly
|
||||
// The Chinese character '第' is 3 bytes (0..3)
|
||||
let cursor = CursorPosition {
|
||||
file: "test.md".to_string(),
|
||||
line: 1,
|
||||
column: 1, // Byte index 1 is inside '第' (bytes 0..3)
|
||||
};
|
||||
|
||||
let source_patch = r#"--- a/test.md
|
||||
+++ b/test.md
|
||||
@@ -1,1 +1,1 @@
|
||||
+第 14 章 Flask 工作原理与机制解析**
|
||||
"#;
|
||||
|
||||
let target_patch = "";
|
||||
|
||||
// This should not panic even though column=1 is not a char boundary
|
||||
let result = get_cursor_excerpt(&cursor, source_patch, target_patch);
|
||||
|
||||
// The function should handle the invalid byte index gracefully
|
||||
if let Some(excerpt) = result {
|
||||
assert!(
|
||||
excerpt.contains("<|user_cursor|>"),
|
||||
"Cursor excerpt should contain marker"
|
||||
);
|
||||
// The marker should be placed at a valid character boundary
|
||||
// (either at the start or after '第')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue