mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Introduce zeta2 format with cursor content in original order (#46732)
This one does `fim_prefix`, `fim_middle`, and `fim_suffix` in that order, in the prompt, instead of putting the current middle last. Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev> Co-authored-by: Ben Kunkle <ben@zed.dev>
This commit is contained in:
parent
c9003e1a12
commit
20284e4f21
15 changed files with 384 additions and 162 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -21179,7 +21179,9 @@ dependencies = [
|
|||
name = "zeta_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
"strum 0.27.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
|
|||
use std::collections::{VecDeque, hash_map};
|
||||
use text::Edit;
|
||||
use workspace::Workspace;
|
||||
use zeta_prompt::ZetaVersion;
|
||||
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
|
|
@ -183,7 +184,9 @@ pub struct EditPredictionStore {
|
|||
pub enum EditPredictionModel {
|
||||
#[default]
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Zeta2 {
|
||||
version: ZetaVersion,
|
||||
},
|
||||
Sweep,
|
||||
Mercury,
|
||||
}
|
||||
|
|
@ -654,7 +657,9 @@ impl EditPredictionStore {
|
|||
update_required: false,
|
||||
#[cfg(feature = "cli-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2 {
|
||||
version: Default::default(),
|
||||
},
|
||||
sweep_ai: SweepAi::new(cx),
|
||||
mercury: Mercury::new(cx),
|
||||
data_collection_choice,
|
||||
|
|
@ -794,7 +799,10 @@ impl EditPredictionStore {
|
|||
}
|
||||
|
||||
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
||||
if self.edit_prediction_model == EditPredictionModel::Zeta2 {
|
||||
if matches!(
|
||||
self.edit_prediction_model,
|
||||
EditPredictionModel::Zeta2 { .. }
|
||||
) {
|
||||
self.user_store.read(cx).edit_prediction_usage()
|
||||
} else {
|
||||
None
|
||||
|
|
@ -1204,7 +1212,7 @@ impl EditPredictionStore {
|
|||
sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
|
||||
}
|
||||
EditPredictionModel::Mercury => {}
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
|
||||
zeta2::edit_prediction_accepted(self, current_prediction, cx)
|
||||
}
|
||||
}
|
||||
|
|
@ -1338,7 +1346,7 @@ impl EditPredictionStore {
|
|||
was_shown: bool,
|
||||
) {
|
||||
match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
|
||||
if self.custom_predict_edits_url.is_some() {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1773,7 +1781,9 @@ impl EditPredictionStore {
|
|||
}
|
||||
let task = match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
|
||||
EditPredictionModel::Zeta2 { version } => {
|
||||
zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
|
||||
}
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1332,12 +1332,20 @@ fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawComp
|
|||
|
||||
let current_marker = "<|fim_middle|>current\n";
|
||||
let updated_marker = "<|fim_middle|>updated\n";
|
||||
let suffix_marker = "<|fim_suffix|>\n";
|
||||
let cursor = "<|user_cursor|>";
|
||||
|
||||
let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
|
||||
let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
|
||||
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
|
||||
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
|
||||
// In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
|
||||
// Strip that out to get just the editable region.
|
||||
let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
|
||||
&excerpt[..suffix_pos]
|
||||
} else {
|
||||
&excerpt
|
||||
};
|
||||
let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
|
||||
|
||||
RawCompletionResponse {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
|
|
@ -1629,6 +1637,82 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
|
||||
// Test that zeta2's newline normalization logic doesn't insert spurious newlines.
|
||||
// When the buffer ends without a trailing newline, but the model returns output
|
||||
// with a trailing newline, zeta2 should normalize both sides before diffing
|
||||
// so no spurious newline is inserted.
|
||||
let (ep_store, mut requests) = init_test_with_fake_client(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
// Single line buffer with no trailing newline
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"foo.txt": "hello"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(path!("root/foo.txt"), cx)
|
||||
.unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let position = snapshot.anchor_before(language::Point::new(0, 5));
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
// Model returns output WITH a trailing newline, even though the buffer doesn't have one.
|
||||
// Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
|
||||
let response = RawCompletionResponse {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
object: "text_completion".into(),
|
||||
created: 0,
|
||||
model: "model".into(),
|
||||
choices: vec![RawCompletionChoice {
|
||||
text: "hello world\n".to_string(),
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: RawCompletionUsage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
};
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// The prediction should insert " world" without adding a newline
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.expect("should have prediction");
|
||||
let edits: Vec<_> = prediction
|
||||
.edits
|
||||
.iter()
|
||||
.map(|(range, text)| {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
(range.to_offset(&snapshot), text.clone())
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(edits, vec![(5..5, " world".into())]);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_can_collect_data(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ use release_channel::AppVersion;
|
|||
|
||||
use std::env;
|
||||
use std::{path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
|
||||
|
||||
pub const MAX_CONTEXT_TOKENS: usize = 350;
|
||||
pub const MAX_EDITABLE_TOKENS: usize = 150;
|
||||
|
|
@ -32,6 +32,7 @@ pub fn request_prediction_with_zeta2(
|
|||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
zeta_version: ZetaVersion,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
|
@ -62,7 +63,7 @@ pub fn request_prediction_with_zeta2(
|
|||
cursor_offset,
|
||||
);
|
||||
|
||||
let prompt = format_zeta_prompt(&prompt_input);
|
||||
let prompt = format_zeta_prompt(&prompt_input, zeta_version);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
|
|
@ -125,9 +126,17 @@ pub fn request_prediction_with_zeta2(
|
|||
output_text = output_text.replace(CURSOR_MARKER, "");
|
||||
}
|
||||
|
||||
let old_text = snapshot
|
||||
let mut old_text = snapshot
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
|
||||
if !output_text.is_empty() && !output_text.ends_with('\n') {
|
||||
output_text.push('\n');
|
||||
}
|
||||
if !old_text.is_empty() && !old_text.ends_with('\n') {
|
||||
old_text.push('\n');
|
||||
}
|
||||
|
||||
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::PredictionProvider;
|
||||
use crate::paths::WORKTREES_DIR;
|
||||
use crate::{PredictionProvider, PromptFormat};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::example_spec::ExampleSpec;
|
||||
|
|
@ -9,11 +9,12 @@ use http_client::Url;
|
|||
use language::{Anchor, Buffer};
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Range;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
io::Read,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use zeta_prompt::RelatedFile;
|
||||
|
||||
|
|
@ -25,12 +26,7 @@ pub struct Example {
|
|||
/// The full content of the file where an edit is being predicted, and the
|
||||
/// actual cursor offset.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub buffer: Option<ExampleBuffer>,
|
||||
|
||||
/// The context retrieved for the prediction. This requires the worktree to
|
||||
/// be loaded and the language server to be started.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<ExampleContext>,
|
||||
pub prompt_inputs: Option<ExamplePromptInputs>,
|
||||
|
||||
/// The input and expected output from the edit prediction model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
|
@ -59,25 +55,22 @@ pub struct ExampleState {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleContext {
|
||||
pub files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleBuffer {
|
||||
pub struct ExamplePromptInputs {
|
||||
pub content: String,
|
||||
pub cursor_row: u32,
|
||||
pub cursor_column: u32,
|
||||
pub cursor_offset: usize,
|
||||
pub context_range: Range<usize>,
|
||||
pub editable_range: Range<usize>,
|
||||
pub edit_history: Vec<Arc<zeta_prompt::Event>>,
|
||||
pub related_files: Option<Vec<RelatedFile>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrompt {
|
||||
pub input: String,
|
||||
pub expected_output: String,
|
||||
pub format: PromptFormat,
|
||||
pub provider: PredictionProvider,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
|
@ -239,8 +232,7 @@ fn parse_markdown_example(input: &str) -> Result<Example> {
|
|||
let spec = ExampleSpec::from_markdown(input)?;
|
||||
Ok(Example {
|
||||
spec,
|
||||
buffer: None,
|
||||
context: None,
|
||||
prompt_inputs: None,
|
||||
prompt: None,
|
||||
predictions: Vec::new(),
|
||||
score: Vec::new(),
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
use crate::{
|
||||
PromptFormat,
|
||||
FormatPromptArgs, PredictionProvider,
|
||||
example::{Example, ExamplePrompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
progress::{Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use gpui::AsyncApp;
|
||||
use similar::DiffableStr;
|
||||
use std::fmt::Write as _;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -16,16 +14,21 @@ use zeta_prompt::format_zeta_prompt;
|
|||
|
||||
pub async fn run_format_prompt(
|
||||
example: &mut Example,
|
||||
prompt_format: PromptFormat,
|
||||
args: &FormatPromptArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
run_context_retrieval(example, app_state, cx).await?;
|
||||
|
||||
let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
|
||||
|
||||
match prompt_format {
|
||||
PromptFormat::Teacher => {
|
||||
let prompt_inputs = example
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.context("prompt_inputs must be set after context retrieval")?;
|
||||
|
||||
match args.provider {
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
step_progress.set_substatus("formatting teacher prompt");
|
||||
let prompt = TeacherPrompt::format_prompt(example);
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
|
|
@ -36,47 +39,27 @@ pub async fn run_format_prompt(
|
|||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_default(),
|
||||
format: prompt_format,
|
||||
provider: args.provider,
|
||||
});
|
||||
}
|
||||
PromptFormat::Zeta2 => {
|
||||
step_progress.set_substatus("loading project");
|
||||
run_load_project(example, app_state, cx.clone()).await?;
|
||||
|
||||
PredictionProvider::Zeta2 => {
|
||||
step_progress.set_substatus("formatting zeta2 prompt");
|
||||
|
||||
let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})?;
|
||||
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
|
||||
let project = state.project.clone();
|
||||
let (_, input) =
|
||||
ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
|
||||
let events = ep_store
|
||||
.edit_history_for_project(&project, cx)
|
||||
.into_iter()
|
||||
.map(|e| e.event)
|
||||
.collect();
|
||||
anyhow::Ok(zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example
|
||||
.context
|
||||
.as_ref()
|
||||
.context("context must be set")?
|
||||
.files
|
||||
.clone(),
|
||||
events,
|
||||
example.spec.cursor_path.clone(),
|
||||
example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("buffer must be set")?
|
||||
.cursor_offset,
|
||||
))
|
||||
})?;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let context_start = prompt_inputs.context_range.start;
|
||||
let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
|
||||
let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start)
|
||||
..(prompt_inputs.editable_range.end - context_start);
|
||||
let input = zeta_prompt::ZetaPromptInput {
|
||||
cursor_path: example.spec.cursor_path.clone(),
|
||||
cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()]
|
||||
.to_string()
|
||||
.into(),
|
||||
editable_range_in_excerpt,
|
||||
cursor_offset_in_excerpt,
|
||||
events: prompt_inputs.edit_history.clone(),
|
||||
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
|
||||
};
|
||||
let prompt = format_zeta_prompt(&input, args.version);
|
||||
let expected_output = zeta2_output_for_patch(
|
||||
&input,
|
||||
&example
|
||||
|
|
@ -89,9 +72,12 @@ pub async fn run_format_prompt(
|
|||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output,
|
||||
format: prompt_format,
|
||||
provider: args.provider,
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
panic!("Cannot format prompt for {:?}", args.provider);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -144,10 +130,10 @@ impl TeacherPrompt {
|
|||
// 2. Context retriever just didn't include cursor line.
|
||||
//
|
||||
// In that case, fallback to using `cursor_position` as excerpt.
|
||||
let example_buffer = example
|
||||
.buffer
|
||||
let prompt_inputs = example
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.context("`buffer` should be filled in in the context collection step")?;
|
||||
.context("`prompt_inputs` should be filled in in the context collection step")?;
|
||||
|
||||
// Extract updated (new) editable region from the model response.
|
||||
// The model may include editable region markers in its output, so we need to strip them.
|
||||
|
|
@ -155,7 +141,7 @@ impl TeacherPrompt {
|
|||
let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
|
||||
|
||||
let old_editable_region =
|
||||
example_buffer.content[example_buffer.editable_range.clone()].to_string();
|
||||
prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string();
|
||||
|
||||
// Normalize leading newlines: if old starts with newline but new doesn't,
|
||||
// prepend newline to new to preserve whitespace structure.
|
||||
|
|
@ -164,8 +150,8 @@ impl TeacherPrompt {
|
|||
new_editable_region.insert(0, '\n');
|
||||
}
|
||||
|
||||
let editable_region_start_line = example_buffer.content
|
||||
[..example_buffer.editable_range.start]
|
||||
let editable_region_start_line = prompt_inputs.content
|
||||
[..prompt_inputs.editable_range.start]
|
||||
.matches('\n')
|
||||
.count();
|
||||
|
||||
|
|
@ -208,17 +194,21 @@ impl TeacherPrompt {
|
|||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
let context = example
|
||||
.context
|
||||
let related_files = example
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.expect("Missing context retriever step");
|
||||
.and_then(|pi| pi.related_files.as_ref());
|
||||
|
||||
if context.files.is_empty() {
|
||||
let Some(related_files) = related_files else {
|
||||
return "(No context)".to_string();
|
||||
};
|
||||
|
||||
if related_files.is_empty() {
|
||||
return "(No context)".to_string();
|
||||
}
|
||||
|
||||
let mut prompt = String::new();
|
||||
for file in context.files.iter() {
|
||||
for file in related_files {
|
||||
let path_str = file.path.to_string_lossy();
|
||||
writeln!(&mut prompt, "`````{path_str}").ok();
|
||||
let mut prev_row = 0;
|
||||
|
|
@ -242,28 +232,26 @@ impl TeacherPrompt {
|
|||
fn format_cursor_excerpt(example: &Example) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
let example_buffer = example.buffer.as_ref().unwrap();
|
||||
let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
|
||||
|
||||
let path_str = example.spec.cursor_path.to_string_lossy();
|
||||
result.push_str(&format!("`````{path_str}\n"));
|
||||
result.push_str(
|
||||
&example_buffer.content
|
||||
[example_buffer.context_range.start..example_buffer.editable_range.start],
|
||||
&prompt_inputs.content
|
||||
[prompt_inputs.context_range.start..prompt_inputs.editable_range.start],
|
||||
);
|
||||
result.push_str(Self::EDITABLE_REGION_START);
|
||||
result.push_str(
|
||||
&example_buffer.content
|
||||
[example_buffer.editable_range.start..example_buffer.cursor_offset],
|
||||
&prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset],
|
||||
);
|
||||
result.push_str(Self::USER_CURSOR_MARKER);
|
||||
result.push_str(
|
||||
&example_buffer.content
|
||||
[example_buffer.cursor_offset..example_buffer.editable_range.end],
|
||||
&prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end],
|
||||
);
|
||||
result.push_str(Self::EDITABLE_REGION_END);
|
||||
result.push_str(
|
||||
&example_buffer.content
|
||||
[example_buffer.editable_range.end..example_buffer.context_range.end],
|
||||
&prompt_inputs.content
|
||||
[prompt_inputs.editable_range.end..prompt_inputs.context_range.end],
|
||||
);
|
||||
result.push_str("\n`````");
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
example::{Example, ExamplePromptInputs, ExampleState},
|
||||
git,
|
||||
headless::EpAppState,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
|
|
@ -38,7 +38,20 @@ pub async fn run_load_project(
|
|||
buffer
|
||||
.read_with(&cx, |buffer, _| buffer.parsing_idle())
|
||||
.await;
|
||||
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx))
|
||||
.context("EditPredictionStore not initialized")?;
|
||||
|
||||
let edit_history = ep_store.update(&mut cx, |store, cx| {
|
||||
store
|
||||
.edit_history_for_project(&project, cx)
|
||||
.into_iter()
|
||||
.map(|e| e.event)
|
||||
.collect()
|
||||
});
|
||||
|
||||
let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let snapshot = buffer.snapshot();
|
||||
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
|
||||
|
|
@ -54,13 +67,15 @@ pub async fn run_load_project(
|
|||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
ExamplePromptInputs {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
context_range,
|
||||
editable_range,
|
||||
edit_history,
|
||||
related_files: None,
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
|
|
@ -68,7 +83,7 @@ pub async fn run_load_project(
|
|||
|
||||
progress.set_info(language_name, InfoStyle::Normal);
|
||||
|
||||
example.buffer = Some(example_buffer);
|
||||
example.prompt_inputs = Some(prompt_inputs);
|
||||
example.state = Some(ExampleState {
|
||||
buffer,
|
||||
project,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ use edit_prediction::EditPredictionStore;
|
|||
use futures::channel::mpsc;
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use gpui::{AppContext as _, Application};
|
||||
use zeta_prompt::ZetaVersion;
|
||||
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -155,7 +156,7 @@ impl Display for Command {
|
|||
f,
|
||||
"format-prompt --prompt-format={}",
|
||||
format_prompt_args
|
||||
.prompt_format
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
|
|
@ -204,22 +205,31 @@ impl Display for Command {
|
|||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long, short('p'))]
|
||||
prompt_format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PromptFormat {
|
||||
Teacher,
|
||||
Zeta2,
|
||||
#[clap(long, short)]
|
||||
provider: PredictionProvider,
|
||||
#[clap(
|
||||
long,
|
||||
short,
|
||||
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
|
||||
value_parser = ZetaVersion::parse,
|
||||
default_value_t = ZetaVersion::default(),
|
||||
)]
|
||||
version: ZetaVersion,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
struct PredictArgs {
|
||||
#[clap(long)]
|
||||
#[clap(long, short)]
|
||||
provider: PredictionProvider,
|
||||
#[clap(long, default_value_t = 1)]
|
||||
repetitions: usize,
|
||||
#[clap(
|
||||
long,
|
||||
short,
|
||||
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
|
||||
value_parser = ZetaVersion::parse,
|
||||
)]
|
||||
version: ZetaVersion,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
|
||||
|
|
@ -514,7 +524,7 @@ fn main() {
|
|||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
args,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
|
|
@ -523,8 +533,7 @@ fn main() {
|
|||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
args,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
PredictionProvider, PromptFormat,
|
||||
FormatPromptArgs, PredictArgs, PredictionProvider,
|
||||
anthropic_client::AnthropicClient,
|
||||
example::{Example, ExamplePrediction, ExamplePrompt},
|
||||
format_prompt::{TeacherPrompt, run_format_prompt},
|
||||
|
|
@ -25,12 +25,13 @@ static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
|
|||
|
||||
pub async fn run_prediction(
|
||||
example: &mut Example,
|
||||
provider: Option<PredictionProvider>,
|
||||
repetition_count: usize,
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
let provider = provider.context("provider is required")?;
|
||||
let provider = args.provider;
|
||||
let repetition_count = args.repetitions;
|
||||
let zeta_version = args.version;
|
||||
|
||||
if let Some(existing_prediction) = example.predictions.first() {
|
||||
if existing_prediction.provider == provider {
|
||||
|
|
@ -48,7 +49,16 @@ pub async fn run_prediction(
|
|||
) {
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
|
||||
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
|
||||
run_format_prompt(
|
||||
example,
|
||||
&FormatPromptArgs {
|
||||
provider,
|
||||
version: args.version,
|
||||
},
|
||||
app_state.clone(),
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let batched = matches!(provider, PredictionProvider::Teacher);
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
|
|
@ -85,7 +95,9 @@ pub async fn run_prediction(
|
|||
ep_store.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
|
||||
version: zeta_version,
|
||||
},
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
|
|
@ -127,7 +139,7 @@ pub async fn run_prediction(
|
|||
updated_example.prompt.get_or_insert(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: String::new(),
|
||||
format: PromptFormat::Zeta2,
|
||||
provider,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -149,8 +149,7 @@ fn examples_from_response(
|
|||
match parse_result {
|
||||
Ok(spec) => Some(Example {
|
||||
spec,
|
||||
buffer: None,
|
||||
context: None,
|
||||
prompt_inputs: None,
|
||||
prompt: None,
|
||||
predictions: Vec::new(),
|
||||
score: Vec::new(),
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
example::{Example, ExampleContext},
|
||||
example::Example,
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
|
|
@ -19,7 +19,11 @@ pub async fn run_context_retrieval(
|
|||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
if example.context.is_some() {
|
||||
if example
|
||||
.prompt_inputs
|
||||
.as_ref()
|
||||
.is_some_and(|inputs| inputs.related_files.is_some())
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
|
@ -63,9 +67,9 @@ pub async fn run_context_retrieval(
|
|||
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
|
||||
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
|
||||
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
|
||||
prompt_inputs.related_files = Some(context_files);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,19 +17,12 @@ pub async fn run_scoring(
|
|||
app_state: Arc<EpAppState>,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
run_prediction(example, args, app_state, cx).await?;
|
||||
|
||||
let progress = Progress::global().start(Step::Score, &example.spec.name);
|
||||
|
||||
progress.set_substatus("applying patches");
|
||||
let original_text = &example.buffer.as_ref().unwrap().content;
|
||||
let original_text = &example.prompt_inputs.as_ref().unwrap().content;
|
||||
let expected_texts: Vec<String> = example
|
||||
.spec
|
||||
.expected_patches
|
||||
|
|
|
|||
|
|
@ -204,7 +204,9 @@ fn assign_edit_prediction_provider(
|
|||
} else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
|
||||
&& cx.has_flag::<Zeta2FeatureFlag>()
|
||||
{
|
||||
edit_prediction::EditPredictionModel::Zeta2
|
||||
edit_prediction::EditPredictionModel::Zeta2 {
|
||||
version: Default::default(),
|
||||
}
|
||||
} else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
|
||||
&& cx.has_flag::<MercuryFeatureFlag>()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -12,4 +12,6 @@ workspace = true
|
|||
path = "src/zeta_prompt.rs"
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
anyhow.workspace = true
|
||||
serde.workspace = true
|
||||
strum.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
|
||||
|
||||
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
|
||||
|
|
@ -16,6 +18,54 @@ pub struct ZetaPromptInput {
|
|||
pub related_files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum ZetaVersion {
|
||||
V0112_MiddleAtEnd,
|
||||
#[default]
|
||||
V0113_Ordered,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ZetaVersion {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", <&'static str>::from(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl ZetaVersion {
|
||||
pub fn parse(version_string: &str) -> Result<Self> {
|
||||
let mut results = ZetaVersion::iter().filter(|version| {
|
||||
<&'static str>::from(version)
|
||||
.to_lowercase()
|
||||
.contains(&version_string.to_lowercase())
|
||||
});
|
||||
let Some(result) = results.next() else {
|
||||
anyhow::bail!(
|
||||
"`{version_string}` did not match any of:\n{}",
|
||||
Self::options_as_string()
|
||||
);
|
||||
};
|
||||
if results.next().is_some() {
|
||||
anyhow::bail!(
|
||||
"`{version_string}` matched more than one of:\n{}",
|
||||
Self::options_as_string()
|
||||
);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn options_as_string() -> String {
|
||||
ZetaVersion::iter()
|
||||
.map(|version| format!("- {}\n", <&'static str>::from(version)))
|
||||
.collect::<Vec<_>>()
|
||||
.concat()
|
||||
}
|
||||
|
||||
pub fn default_as_string() -> String {
|
||||
<&'static str>::from(Self::default()).to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "event")]
|
||||
pub enum Event {
|
||||
|
|
@ -69,11 +119,20 @@ pub struct RelatedExcerpt {
|
|||
pub text: Arc<str>,
|
||||
}
|
||||
|
||||
pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
|
||||
pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
|
||||
let mut prompt = String::new();
|
||||
write_related_files(&mut prompt, &input.related_files);
|
||||
write_edit_history_section(&mut prompt, input);
|
||||
write_cursor_excerpt_section(&mut prompt, input);
|
||||
|
||||
match version {
|
||||
ZetaVersion::V0112_MiddleAtEnd => {
|
||||
v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
|
||||
}
|
||||
ZetaVersion::V0113_Ordered => {
|
||||
v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
|
||||
}
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
|
|
@ -100,31 +159,73 @@ fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
|
|||
}
|
||||
}
|
||||
|
||||
fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
|
||||
let path_str = input.cursor_path.to_string_lossy();
|
||||
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
|
||||
mod v0112_middle_at_end {
|
||||
use super::*;
|
||||
|
||||
prompt.push_str("<|fim_prefix|>\n");
|
||||
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
|
||||
pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
|
||||
let path_str = input.cursor_path.to_string_lossy();
|
||||
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
|
||||
|
||||
prompt.push_str("<|fim_suffix|>\n");
|
||||
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
prompt.push_str("<|fim_prefix|>\n");
|
||||
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
|
||||
|
||||
prompt.push_str("<|fim_suffix|>\n");
|
||||
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_middle|>current\n");
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_MARKER);
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
|
||||
);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_middle|>updated\n");
|
||||
}
|
||||
}
|
||||
|
||||
mod v0113_ordered {
|
||||
use super::*;
|
||||
|
||||
pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
|
||||
let path_str = input.cursor_path.to_string_lossy();
|
||||
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
|
||||
|
||||
prompt.push_str("<|fim_prefix|>\n");
|
||||
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_middle|>current\n");
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_MARKER);
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
|
||||
);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_suffix|>\n");
|
||||
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_middle|>updated\n");
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_middle|>current\n");
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_MARKER);
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
|
||||
);
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("<|fim_middle|>updated\n");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue