diff --git a/crates/edit_prediction_cli/src/split_commit.rs b/crates/edit_prediction_cli/src/split_commit.rs index b81dc21b4da..20c23d7a023 100644 --- a/crates/edit_prediction_cli/src/split_commit.rs +++ b/crates/edit_prediction_cli/src/split_commit.rs @@ -394,7 +394,12 @@ pub fn generate_evaluation_example_from_ordered_commit( pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) { let patch = Patch::parse_unified_diff(commit); let source_edits: BTreeSet = (0..split_pos).collect(); - let (source, target) = extract_edits(&patch, &source_edits); + let (source, mut target) = extract_edits(&patch, &source_edits); + if !target.hunks.is_empty() { + if let Some(header) = header_for_edit(&patch, split_pos) { + target.header = header; + } + } let mut source_str = source.to_string(); let target_str = target.to_string(); @@ -419,6 +424,47 @@ pub fn split_ordered_commit(commit: &str, split_pos: usize) -> (String, String) (source_str, target_str) } +fn header_for_edit(patch: &Patch, edit_index: usize) -> Option { + let edit_index = edit_index.try_into().ok()?; + let edit_location = locate_edited_line(patch, edit_index)?; + header_for_hunk(patch, edit_location.hunk_index) +} + +fn header_for_hunk(patch: &Patch, hunk_index: usize) -> Option { + for hunk in patch.hunks.get(..hunk_index)?.iter().rev() { + let mut header_lines = Vec::new(); + for line in hunk.lines.iter().rev() { + let PatchLine::Garbage(line) = line else { + break; + }; + if line.trim().is_empty() && header_lines.is_empty() { + continue; + } + if !line.starts_with("//") { + break; + } + header_lines.push(line.as_str()); + } + if !header_lines.is_empty() { + return Some(render_reversed_header_lines(header_lines)); + } + } + + let header_lines = patch + .header + .lines() + .rev() + .skip_while(|line| line.trim().is_empty()) + .take_while(|line| line.starts_with("//")) + .collect::>(); + (!header_lines.is_empty()).then(|| render_reversed_header_lines(header_lines)) +} + +fn render_reversed_header_lines(mut lines: Vec<&str>) -> String { + lines.reverse(); + lines.join("\n") + "\n" +} + /// Calculate the weight for a split position based on the character at that position. /// /// Higher weights indicate more natural pause points (e.g., after punctuation, @@ -1179,6 +1225,55 @@ mod tests { assert_eq!(tgt_patch.stats().added, 1); } + #[test] + fn test_split_ordered_commit_target_header_continues_current_group() { + let commit = r#"//////////////////////////////////////////////////////////////////////////////// +// Update dependency version +//////////////////////////////////////////////////////////////////////////////// +--- a/go.mod ++++ b/go.mod +@@ -1,3 +1,3 @@ + require ( +- gopkg.in/yaml.v3 v3.0.0 // indirect ++ gopkg.in/yaml.v3 v3.0.1 // indirect + ) +diff --git a/go.sum b/go.sum +index f71a068..b8cc3c2 100644 +//////////////////////////////////////////////////////////////////////////////// +// Update go.sum checksums +//////////////////////////////////////////////////////////////////////////////// +--- a/go.sum ++++ b/go.sum +@@ -1,3 +1,5 @@ + gopkg.in/yaml.v3 v3.0.0 h1:old + gopkg.in/yaml.v3 v3.0.0/go.mod h1:oldmod ++gopkg.in/yaml.v3 v3.0.1 h1:new ++gopkg.in/yaml.v3 v3.0.1/go.mod h1:newmod +diff --git a/lib/handler.go b/lib/handler.go +index 1827a70..d9b3ed1 100644 +//////////////////////////////////////////////////////////////////////////////// +// Fix error wrapping +//////////////////////////////////////////////////////////////////////////////// +--- a/lib/handler.go ++++ b/lib/handler.go +@@ -1,3 +1,3 @@ +- return fmt.Errorf("failed: %s", err) ++ return fmt.Errorf("failed: %w", err) +"#; + + let (_source, target) = split_ordered_commit(commit, 3); + + assert!( + target.starts_with( + "////////////////////////////////////////////////////////////////////////////////\n// Update go.sum checksums\n////////////////////////////////////////////////////////////////////////////////\n" + ), + "target patch should continue with the active group header:\n{target}" + ); + assert!(!target.starts_with( + "////////////////////////////////////////////////////////////////////////////////\n// Update dependency version\n////////////////////////////////////////////////////////////////////////////////\n" + )); + } + #[test] fn test_generate_evaluation_example() { let commit = r#"commit abc123