Add ep split subcommand for dataset splitting (#46364)

Adds a new `ep split` command that splits JSONL datasets into multiple
output files with stratification by `repository_url` when present.

Example usage:

  ep split input.jsonl train.jsonl=80% valid.jsonl=rest

Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2026-01-08 15:31:26 +02:00 committed by GitHub
parent 1fa9dd6284
commit 11cfdb1e62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 537 additions and 11 deletions

1
Cargo.lock generated
View file

@ -5308,6 +5308,7 @@ dependencies = [
"smol",
"sqlez",
"sqlez_macros",
"tempfile",
"terminal_view",
"util",
"wasmtime",

View file

@ -71,3 +71,4 @@ indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
tempfile.workspace = true

View file

@ -14,6 +14,7 @@ mod reorder_patch;
mod retrieve_context;
mod score;
mod split_commit;
mod split_dataset;
mod synthesize;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use collections::HashSet;
@ -40,6 +41,7 @@ use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
use crate::split_commit::SplitCommitArgs;
use crate::split_dataset::SplitArgs;
use crate::synthesize::{SynthesizeConfig, run_synthesize};
#[derive(Parser, Debug)]
@ -124,6 +126,8 @@ enum Command {
Clean,
/// Generate an evaluation example by splitting a chronologically-ordered commit
SplitCommit(SplitCommitArgs),
/// Split a JSONL dataset into multiple files (stratified by repository_url if present)
Split(SplitArgs),
}
impl Display for Command {
@ -178,6 +182,7 @@ impl Display for Command {
}
Command::Clean => write!(f, "clean"),
Command::SplitCommit(_) => write!(f, "split-commit"),
Command::Split(_) => write!(f, "split"),
}
}
}
@ -416,6 +421,13 @@ fn main() {
}
return;
}
Command::Split(split_args) => {
if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
eprintln!("{error:#}");
std::process::exit(1);
}
return;
}
_ => {}
}
@ -509,7 +521,8 @@ fn main() {
}
Command::Clean
| Command::Synthesize(_)
| Command::SplitCommit(_) => {
| Command::SplitCommit(_)
| Command::Split(_) => {
unreachable!()
}
}
@ -603,17 +616,12 @@ async fn handle_error(
indoc::indoc! {"
While processing \"{}\":
{:?}
\x1b[31m{:?}\x1b[0m
Written to: \x1b[36m{}\x1b[0m
Cursor File: \x1b[36m{}\x1b[0m
Explore this example data with:
fx \x1b[36m{}\x1b[0m
Re-run this example with:
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
Example: \x1b[36m{}\x1b[0m
Error file: \x1b[36m{}\x1b[0m
Cursor file: \x1b[36m{}\x1b[0m
Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
"},
example.spec.name,
error,

View file

@ -0,0 +1,516 @@
//! `ep split` implementation.
//!
//! This command splits a JSONL dataset into multiple files based on size specifications,
//! with stratification by repository URL (if the field is present).
//!
//! # Usage
//!
//! ```text
//! ep split [input.jsonl] <out1>=<size1> <out2>=<size2> ...
//! ```
//!
//! If `input.jsonl` is not provided or is `-`, reads from stdin.
//!
//! # Size specifications
//!
//! - `80%` - percentage of total (repositories if stratified, examples otherwise)
//! - `100` - absolute count of repositories (if stratified) or examples
//! - `rest` - all remaining items (only one split can use this)
//!
//! # Stratification
//!
//! When examples have a `repository_url` field, the split is stratified by repository.
//! This ensures each output file contains examples from non-overlapping repositories.
//! Size specifications apply to the number of repositories, not individual examples.
//!
//! Examples without `repository_url` are distributed proportionally across all outputs.
use anyhow::{Context as _, Result, bail};
use clap::Args;
use rand::SeedableRng;
use rand::seq::SliceRandom;
use serde_json::Value;
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
/// `ep split` CLI args.
#[derive(Debug, Args, Clone)]
#[command(
about = "Split a JSONL dataset into multiple files (stratified by repository_url if present)",
after_help = r#"SIZE SPECIFICATIONS:
<percentage>% Percentage of total (e.g., 80%)
<count> Absolute number (e.g., 100)
rest All remaining items (only one output can use this)
When stratifying by repository_url, sizes apply to repositories, not examples.
EXAMPLES:
# Split 80% train, 20% validation
ep split input.jsonl train.jsonl=80% valid.jsonl=rest
# Split into train/valid/test
ep split input.jsonl train.jsonl=80% valid.jsonl=10% test.jsonl=rest
# Use absolute counts (100 repos to train, rest to valid)
ep split input.jsonl train.jsonl=100 valid.jsonl=rest
# Read from stdin
cat input.jsonl | ep split train.jsonl=80% valid.jsonl=rest
# Reproducible split with seed
ep split --seed 42 input.jsonl train.jsonl=80% valid.jsonl=rest
STRATIFICATION:
When examples have a "repository_url" field, the split ensures each output
file contains examples from non-overlapping repositories. This prevents
data leakage between train/test splits.
"#
)]
pub struct SplitArgs {
/// Random seed for reproducibility
#[arg(long)]
pub seed: Option<u64>,
}
#[derive(Debug, Clone)]
pub enum SplitSize {
Percentage(f64),
Absolute(usize),
Rest,
}
#[derive(Debug, Clone)]
pub struct SplitSpec {
pub path: PathBuf,
pub size: SplitSize,
}
fn parse_split_spec(spec: &str) -> Result<SplitSpec> {
let (path, size_str) = spec
.rsplit_once('=')
.with_context(|| format!("invalid split spec '{}': expected <path>=<size>", spec))?;
let size = if size_str == "rest" {
SplitSize::Rest
} else if size_str.ends_with('%') {
let pct_str = size_str.trim_end_matches('%');
let pct: f64 = pct_str
.parse()
.with_context(|| format!("invalid percentage '{}' in '{}'", pct_str, spec))?;
if !(0.0..=100.0).contains(&pct) {
bail!("percentage must be between 0 and 100, got {}", pct);
}
SplitSize::Percentage(pct / 100.0)
} else {
let count: usize = size_str
.parse()
.with_context(|| format!("invalid count '{}' in '{}'", size_str, spec))?;
SplitSize::Absolute(count)
};
Ok(SplitSpec {
path: PathBuf::from(path),
size,
})
}
fn read_lines_from_input(input: Option<&Path>) -> Result<Vec<String>> {
let reader: Box<dyn BufRead> = match input {
Some(path) => {
let file =
File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
Box::new(BufReader::new(file))
}
None => Box::new(BufReader::new(io::stdin())),
};
let lines: Vec<String> = reader
.lines()
.collect::<io::Result<Vec<_>>>()
.context("failed to read input lines")?;
Ok(lines)
}
fn get_repository_url(line: &str) -> Option<String> {
let value: Value = serde_json::from_str(line).ok()?;
value
.get("repository_url")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
}
fn group_lines_by_repo(lines: Vec<String>) -> (HashMap<String, Vec<String>>, Vec<String>) {
let mut by_repo: HashMap<String, Vec<String>> = HashMap::new();
let mut without_repo: Vec<String> = Vec::new();
for line in lines {
if let Some(repo_url) = get_repository_url(&line) {
by_repo.entry(repo_url).or_default().push(line);
} else {
without_repo.push(line);
}
}
(by_repo, without_repo)
}
fn compute_split_counts(specs: &[SplitSpec], total: usize) -> Result<Vec<usize>> {
let mut counts = vec![0usize; specs.len()];
let mut remaining = total;
let mut rest_index: Option<usize> = None;
for (i, spec) in specs.iter().enumerate() {
match &spec.size {
SplitSize::Percentage(pct) => {
let count = (total as f64 * pct).round() as usize;
counts[i] = count.min(remaining);
remaining = remaining.saturating_sub(counts[i]);
}
SplitSize::Absolute(count) => {
counts[i] = (*count).min(remaining);
remaining = remaining.saturating_sub(counts[i]);
}
SplitSize::Rest => {
if rest_index.is_some() {
bail!("only one split can use 'rest'");
}
rest_index = Some(i);
}
}
}
if let Some(idx) = rest_index {
counts[idx] = remaining;
}
Ok(counts)
}
fn write_lines_to_file(path: &Path, lines: &[String]) -> Result<()> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)
.with_context(|| format!("failed to create directory '{}'", parent.display()))?;
}
}
let file =
File::create(path).with_context(|| format!("failed to create '{}'", path.display()))?;
let mut writer = BufWriter::new(file);
for line in lines {
writeln!(writer, "{}", line)
.with_context(|| format!("failed to write to '{}'", path.display()))?;
}
writer
.flush()
.with_context(|| format!("failed to flush '{}'", path.display()))?;
Ok(())
}
pub fn run_split(args: &SplitArgs, inputs: &[PathBuf]) -> Result<()> {
if inputs.is_empty() {
bail!("usage: ep split [input.jsonl] train.jsonl=80% valid.jsonl=rest");
}
let (input_path, split_specs_raw): (Option<&Path>, &[PathBuf]) =
if inputs.first().is_some_and(|p| {
let s = p.to_string_lossy();
!s.contains('=')
}) {
let first = inputs.first().map(|p| p.as_path());
let first = if first == Some(Path::new("-")) {
None
} else {
first
};
(first, &inputs[1..])
} else {
(None, inputs)
};
if split_specs_raw.is_empty() {
bail!("no split specifications provided");
}
let specs: Vec<SplitSpec> = split_specs_raw
.iter()
.map(|p| parse_split_spec(&p.to_string_lossy()))
.collect::<Result<Vec<_>>>()?;
let lines = read_lines_from_input(input_path)?;
let total_lines = lines.len();
if total_lines == 0 {
for spec in &specs {
write_lines_to_file(&spec.path, &[])?;
}
return Ok(());
}
let (by_repo, without_repo) = group_lines_by_repo(lines);
let has_repos = !by_repo.is_empty();
if has_repos {
eprintln!(
"Stratifying by repository_url ({} unique repositories, {} examples)",
by_repo.len(),
total_lines - without_repo.len()
);
if !without_repo.is_empty() {
eprintln!(
" + {} examples without repository_url (distributed proportionally)",
without_repo.len()
);
}
}
let mut rng = match args.seed {
Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
None => rand::rngs::StdRng::from_os_rng(),
};
let mut split_outputs: Vec<Vec<String>> = vec![Vec::new(); specs.len()];
if has_repos {
let mut repos: Vec<String> = by_repo.keys().cloned().collect();
repos.shuffle(&mut rng);
let repo_counts = compute_split_counts(&specs, repos.len())?;
let mut repo_iter = repos.into_iter();
for (split_idx, &count) in repo_counts.iter().enumerate() {
for _ in 0..count {
if let Some(repo) = repo_iter.next() {
if let Some(repo_lines) = by_repo.get(&repo) {
split_outputs[split_idx].extend(repo_lines.iter().cloned());
}
}
}
}
if !without_repo.is_empty() {
let no_repo_counts = compute_split_counts(&specs, without_repo.len())?;
let mut no_repo_shuffled = without_repo;
no_repo_shuffled.shuffle(&mut rng);
let mut line_iter = no_repo_shuffled.into_iter();
for (split_idx, &count) in no_repo_counts.iter().enumerate() {
for _ in 0..count {
if let Some(line) = line_iter.next() {
split_outputs[split_idx].push(line);
}
}
}
}
} else {
let line_counts = compute_split_counts(&specs, total_lines)?;
let mut shuffled_lines = without_repo;
shuffled_lines.shuffle(&mut rng);
let mut line_iter = shuffled_lines.into_iter();
for (split_idx, &count) in line_counts.iter().enumerate() {
for _ in 0..count {
if let Some(line) = line_iter.next() {
split_outputs[split_idx].push(line);
}
}
}
}
for (spec, output_lines) in specs.iter().zip(split_outputs.iter()) {
write_lines_to_file(&spec.path, output_lines)?;
eprintln!("{}: {} examples", spec.path.display(), output_lines.len());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_temp_jsonl(lines: &[&str]) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
for line in lines {
writeln!(file, "{}", line).unwrap();
}
file.flush().unwrap();
file
}
#[test]
fn test_parse_split_spec_percentage() {
let spec = parse_split_spec("train.jsonl=80%").unwrap();
assert_eq!(spec.path, PathBuf::from("train.jsonl"));
match spec.size {
SplitSize::Percentage(p) => assert!((p - 0.8).abs() < 0.001),
_ => panic!("expected percentage"),
}
}
#[test]
fn test_parse_split_spec_absolute() {
let spec = parse_split_spec("test.jsonl=100").unwrap();
assert_eq!(spec.path, PathBuf::from("test.jsonl"));
match spec.size {
SplitSize::Absolute(n) => assert_eq!(n, 100),
_ => panic!("expected absolute"),
}
}
#[test]
fn test_parse_split_spec_rest() {
let spec = parse_split_spec("valid.jsonl=rest").unwrap();
assert_eq!(spec.path, PathBuf::from("valid.jsonl"));
assert!(matches!(spec.size, SplitSize::Rest));
}
#[test]
fn test_get_repository_url() {
let line = r#"{"repository_url": "https://github.com/example/repo", "data": 123}"#;
assert_eq!(
get_repository_url(line),
Some("https://github.com/example/repo".to_string())
);
let line_no_repo = r#"{"data": 123}"#;
assert_eq!(get_repository_url(line_no_repo), None);
}
#[test]
fn test_compute_split_counts_percentage() {
let specs = vec![
SplitSpec {
path: PathBuf::from("a"),
size: SplitSize::Percentage(0.8),
},
SplitSpec {
path: PathBuf::from("b"),
size: SplitSize::Percentage(0.2),
},
];
let counts = compute_split_counts(&specs, 100).unwrap();
assert_eq!(counts, vec![80, 20]);
}
#[test]
fn test_compute_split_counts_with_rest() {
let specs = vec![
SplitSpec {
path: PathBuf::from("a"),
size: SplitSize::Percentage(0.8),
},
SplitSpec {
path: PathBuf::from("b"),
size: SplitSize::Rest,
},
];
let counts = compute_split_counts(&specs, 100).unwrap();
assert_eq!(counts, vec![80, 20]);
}
#[test]
fn test_compute_split_counts_absolute() {
let specs = vec![
SplitSpec {
path: PathBuf::from("a"),
size: SplitSize::Absolute(50),
},
SplitSpec {
path: PathBuf::from("b"),
size: SplitSize::Rest,
},
];
let counts = compute_split_counts(&specs, 100).unwrap();
assert_eq!(counts, vec![50, 50]);
}
#[test]
fn test_group_lines_by_repo() {
let lines = vec![
r#"{"repository_url": "repo1", "id": 1}"#.to_string(),
r#"{"repository_url": "repo1", "id": 2}"#.to_string(),
r#"{"repository_url": "repo2", "id": 3}"#.to_string(),
r#"{"id": 4}"#.to_string(),
];
let (by_repo, without_repo) = group_lines_by_repo(lines);
assert_eq!(by_repo.len(), 2);
assert_eq!(by_repo.get("repo1").unwrap().len(), 2);
assert_eq!(by_repo.get("repo2").unwrap().len(), 1);
assert_eq!(without_repo.len(), 1);
}
#[test]
fn test_run_split_basic() {
let input = create_temp_jsonl(&[
r#"{"repository_url": "repo1", "id": 1}"#,
r#"{"repository_url": "repo1", "id": 2}"#,
r#"{"repository_url": "repo2", "id": 3}"#,
r#"{"repository_url": "repo2", "id": 4}"#,
r#"{"repository_url": "repo3", "id": 5}"#,
r#"{"repository_url": "repo3", "id": 6}"#,
r#"{"repository_url": "repo4", "id": 7}"#,
r#"{"repository_url": "repo4", "id": 8}"#,
]);
let temp_dir = tempfile::tempdir().unwrap();
let train_path = temp_dir.path().join("train.jsonl");
let valid_path = temp_dir.path().join("valid.jsonl");
let args = SplitArgs { seed: Some(42) };
let inputs = vec![
input.path().to_path_buf(),
PathBuf::from(format!("{}=50%", train_path.display())),
PathBuf::from(format!("{}=rest", valid_path.display())),
];
run_split(&args, &inputs).unwrap();
let train_content = std::fs::read_to_string(&train_path).unwrap();
let valid_content = std::fs::read_to_string(&valid_path).unwrap();
let train_lines: Vec<&str> = train_content.lines().collect();
let valid_lines: Vec<&str> = valid_content.lines().collect();
assert_eq!(train_lines.len() + valid_lines.len(), 8);
let train_repos: std::collections::HashSet<_> = train_lines
.iter()
.filter_map(|l| get_repository_url(l))
.collect();
let valid_repos: std::collections::HashSet<_> = valid_lines
.iter()
.filter_map(|l| get_repository_url(l))
.collect();
assert!(
train_repos.is_disjoint(&valid_repos),
"train and valid should have non-overlapping repos"
);
}
#[test]
fn test_multiple_rest_fails() {
let specs = vec![
SplitSpec {
path: PathBuf::from("a"),
size: SplitSize::Rest,
},
SplitSpec {
path: PathBuf::from("b"),
size: SplitSize::Rest,
},
];
assert!(compute_split_counts(&specs, 100).is_err());
}
}