mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
ep: Various ep CLI improvements (#56587)
- limit diagnostics sent to teacher - improve parallelism in `ep` commands when using `--max-parallelism` by only grouping by repo when instructed too (repo grouping is only useful for context collection) - update rejected and rated queries to also fetch settled editable region Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ...
This commit is contained in:
parent
3d8495757d
commit
481854f754
5 changed files with 169 additions and 90 deletions
|
|
@ -314,24 +314,32 @@ impl TeacherPrompt {
|
|||
}
|
||||
|
||||
fn format_diagnostics(example: &Example) -> String {
|
||||
example
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.map(|prompt_inputs| {
|
||||
prompt_inputs
|
||||
.active_buffer_diagnostics
|
||||
.iter()
|
||||
.map(|diagnostic| {
|
||||
format!(
|
||||
"*{}*:\n```\n{}\n```\n",
|
||||
&diagnostic.message, &diagnostic.snippet
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
})
|
||||
.filter(|m| !m.is_empty())
|
||||
.unwrap_or("No Diagnostics".to_string())
|
||||
let Some(prompt_inputs) = example.prompt_inputs.as_ref() else {
|
||||
return "No Diagnostics".to_string();
|
||||
};
|
||||
|
||||
let cursor_buffer_row = prompt_inputs.excerpt_start_row.map(|excerpt_start_row| {
|
||||
excerpt_start_row
|
||||
+ prompt_inputs.cursor_excerpt[..prompt_inputs.cursor_offset_in_excerpt]
|
||||
.bytes()
|
||||
.filter(|byte| *byte == b'\n')
|
||||
.count() as u32
|
||||
});
|
||||
let diagnostics = zeta_prompt::format_active_buffer_diagnostics_with_budget(
|
||||
&prompt_inputs.active_buffer_diagnostics,
|
||||
cursor_buffer_row,
|
||||
2_000,
|
||||
);
|
||||
|
||||
let diagnostics = diagnostics
|
||||
.strip_prefix("<filename>diagnostics\n")
|
||||
.unwrap_or(&diagnostics);
|
||||
|
||||
if diagnostics.is_empty() {
|
||||
"No Diagnostics".to_string()
|
||||
} else {
|
||||
diagnostics.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ use zeta_prompt::ZetaFormat;
|
|||
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::collections::VecDeque;
|
||||
use std::env;
|
||||
use std::fmt::Display;
|
||||
use std::fs::{File, OpenOptions};
|
||||
|
|
@ -72,6 +73,9 @@ struct EpArgs {
|
|||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10, global = true)]
|
||||
max_parallelism: usize,
|
||||
/// Process all examples from a repository together instead of distributing examples across workers.
|
||||
#[clap(long, default_value_t = false, global = true)]
|
||||
group_by_repo: bool,
|
||||
/// The limit for the number of examples to process
|
||||
/// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
|
||||
#[clap(long, global = true)]
|
||||
|
|
@ -899,6 +903,18 @@ fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
|
|||
hasher.finish()
|
||||
}
|
||||
|
||||
fn chunk_examples(examples: Vec<Example>, max_parallelism: usize) -> VecDeque<Vec<Example>> {
|
||||
if examples.is_empty() || max_parallelism == 0 {
|
||||
return VecDeque::new();
|
||||
}
|
||||
|
||||
let chunk_size = examples.len().div_ceil(max_parallelism);
|
||||
examples
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| chunk.to_vec())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
|
||||
let file = match File::open(path) {
|
||||
Ok(f) => f,
|
||||
|
|
@ -1173,7 +1189,12 @@ fn main() {
|
|||
output_sender = Some(sender);
|
||||
}
|
||||
|
||||
let grouped_examples = Mutex::new(group_examples_by_repo(examples));
|
||||
let example_batches = if args.group_by_repo {
|
||||
group_examples_by_repo(examples)
|
||||
} else {
|
||||
chunk_examples(examples, args.max_parallelism)
|
||||
};
|
||||
let example_batches = Mutex::new(example_batches);
|
||||
let finished_examples = Mutex::new(Vec::new());
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
|
|
@ -1181,7 +1202,7 @@ fn main() {
|
|||
tasks.push(async {
|
||||
loop {
|
||||
let Some(mut repo_examples) =
|
||||
grouped_examples.lock().unwrap().pop_front()
|
||||
example_batches.lock().unwrap().pop_front()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -576,6 +576,7 @@ pub async fn fetch_rejected_examples_after(
|
|||
input_payload AS input,
|
||||
prompt AS prompt,
|
||||
requested_output AS output,
|
||||
settled_editable_region AS settled_editable_region,
|
||||
is_ep_shown_before_rejected AS was_shown,
|
||||
ep_rejected_reason AS reason,
|
||||
zed_version AS zed_version
|
||||
|
|
@ -623,6 +624,7 @@ pub async fn fetch_rejected_examples_after(
|
|||
"input",
|
||||
"prompt",
|
||||
"output",
|
||||
"settled_editable_region",
|
||||
"was_shown",
|
||||
"reason",
|
||||
"zed_version",
|
||||
|
|
@ -928,6 +930,7 @@ pub async fn fetch_rated_examples_after(
|
|||
ep_request_id AS request_id,
|
||||
rated_inputs AS inputs,
|
||||
rated_output AS output,
|
||||
settled_editable_region AS settled_editable_region,
|
||||
rating AS rating,
|
||||
feedback AS feedback,
|
||||
device_id AS device_id,
|
||||
|
|
@ -971,6 +974,7 @@ pub async fn fetch_rated_examples_after(
|
|||
"request_id",
|
||||
"inputs",
|
||||
"output",
|
||||
"settled_editable_region",
|
||||
"rating",
|
||||
"feedback",
|
||||
"device_id",
|
||||
|
|
@ -1043,6 +1047,7 @@ fn rated_examples_from_response<'a>(
|
|||
None => None,
|
||||
};
|
||||
let output = get_string("output");
|
||||
let settled_editable_region = get_string("settled_editable_region");
|
||||
let rating = get_string("rating");
|
||||
let feedback = get_string("feedback").unwrap_or_default();
|
||||
let device_id = get_string("device_id");
|
||||
|
|
@ -1059,6 +1064,7 @@ fn rated_examples_from_response<'a>(
|
|||
time,
|
||||
inputs,
|
||||
output,
|
||||
settled_editable_region,
|
||||
rating,
|
||||
feedback,
|
||||
experiment_name,
|
||||
|
|
@ -1088,6 +1094,7 @@ fn build_rated_example(
|
|||
time: String,
|
||||
input: ZetaPromptInput,
|
||||
output: String,
|
||||
settled_editable_region: Option<String>,
|
||||
rating: String,
|
||||
feedback: String,
|
||||
experiment_name: Option<String>,
|
||||
|
|
@ -1115,6 +1122,16 @@ fn build_rated_example(
|
|||
tags.push(format!("environment:{env}"));
|
||||
}
|
||||
|
||||
let expected_patch = settled_editable_region
|
||||
.as_ref()
|
||||
.map(|settled_editable_region| {
|
||||
build_output_patch(
|
||||
&input.cursor_path,
|
||||
input.cursor_excerpt.as_ref(),
|
||||
&input.excerpt_ranges.editable_350,
|
||||
settled_editable_region,
|
||||
)
|
||||
});
|
||||
let mut example =
|
||||
build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
|
||||
|
||||
|
|
@ -1127,9 +1144,13 @@ fn build_rated_example(
|
|||
.push(edit_prediction::example_spec::HumanFeedback { message: feedback });
|
||||
}
|
||||
|
||||
if is_positive {
|
||||
example.spec.expected_patches = vec![output];
|
||||
} else {
|
||||
if let Some(expected_patch) = expected_patch {
|
||||
example.spec.expected_patches = vec![expected_patch];
|
||||
} else if is_positive {
|
||||
example.spec.expected_patches = vec![output.clone()];
|
||||
}
|
||||
|
||||
if !is_positive {
|
||||
example.spec.rejected_patch = Some(output);
|
||||
}
|
||||
|
||||
|
|
@ -1608,6 +1629,7 @@ fn rejected_examples_from_response<'a>(
|
|||
input_json.clone().and_then(|v| serde_json::from_value(v).ok());
|
||||
let prompt = get_string("prompt");
|
||||
let output = get_string("output");
|
||||
let settled_editable_region = get_string("settled_editable_region");
|
||||
let was_shown = get_bool("was_shown");
|
||||
let reason = get_string("reason");
|
||||
let zed_version = get_string("zed_version");
|
||||
|
|
@ -1621,6 +1643,7 @@ fn rejected_examples_from_response<'a>(
|
|||
input,
|
||||
prompt,
|
||||
output,
|
||||
settled_editable_region,
|
||||
was_shown,
|
||||
reason,
|
||||
zed_version,
|
||||
|
|
@ -1652,6 +1675,7 @@ fn build_rejected_example(
|
|||
input: ZetaPromptInput,
|
||||
prompt: Option<String>,
|
||||
output: String,
|
||||
settled_editable_region: Option<String>,
|
||||
was_shown: bool,
|
||||
reason: String,
|
||||
zed_version: Option<String>,
|
||||
|
|
@ -1662,6 +1686,16 @@ fn build_rejected_example(
|
|||
&input.excerpt_ranges.editable_350,
|
||||
&output,
|
||||
);
|
||||
let expected_patch = settled_editable_region
|
||||
.as_ref()
|
||||
.map(|settled_editable_region| {
|
||||
build_output_patch(
|
||||
&input.cursor_path,
|
||||
input.cursor_excerpt.as_ref(),
|
||||
&input.excerpt_ranges.editable_350,
|
||||
settled_editable_region,
|
||||
)
|
||||
});
|
||||
let mut example = build_example_from_snowflake(
|
||||
request_id,
|
||||
device_id,
|
||||
|
|
@ -1672,6 +1706,9 @@ fn build_rejected_example(
|
|||
zed_version,
|
||||
);
|
||||
example.spec.rejected_patch = Some(rejected_patch);
|
||||
if let Some(expected_patch) = expected_patch {
|
||||
example.spec.expected_patches = vec![expected_patch];
|
||||
}
|
||||
example.prompt = prompt.map(|prompt| ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: None,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use anyhow::Context as _;
|
|||
use edit_prediction_metrics::{
|
||||
ActualPredictionCursor, PredictionReversalContext, PredictionScoringInput,
|
||||
};
|
||||
use gpui::AsyncApp;
|
||||
use gpui::{AppContext as _, AsyncApp};
|
||||
use std::fs::File;
|
||||
use std::io::BufWriter;
|
||||
use std::path::Path;
|
||||
|
|
@ -24,79 +24,92 @@ pub async fn run_scoring(
|
|||
example_progress: &ExampleProgress,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
run_prediction(example, args, app_state, example_progress, cx).await?;
|
||||
run_prediction(example, args, app_state, example_progress, cx.clone()).await?;
|
||||
|
||||
let progress = example_progress.start(Step::Score);
|
||||
|
||||
progress.set_substatus("applying patches");
|
||||
let prompt_inputs = example
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
|
||||
let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
|
||||
let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
|
||||
|
||||
let old_editable_region = if let Some(p) = example.prompt.as_ref() {
|
||||
if matches!(
|
||||
p.provider,
|
||||
PredictionProvider::Teacher(_, _) | PredictionProvider::TeacherNonBatching(_, _)
|
||||
) {
|
||||
Some(
|
||||
TeacherPrompt::extract_editable_region(&p.input)?
|
||||
.replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let prepared_expected_patches = edit_prediction_metrics::prepare_expected_patches(
|
||||
&expected_patches_with_cursors,
|
||||
original_text,
|
||||
old_editable_region.as_deref(),
|
||||
)
|
||||
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))?;
|
||||
|
||||
let cursor_path = example.spec.cursor_path.as_ref();
|
||||
|
||||
progress.set_substatus("computing metrics");
|
||||
let mut scores = vec![];
|
||||
for prediction in &example.predictions {
|
||||
let actual_patch = prediction.actual_patch.clone().or_else(|| {
|
||||
parse_prediction_output(example, &prediction.actual_output, prediction.provider)
|
||||
.ok()
|
||||
.map(|(patch, _)| patch)
|
||||
});
|
||||
|
||||
let actual_cursor =
|
||||
prediction
|
||||
.actual_cursor
|
||||
let example_for_scoring = example.clone();
|
||||
example.score = cx
|
||||
.background_spawn(async move {
|
||||
let prompt_inputs = example_for_scoring
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.map(|cursor| ActualPredictionCursor {
|
||||
row: cursor.row,
|
||||
editable_region_offset: cursor.editable_region_offset,
|
||||
.context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
|
||||
let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
|
||||
let expected_patches_with_cursors = example_for_scoring
|
||||
.spec
|
||||
.expected_patches_with_cursor_positions();
|
||||
|
||||
let old_editable_region = if let Some(p) = example_for_scoring.prompt.as_ref() {
|
||||
if matches!(
|
||||
p.provider,
|
||||
PredictionProvider::Teacher(_, _) | PredictionProvider::TeacherNonBatching(_, _)
|
||||
) {
|
||||
Some(
|
||||
TeacherPrompt::extract_editable_region(&p.input)?
|
||||
.replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let prepared_expected_patches = edit_prediction_metrics::prepare_expected_patches(
|
||||
&expected_patches_with_cursors,
|
||||
original_text,
|
||||
old_editable_region.as_deref(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Expected patch did not apply for {}",
|
||||
example_for_scoring.spec.name
|
||||
)
|
||||
})?;
|
||||
|
||||
let cursor_path = example_for_scoring.spec.cursor_path.as_ref();
|
||||
|
||||
let mut scores = vec![];
|
||||
for prediction in &example_for_scoring.predictions {
|
||||
let actual_patch = prediction.actual_patch.clone().or_else(|| {
|
||||
parse_prediction_output(
|
||||
&example_for_scoring,
|
||||
&prediction.actual_output,
|
||||
prediction.provider,
|
||||
)
|
||||
.ok()
|
||||
.map(|(patch, _)| patch)
|
||||
});
|
||||
|
||||
scores.push(edit_prediction_metrics::score_prediction(
|
||||
PredictionScoringInput {
|
||||
original_text,
|
||||
expected_patches: &prepared_expected_patches,
|
||||
actual_patch: actual_patch.as_deref(),
|
||||
actual_cursor,
|
||||
reversal_context: Some(PredictionReversalContext {
|
||||
edit_history: &prompt_inputs.events,
|
||||
excerpt_start_row: prompt_inputs.excerpt_start_row,
|
||||
cursor_path,
|
||||
}),
|
||||
cumulative_logprob: prediction.cumulative_logprob,
|
||||
avg_logprob: prediction.avg_logprob,
|
||||
},
|
||||
));
|
||||
}
|
||||
let actual_cursor = prediction.actual_cursor.as_ref().map(|cursor| {
|
||||
ActualPredictionCursor {
|
||||
row: cursor.row,
|
||||
editable_region_offset: cursor.editable_region_offset,
|
||||
}
|
||||
});
|
||||
|
||||
example.score = scores;
|
||||
scores.push(edit_prediction_metrics::score_prediction(
|
||||
PredictionScoringInput {
|
||||
original_text,
|
||||
expected_patches: &prepared_expected_patches,
|
||||
actual_patch: actual_patch.as_deref(),
|
||||
actual_cursor,
|
||||
reversal_context: Some(PredictionReversalContext {
|
||||
edit_history: &prompt_inputs.events,
|
||||
excerpt_start_row: prompt_inputs.excerpt_start_row,
|
||||
cursor_path,
|
||||
}),
|
||||
cumulative_logprob: prediction.cumulative_logprob,
|
||||
avg_logprob: prediction.avg_logprob,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
anyhow::Ok(scores)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -840,7 +840,7 @@ pub fn format_prompt_with_budget_for_format(
|
|||
return Some(prompt);
|
||||
}
|
||||
|
||||
fn format_active_buffer_diagnostics_with_budget(
|
||||
pub fn format_active_buffer_diagnostics_with_budget(
|
||||
diagnostics: &[ActiveBufferDiagnostic],
|
||||
cursor_buffer_row: Option<u32>,
|
||||
budget: usize,
|
||||
|
|
|
|||
Loading…
Reference in a new issue