Introduce zeta2 format with cursor content in original order (#46732)

This one does `fim_prefix`, `fim_middle`, and `fim_suffix` in that
order, in the prompt, instead of putting the current middle last.

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
This commit is contained in:
Max Brunsfeld 2026-01-13 13:53:44 -08:00 committed by GitHub
parent c9003e1a12
commit 20284e4f21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 384 additions and 162 deletions

2
Cargo.lock generated
View file

@ -21179,7 +21179,9 @@ dependencies = [
name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
"serde",
"strum 0.27.2",
]
[[package]]

View file

@ -38,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
use std::collections::{VecDeque, hash_map};
use text::Edit;
use workspace::Workspace;
use zeta_prompt::ZetaVersion;
use std::ops::Range;
use std::path::Path;
@ -183,7 +184,9 @@ pub struct EditPredictionStore {
pub enum EditPredictionModel {
#[default]
Zeta1,
Zeta2,
Zeta2 {
version: ZetaVersion,
},
Sweep,
Mercury,
}
@ -654,7 +657,9 @@ impl EditPredictionStore {
update_required: false,
#[cfg(feature = "cli-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
edit_prediction_model: EditPredictionModel::Zeta2 {
version: Default::default(),
},
sweep_ai: SweepAi::new(cx),
mercury: Mercury::new(cx),
data_collection_choice,
@ -794,7 +799,10 @@ impl EditPredictionStore {
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
if self.edit_prediction_model == EditPredictionModel::Zeta2 {
if matches!(
self.edit_prediction_model,
EditPredictionModel::Zeta2 { .. }
) {
self.user_store.read(cx).edit_prediction_usage()
} else {
None
@ -1204,7 +1212,7 @@ impl EditPredictionStore {
sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
}
EditPredictionModel::Mercury => {}
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
zeta2::edit_prediction_accepted(self, current_prediction, cx)
}
}
@ -1338,7 +1346,7 @@ impl EditPredictionStore {
was_shown: bool,
) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
if self.custom_predict_edits_url.is_some() {
return;
}
@ -1773,7 +1781,9 @@ impl EditPredictionStore {
}
let task = match self.edit_prediction_model {
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
EditPredictionModel::Zeta2 { version } => {
zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
}
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};

View file

@ -1332,12 +1332,20 @@ fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawComp
let current_marker = "<|fim_middle|>current\n";
let updated_marker = "<|fim_middle|>updated\n";
let suffix_marker = "<|fim_suffix|>\n";
let cursor = "<|user_cursor|>";
let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
// In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
// Strip that out to get just the editable region.
let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
&excerpt[..suffix_pos]
} else {
&excerpt
};
let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
RawCompletionResponse {
id: Uuid::new_v4().to_string(),
@ -1629,6 +1637,82 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
// Test that zeta2's newline normalization logic doesn't insert spurious newlines.
// When the buffer ends without a trailing newline, but the model returns output
// with a trailing newline, zeta2 should normalize both sides before diffing
// so no spurious newline is inserted.
let (ep_store, mut requests) = init_test_with_fake_client(cx);
let fs = FakeFs::new(cx.executor());
// Single line buffer with no trailing newline
fs.insert_tree(
"/root",
json!({
"foo.txt": "hello"
}),
)
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("root/foo.txt"), cx)
.unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot.anchor_before(language::Point::new(0, 5));
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
// Model returns output WITH a trailing newline, even though the buffer doesn't have one.
// Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
let response = RawCompletionResponse {
id: Uuid::new_v4().to_string(),
object: "text_completion".into(),
created: 0,
model: "model".into(),
choices: vec![RawCompletionChoice {
text: "hello world\n".to_string(),
finish_reason: None,
}],
usage: RawCompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
};
respond_tx.send(response).unwrap();
cx.run_until_parked();
// The prediction should insert " world" without adding a newline
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.prediction_at(&buffer, None, &project, cx)
.expect("should have prediction");
let edits: Vec<_> = prediction
.edits
.iter()
.map(|(range, text)| {
let snapshot = buffer.read(cx).snapshot();
(range.to_offset(&snapshot), text.clone())
})
.collect();
assert_eq!(edits, vec![(5..5, " world".into())]);
});
}
#[gpui::test]
async fn test_can_collect_data(cx: &mut TestAppContext) {
init_test(cx);

View file

@ -15,8 +15,8 @@ use release_channel::AppVersion;
use std::env;
use std::{path::Path, sync::Arc, time::Instant};
use zeta_prompt::CURSOR_MARKER;
use zeta_prompt::format_zeta_prompt;
use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
pub const MAX_CONTEXT_TOKENS: usize = 350;
pub const MAX_EDITABLE_TOKENS: usize = 150;
@ -32,6 +32,7 @@ pub fn request_prediction_with_zeta2(
debug_tx,
..
}: EditPredictionModelInput,
zeta_version: ZetaVersion,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer_snapshotted_at = Instant::now();
@ -62,7 +63,7 @@ pub fn request_prediction_with_zeta2(
cursor_offset,
);
let prompt = format_zeta_prompt(&prompt_input);
let prompt = format_zeta_prompt(&prompt_input, zeta_version);
if let Some(debug_tx) = &debug_tx {
debug_tx
@ -125,9 +126,17 @@ pub fn request_prediction_with_zeta2(
output_text = output_text.replace(CURSOR_MARKER, "");
}
let old_text = snapshot
let mut old_text = snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
if !output_text.is_empty() && !output_text.ends_with('\n') {
output_text.push('\n');
}
if !old_text.is_empty() && !old_text.ends_with('\n') {
old_text.push('\n');
}
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
.into_iter()
.map(|(range, text)| {

View file

@ -1,5 +1,5 @@
use crate::PredictionProvider;
use crate::paths::WORKTREES_DIR;
use crate::{PredictionProvider, PromptFormat};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::example_spec::ExampleSpec;
@ -9,11 +9,12 @@ use http_client::Url;
use language::{Anchor, Buffer};
use project::Project;
use serde::{Deserialize, Serialize};
use std::ops::Range;
use std::{
borrow::Cow,
io::Read,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use zeta_prompt::RelatedFile;
@ -25,12 +26,7 @@ pub struct Example {
/// The full content of the file where an edit is being predicted, and the
/// actual cursor offset.
#[serde(skip_serializing_if = "Option::is_none")]
pub buffer: Option<ExampleBuffer>,
/// The context retrieved for the prediction. This requires the worktree to
/// be loaded and the language server to be started.
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<ExampleContext>,
pub prompt_inputs: Option<ExamplePromptInputs>,
/// The input and expected output from the edit prediction model.
#[serde(skip_serializing_if = "Option::is_none")]
@ -59,25 +55,22 @@ pub struct ExampleState {
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleContext {
pub files: Vec<RelatedFile>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleBuffer {
pub struct ExamplePromptInputs {
pub content: String,
pub cursor_row: u32,
pub cursor_column: u32,
pub cursor_offset: usize,
pub context_range: Range<usize>,
pub editable_range: Range<usize>,
pub edit_history: Vec<Arc<zeta_prompt::Event>>,
pub related_files: Option<Vec<RelatedFile>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
pub format: PromptFormat,
pub provider: PredictionProvider,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@ -239,8 +232,7 @@ fn parse_markdown_example(input: &str) -> Result<Example> {
let spec = ExampleSpec::from_markdown(input)?;
Ok(Example {
spec,
buffer: None,
context: None,
prompt_inputs: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),

View file

@ -1,14 +1,12 @@
use crate::{
PromptFormat,
FormatPromptArgs, PredictionProvider,
example::{Example, ExamplePrompt},
headless::EpAppState,
load_project::run_load_project,
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result};
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
use gpui::{AsyncApp, Entity};
use gpui::AsyncApp;
use similar::DiffableStr;
use std::fmt::Write as _;
use std::sync::Arc;
@ -16,16 +14,21 @@ use zeta_prompt::format_zeta_prompt;
pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
args: &FormatPromptArgs,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
cx: AsyncApp,
) -> Result<()> {
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
run_context_retrieval(example, app_state, cx).await?;
let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
match prompt_format {
PromptFormat::Teacher => {
let prompt_inputs = example
.prompt_inputs
.as_ref()
.context("prompt_inputs must be set after context retrieval")?;
match args.provider {
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
step_progress.set_substatus("formatting teacher prompt");
let prompt = TeacherPrompt::format_prompt(example);
example.prompt = Some(ExamplePrompt {
@ -36,47 +39,27 @@ pub async fn run_format_prompt(
.first()
.cloned()
.unwrap_or_default(),
format: prompt_format,
provider: args.provider,
});
}
PromptFormat::Zeta2 => {
step_progress.set_substatus("loading project");
run_load_project(example, app_state, cx.clone()).await?;
PredictionProvider::Zeta2 => {
step_progress.set_substatus("formatting zeta2 prompt");
let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})?;
let state = example.state.as_ref().context("state must be set")?;
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
let project = state.project.clone();
let (_, input) =
ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
let events = ep_store
.edit_history_for_project(&project, cx)
.into_iter()
.map(|e| e.event)
.collect();
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
.context
.as_ref()
.context("context must be set")?
.files
.clone(),
events,
example.spec.cursor_path.clone(),
example
.buffer
.as_ref()
.context("buffer must be set")?
.cursor_offset,
))
})?;
let prompt = format_zeta_prompt(&input);
let context_start = prompt_inputs.context_range.start;
let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start)
..(prompt_inputs.editable_range.end - context_start);
let input = zeta_prompt::ZetaPromptInput {
cursor_path: example.spec.cursor_path.clone(),
cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()]
.to_string()
.into(),
editable_range_in_excerpt,
cursor_offset_in_excerpt,
events: prompt_inputs.edit_history.clone(),
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
};
let prompt = format_zeta_prompt(&input, args.version);
let expected_output = zeta2_output_for_patch(
&input,
&example
@ -89,9 +72,12 @@ pub async fn run_format_prompt(
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
format: prompt_format,
provider: args.provider,
});
}
_ => {
panic!("Cannot format prompt for {:?}", args.provider);
}
};
Ok(())
}
@ -144,10 +130,10 @@ impl TeacherPrompt {
// 2. Context retriever just didn't include cursor line.
//
// In that case, fallback to using `cursor_position` as excerpt.
let example_buffer = example
.buffer
let prompt_inputs = example
.prompt_inputs
.as_ref()
.context("`buffer` should be filled in in the context collection step")?;
.context("`prompt_inputs` should be filled in in the context collection step")?;
// Extract updated (new) editable region from the model response.
// The model may include editable region markers in its output, so we need to strip them.
@ -155,7 +141,7 @@ impl TeacherPrompt {
let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
let old_editable_region =
example_buffer.content[example_buffer.editable_range.clone()].to_string();
prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string();
// Normalize leading newlines: if old starts with newline but new doesn't,
// prepend newline to new to preserve whitespace structure.
@ -164,8 +150,8 @@ impl TeacherPrompt {
new_editable_region.insert(0, '\n');
}
let editable_region_start_line = example_buffer.content
[..example_buffer.editable_range.start]
let editable_region_start_line = prompt_inputs.content
[..prompt_inputs.editable_range.start]
.matches('\n')
.count();
@ -208,17 +194,21 @@ impl TeacherPrompt {
}
fn format_context(example: &Example) -> String {
let context = example
.context
let related_files = example
.prompt_inputs
.as_ref()
.expect("Missing context retriever step");
.and_then(|pi| pi.related_files.as_ref());
if context.files.is_empty() {
let Some(related_files) = related_files else {
return "(No context)".to_string();
};
if related_files.is_empty() {
return "(No context)".to_string();
}
let mut prompt = String::new();
for file in context.files.iter() {
for file in related_files {
let path_str = file.path.to_string_lossy();
writeln!(&mut prompt, "`````{path_str}").ok();
let mut prev_row = 0;
@ -242,28 +232,26 @@ impl TeacherPrompt {
fn format_cursor_excerpt(example: &Example) -> String {
let mut result = String::new();
let example_buffer = example.buffer.as_ref().unwrap();
let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
let path_str = example.spec.cursor_path.to_string_lossy();
result.push_str(&format!("`````{path_str}\n"));
result.push_str(
&example_buffer.content
[example_buffer.context_range.start..example_buffer.editable_range.start],
&prompt_inputs.content
[prompt_inputs.context_range.start..prompt_inputs.editable_range.start],
);
result.push_str(Self::EDITABLE_REGION_START);
result.push_str(
&example_buffer.content
[example_buffer.editable_range.start..example_buffer.cursor_offset],
&prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset],
);
result.push_str(Self::USER_CURSOR_MARKER);
result.push_str(
&example_buffer.content
[example_buffer.cursor_offset..example_buffer.editable_range.end],
&prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end],
);
result.push_str(Self::EDITABLE_REGION_END);
result.push_str(
&example_buffer.content
[example_buffer.editable_range.end..example_buffer.context_range.end],
&prompt_inputs.content
[prompt_inputs.editable_range.end..prompt_inputs.context_range.end],
);
result.push_str("\n`````");

View file

@ -1,5 +1,5 @@
use crate::{
example::{Example, ExampleBuffer, ExampleState},
example::{Example, ExamplePromptInputs, ExampleState},
git,
headless::EpAppState,
progress::{InfoStyle, Progress, Step, StepProgress},
@ -38,7 +38,20 @@ pub async fn run_load_project(
buffer
.read_with(&cx, |buffer, _| buffer.parsing_idle())
.await;
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx))
.context("EditPredictionStore not initialized")?;
let edit_history = ep_store.update(&mut cx, |store, cx| {
store
.edit_history_for_project(&project, cx)
.into_iter()
.map(|e| e.event)
.collect()
});
let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let snapshot = buffer.snapshot();
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
@ -54,13 +67,15 @@ pub async fn run_load_project(
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
ExamplePromptInputs {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
context_range,
editable_range,
edit_history,
related_files: None,
},
language_name,
)
@ -68,7 +83,7 @@ pub async fn run_load_project(
progress.set_info(language_name, InfoStyle::Normal);
example.buffer = Some(example_buffer);
example.prompt_inputs = Some(prompt_inputs);
example.state = Some(ExampleState {
buffer,
project,

View file

@ -22,6 +22,7 @@ use edit_prediction::EditPredictionStore;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use gpui::{AppContext as _, Application};
use zeta_prompt::ZetaVersion;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
@ -155,7 +156,7 @@ impl Display for Command {
f,
"format-prompt --prompt-format={}",
format_prompt_args
.prompt_format
.provider
.to_possible_value()
.unwrap()
.get_name()
@ -204,22 +205,31 @@ impl Display for Command {
#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
#[clap(long, short('p'))]
prompt_format: PromptFormat,
}
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PromptFormat {
Teacher,
Zeta2,
#[clap(long, short)]
provider: PredictionProvider,
#[clap(
long,
short,
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
value_parser = ZetaVersion::parse,
default_value_t = ZetaVersion::default(),
)]
version: ZetaVersion,
}
#[derive(Debug, Args, Clone)]
struct PredictArgs {
#[clap(long)]
#[clap(long, short)]
provider: PredictionProvider,
#[clap(long, default_value_t = 1)]
repetitions: usize,
#[clap(
long,
short,
help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
value_parser = ZetaVersion::parse,
)]
version: ZetaVersion,
}
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
@ -514,7 +524,7 @@ fn main() {
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
args,
app_state.clone(),
cx.clone(),
)
@ -523,8 +533,7 @@ fn main() {
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
args,
app_state.clone(),
cx.clone(),
)

View file

@ -1,5 +1,5 @@
use crate::{
PredictionProvider, PromptFormat,
FormatPromptArgs, PredictArgs, PredictionProvider,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction, ExamplePrompt},
format_prompt::{TeacherPrompt, run_format_prompt},
@ -25,12 +25,13 @@ static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
pub async fn run_prediction(
example: &mut Example,
provider: Option<PredictionProvider>,
repetition_count: usize,
args: &PredictArgs,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
let provider = provider.context("provider is required")?;
let provider = args.provider;
let repetition_count = args.repetitions;
let zeta_version = args.version;
if let Some(existing_prediction) = example.predictions.first() {
if existing_prediction.provider == provider {
@ -48,7 +49,16 @@ pub async fn run_prediction(
) {
let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
run_format_prompt(
example,
&FormatPromptArgs {
provider,
version: args.version,
},
app_state.clone(),
cx,
)
.await?;
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
@ -85,7 +95,9 @@ pub async fn run_prediction(
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
version: zeta_version,
},
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
@ -127,7 +139,7 @@ pub async fn run_prediction(
updated_example.prompt.get_or_insert(ExamplePrompt {
input: prompt,
expected_output: String::new(),
format: PromptFormat::Zeta2,
provider,
});
}
}

View file

@ -149,8 +149,7 @@ fn examples_from_response(
match parse_result {
Ok(spec) => Some(Example {
spec,
buffer: None,
context: None,
prompt_inputs: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),

View file

@ -1,5 +1,5 @@
use crate::{
example::{Example, ExampleContext},
example::Example,
headless::EpAppState,
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
@ -19,7 +19,11 @@ pub async fn run_context_retrieval(
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
if example.context.is_some() {
if example
.prompt_inputs
.as_ref()
.is_some_and(|inputs| inputs.related_files.is_some())
{
return Ok(());
}
@ -63,9 +67,9 @@ pub async fn run_context_retrieval(
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
example.context = Some(ExampleContext {
files: context_files,
});
if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
prompt_inputs.related_files = Some(context_files);
}
Ok(())
}

View file

@ -17,19 +17,12 @@ pub async fn run_scoring(
app_state: Arc<EpAppState>,
cx: AsyncApp,
) -> anyhow::Result<()> {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state,
cx,
)
.await?;
run_prediction(example, args, app_state, cx).await?;
let progress = Progress::global().start(Step::Score, &example.spec.name);
progress.set_substatus("applying patches");
let original_text = &example.buffer.as_ref().unwrap().content;
let original_text = &example.prompt_inputs.as_ref().unwrap().content;
let expected_texts: Vec<String> = example
.spec
.expected_patches

View file

@ -204,7 +204,9 @@ fn assign_edit_prediction_provider(
} else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>()
{
edit_prediction::EditPredictionModel::Zeta2
edit_prediction::EditPredictionModel::Zeta2 {
version: Default::default(),
}
} else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<MercuryFeatureFlag>()
{

View file

@ -12,4 +12,6 @@ workspace = true
path = "src/zeta_prompt.rs"
[dependencies]
serde.workspace = true
anyhow.workspace = true
serde.workspace = true
strum.workspace = true

View file

@ -1,8 +1,10 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::ops::Range;
use std::path::Path;
use std::sync::Arc;
use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
@ -16,6 +18,54 @@ pub struct ZetaPromptInput {
pub related_files: Vec<RelatedFile>,
}
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
#[allow(non_camel_case_types)]
pub enum ZetaVersion {
V0112_MiddleAtEnd,
#[default]
V0113_Ordered,
}
impl std::fmt::Display for ZetaVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", <&'static str>::from(self))
}
}
impl ZetaVersion {
pub fn parse(version_string: &str) -> Result<Self> {
let mut results = ZetaVersion::iter().filter(|version| {
<&'static str>::from(version)
.to_lowercase()
.contains(&version_string.to_lowercase())
});
let Some(result) = results.next() else {
anyhow::bail!(
"`{version_string}` did not match any of:\n{}",
Self::options_as_string()
);
};
if results.next().is_some() {
anyhow::bail!(
"`{version_string}` matched more than one of:\n{}",
Self::options_as_string()
);
}
Ok(result)
}
fn options_as_string() -> String {
ZetaVersion::iter()
.map(|version| format!("- {}\n", <&'static str>::from(version)))
.collect::<Vec<_>>()
.concat()
}
pub fn default_as_string() -> String {
<&'static str>::from(Self::default()).to_string()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum Event {
@ -69,11 +119,20 @@ pub struct RelatedExcerpt {
pub text: Arc<str>,
}
pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
let mut prompt = String::new();
write_related_files(&mut prompt, &input.related_files);
write_edit_history_section(&mut prompt, input);
write_cursor_excerpt_section(&mut prompt, input);
match version {
ZetaVersion::V0112_MiddleAtEnd => {
v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
}
ZetaVersion::V0113_Ordered => {
v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
}
}
prompt
}
@ -100,31 +159,73 @@ fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
}
}
fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
let path_str = input.cursor_path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
mod v0112_middle_at_end {
use super::*;
prompt.push_str("<|fim_prefix|>\n");
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
let path_str = input.cursor_path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
prompt.push_str("<|fim_suffix|>\n");
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
prompt.push_str("<|fim_prefix|>\n");
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
prompt.push_str("<|fim_suffix|>\n");
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>current\n");
prompt.push_str(
&input.cursor_excerpt
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_MARKER);
prompt.push_str(
&input.cursor_excerpt
[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>updated\n");
}
}
mod v0113_ordered {
use super::*;
pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
let path_str = input.cursor_path.to_string_lossy();
write!(prompt, "<|file_sep|>{}\n", path_str).ok();
prompt.push_str("<|fim_prefix|>\n");
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>current\n");
prompt.push_str(
&input.cursor_excerpt
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_MARKER);
prompt.push_str(
&input.cursor_excerpt
[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_suffix|>\n");
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>updated\n");
}
prompt.push_str("<|fim_middle|>current\n");
prompt.push_str(
&input.cursor_excerpt
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_MARKER);
prompt.push_str(
&input.cursor_excerpt[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("<|fim_middle|>updated\n");
}