ep: Various ep CLI improvements (#56587)

- limit diagnostics sent to teacher
- improve parallelism in `ep` commands when using `--max-parallelism` by
only grouping by repo when instructed too (repo grouping is only useful
for context collection)
- update rejected and rated queries to also fetch settled editable
region

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A or Added/Fixed/Improved ...
This commit is contained in:
Ben Kunkle 2026-05-13 07:34:50 -04:00 committed by GitHub
parent 3d8495757d
commit 481854f754
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 169 additions and 90 deletions

View file

@ -314,24 +314,32 @@ impl TeacherPrompt {
}
fn format_diagnostics(example: &Example) -> String {
example
.prompt_inputs
.as_ref()
.map(|prompt_inputs| {
prompt_inputs
.active_buffer_diagnostics
.iter()
.map(|diagnostic| {
format!(
"*{}*:\n```\n{}\n```\n",
&diagnostic.message, &diagnostic.snippet
)
})
.collect::<Vec<_>>()
.join("\n")
})
.filter(|m| !m.is_empty())
.unwrap_or("No Diagnostics".to_string())
let Some(prompt_inputs) = example.prompt_inputs.as_ref() else {
return "No Diagnostics".to_string();
};
let cursor_buffer_row = prompt_inputs.excerpt_start_row.map(|excerpt_start_row| {
excerpt_start_row
+ prompt_inputs.cursor_excerpt[..prompt_inputs.cursor_offset_in_excerpt]
.bytes()
.filter(|byte| *byte == b'\n')
.count() as u32
});
let diagnostics = zeta_prompt::format_active_buffer_diagnostics_with_budget(
&prompt_inputs.active_buffer_diagnostics,
cursor_buffer_row,
2_000,
);
let diagnostics = diagnostics
.strip_prefix("<filename>diagnostics\n")
.unwrap_or(&diagnostics);
if diagnostics.is_empty() {
"No Diagnostics".to_string()
} else {
diagnostics.to_string()
}
}
}

View file

@ -40,6 +40,7 @@ use zeta_prompt::ZetaFormat;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::VecDeque;
use std::env;
use std::fmt::Display;
use std::fs::{File, OpenOptions};
@ -72,6 +73,9 @@ struct EpArgs {
printenv: bool,
#[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
/// Process all examples from a repository together instead of distributing examples across workers.
#[clap(long, default_value_t = false, global = true)]
group_by_repo: bool,
/// The limit for the number of examples to process
/// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
#[clap(long, global = true)]
@ -899,6 +903,18 @@ fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
hasher.finish()
}
fn chunk_examples(examples: Vec<Example>, max_parallelism: usize) -> VecDeque<Vec<Example>> {
if examples.is_empty() || max_parallelism == 0 {
return VecDeque::new();
}
let chunk_size = examples.len().div_ceil(max_parallelism);
examples
.chunks(chunk_size)
.map(|chunk| chunk.to_vec())
.collect()
}
fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
let file = match File::open(path) {
Ok(f) => f,
@ -1173,7 +1189,12 @@ fn main() {
output_sender = Some(sender);
}
let grouped_examples = Mutex::new(group_examples_by_repo(examples));
let example_batches = if args.group_by_repo {
group_examples_by_repo(examples)
} else {
chunk_examples(examples, args.max_parallelism)
};
let example_batches = Mutex::new(example_batches);
let finished_examples = Mutex::new(Vec::new());
let mut tasks = Vec::new();
@ -1181,7 +1202,7 @@ fn main() {
tasks.push(async {
loop {
let Some(mut repo_examples) =
grouped_examples.lock().unwrap().pop_front()
example_batches.lock().unwrap().pop_front()
else {
break;
};

View file

@ -576,6 +576,7 @@ pub async fn fetch_rejected_examples_after(
input_payload AS input,
prompt AS prompt,
requested_output AS output,
settled_editable_region AS settled_editable_region,
is_ep_shown_before_rejected AS was_shown,
ep_rejected_reason AS reason,
zed_version AS zed_version
@ -623,6 +624,7 @@ pub async fn fetch_rejected_examples_after(
"input",
"prompt",
"output",
"settled_editable_region",
"was_shown",
"reason",
"zed_version",
@ -928,6 +930,7 @@ pub async fn fetch_rated_examples_after(
ep_request_id AS request_id,
rated_inputs AS inputs,
rated_output AS output,
settled_editable_region AS settled_editable_region,
rating AS rating,
feedback AS feedback,
device_id AS device_id,
@ -971,6 +974,7 @@ pub async fn fetch_rated_examples_after(
"request_id",
"inputs",
"output",
"settled_editable_region",
"rating",
"feedback",
"device_id",
@ -1043,6 +1047,7 @@ fn rated_examples_from_response<'a>(
None => None,
};
let output = get_string("output");
let settled_editable_region = get_string("settled_editable_region");
let rating = get_string("rating");
let feedback = get_string("feedback").unwrap_or_default();
let device_id = get_string("device_id");
@ -1059,6 +1064,7 @@ fn rated_examples_from_response<'a>(
time,
inputs,
output,
settled_editable_region,
rating,
feedback,
experiment_name,
@ -1088,6 +1094,7 @@ fn build_rated_example(
time: String,
input: ZetaPromptInput,
output: String,
settled_editable_region: Option<String>,
rating: String,
feedback: String,
experiment_name: Option<String>,
@ -1115,6 +1122,16 @@ fn build_rated_example(
tags.push(format!("environment:{env}"));
}
let expected_patch = settled_editable_region
.as_ref()
.map(|settled_editable_region| {
build_output_patch(
&input.cursor_path,
input.cursor_excerpt.as_ref(),
&input.excerpt_ranges.editable_350,
settled_editable_region,
)
});
let mut example =
build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
@ -1127,9 +1144,13 @@ fn build_rated_example(
.push(edit_prediction::example_spec::HumanFeedback { message: feedback });
}
if is_positive {
example.spec.expected_patches = vec![output];
} else {
if let Some(expected_patch) = expected_patch {
example.spec.expected_patches = vec![expected_patch];
} else if is_positive {
example.spec.expected_patches = vec![output.clone()];
}
if !is_positive {
example.spec.rejected_patch = Some(output);
}
@ -1608,6 +1629,7 @@ fn rejected_examples_from_response<'a>(
input_json.clone().and_then(|v| serde_json::from_value(v).ok());
let prompt = get_string("prompt");
let output = get_string("output");
let settled_editable_region = get_string("settled_editable_region");
let was_shown = get_bool("was_shown");
let reason = get_string("reason");
let zed_version = get_string("zed_version");
@ -1621,6 +1643,7 @@ fn rejected_examples_from_response<'a>(
input,
prompt,
output,
settled_editable_region,
was_shown,
reason,
zed_version,
@ -1652,6 +1675,7 @@ fn build_rejected_example(
input: ZetaPromptInput,
prompt: Option<String>,
output: String,
settled_editable_region: Option<String>,
was_shown: bool,
reason: String,
zed_version: Option<String>,
@ -1662,6 +1686,16 @@ fn build_rejected_example(
&input.excerpt_ranges.editable_350,
&output,
);
let expected_patch = settled_editable_region
.as_ref()
.map(|settled_editable_region| {
build_output_patch(
&input.cursor_path,
input.cursor_excerpt.as_ref(),
&input.excerpt_ranges.editable_350,
settled_editable_region,
)
});
let mut example = build_example_from_snowflake(
request_id,
device_id,
@ -1672,6 +1706,9 @@ fn build_rejected_example(
zed_version,
);
example.spec.rejected_patch = Some(rejected_patch);
if let Some(expected_patch) = expected_patch {
example.spec.expected_patches = vec![expected_patch];
}
example.prompt = prompt.map(|prompt| ExamplePrompt {
input: prompt,
expected_output: None,

View file

@ -11,7 +11,7 @@ use anyhow::Context as _;
use edit_prediction_metrics::{
ActualPredictionCursor, PredictionReversalContext, PredictionScoringInput,
};
use gpui::AsyncApp;
use gpui::{AppContext as _, AsyncApp};
use std::fs::File;
use std::io::BufWriter;
use std::path::Path;
@ -24,79 +24,92 @@ pub async fn run_scoring(
example_progress: &ExampleProgress,
cx: AsyncApp,
) -> anyhow::Result<()> {
run_prediction(example, args, app_state, example_progress, cx).await?;
run_prediction(example, args, app_state, example_progress, cx.clone()).await?;
let progress = example_progress.start(Step::Score);
progress.set_substatus("applying patches");
let prompt_inputs = example
.prompt_inputs
.as_ref()
.context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
let old_editable_region = if let Some(p) = example.prompt.as_ref() {
if matches!(
p.provider,
PredictionProvider::Teacher(_, _) | PredictionProvider::TeacherNonBatching(_, _)
) {
Some(
TeacherPrompt::extract_editable_region(&p.input)?
.replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
)
} else {
None
}
} else {
None
};
let prepared_expected_patches = edit_prediction_metrics::prepare_expected_patches(
&expected_patches_with_cursors,
original_text,
old_editable_region.as_deref(),
)
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))?;
let cursor_path = example.spec.cursor_path.as_ref();
progress.set_substatus("computing metrics");
let mut scores = vec![];
for prediction in &example.predictions {
let actual_patch = prediction.actual_patch.clone().or_else(|| {
parse_prediction_output(example, &prediction.actual_output, prediction.provider)
.ok()
.map(|(patch, _)| patch)
});
let actual_cursor =
prediction
.actual_cursor
let example_for_scoring = example.clone();
example.score = cx
.background_spawn(async move {
let prompt_inputs = example_for_scoring
.prompt_inputs
.as_ref()
.map(|cursor| ActualPredictionCursor {
row: cursor.row,
editable_region_offset: cursor.editable_region_offset,
.context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
let expected_patches_with_cursors = example_for_scoring
.spec
.expected_patches_with_cursor_positions();
let old_editable_region = if let Some(p) = example_for_scoring.prompt.as_ref() {
if matches!(
p.provider,
PredictionProvider::Teacher(_, _) | PredictionProvider::TeacherNonBatching(_, _)
) {
Some(
TeacherPrompt::extract_editable_region(&p.input)?
.replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
)
} else {
None
}
} else {
None
};
let prepared_expected_patches = edit_prediction_metrics::prepare_expected_patches(
&expected_patches_with_cursors,
original_text,
old_editable_region.as_deref(),
)
.with_context(|| {
format!(
"Expected patch did not apply for {}",
example_for_scoring.spec.name
)
})?;
let cursor_path = example_for_scoring.spec.cursor_path.as_ref();
let mut scores = vec![];
for prediction in &example_for_scoring.predictions {
let actual_patch = prediction.actual_patch.clone().or_else(|| {
parse_prediction_output(
&example_for_scoring,
&prediction.actual_output,
prediction.provider,
)
.ok()
.map(|(patch, _)| patch)
});
scores.push(edit_prediction_metrics::score_prediction(
PredictionScoringInput {
original_text,
expected_patches: &prepared_expected_patches,
actual_patch: actual_patch.as_deref(),
actual_cursor,
reversal_context: Some(PredictionReversalContext {
edit_history: &prompt_inputs.events,
excerpt_start_row: prompt_inputs.excerpt_start_row,
cursor_path,
}),
cumulative_logprob: prediction.cumulative_logprob,
avg_logprob: prediction.avg_logprob,
},
));
}
let actual_cursor = prediction.actual_cursor.as_ref().map(|cursor| {
ActualPredictionCursor {
row: cursor.row,
editable_region_offset: cursor.editable_region_offset,
}
});
example.score = scores;
scores.push(edit_prediction_metrics::score_prediction(
PredictionScoringInput {
original_text,
expected_patches: &prepared_expected_patches,
actual_patch: actual_patch.as_deref(),
actual_cursor,
reversal_context: Some(PredictionReversalContext {
edit_history: &prompt_inputs.events,
excerpt_start_row: prompt_inputs.excerpt_start_row,
cursor_path,
}),
cumulative_logprob: prediction.cumulative_logprob,
avg_logprob: prediction.avg_logprob,
},
));
}
anyhow::Ok(scores)
})
.await?;
Ok(())
}

View file

@ -840,7 +840,7 @@ pub fn format_prompt_with_budget_for_format(
return Some(prompt);
}
fn format_active_buffer_diagnostics_with_budget(
pub fn format_active_buffer_diagnostics_with_budget(
diagnostics: &[ActiveBufferDiagnostic],
cursor_buffer_row: Option<u32>,
budget: usize,