zeta2: Merge Sweep and Zeta2 Providers (#43097)

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Ben Kunkle 2025-11-19 15:40:06 -08:00 committed by GitHub
parent c70f2d16ad
commit f2f40a5099
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 554 additions and 1199 deletions

32
Cargo.lock generated
View file

@ -5314,13 +5314,13 @@ dependencies = [
"serde_json",
"settings",
"supermaven",
"sweep_ai",
"telemetry",
"theme",
"ui",
"workspace",
"zed_actions",
"zeta",
"zeta2",
]
[[package]]
@ -16590,33 +16590,6 @@ dependencies = [
"zeno",
]
[[package]]
name = "sweep_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"brotli",
"client",
"collections",
"edit_prediction",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",
"indoc",
"language",
"project",
"release_channel",
"reqwest_client",
"serde",
"serde_json",
"tree-sitter-rust",
"util",
"workspace",
"zlog",
]
[[package]]
name = "symphonia"
version = "0.5.5"
@ -21343,7 +21316,6 @@ dependencies = [
"snippets_ui",
"supermaven",
"svg_preview",
"sweep_ai",
"sysinfo 0.37.2",
"system_specs",
"tab_switcher",
@ -21754,6 +21726,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"brotli",
"chrono",
"client",
"clock",
@ -21864,7 +21837,6 @@ dependencies = [
"shellexpand 2.1.2",
"smol",
"soa-rs",
"sweep_ai",
"terminal_view",
"toml 0.8.23",
"util",

View file

@ -165,7 +165,6 @@ members = [
"crates/sum_tree",
"crates/supermaven",
"crates/supermaven_api",
"crates/sweep_ai",
"crates/codestral",
"crates/svg_preview",
"crates/system_specs",
@ -399,7 +398,6 @@ streaming_diff = { path = "crates/streaming_diff" }
sum_tree = { path = "crates/sum_tree" }
supermaven = { path = "crates/supermaven" }
supermaven_api = { path = "crates/supermaven_api" }
sweep_ai = { path = "crates/sweep_ai" }
codestral = { path = "crates/codestral" }
system_specs = { path = "crates/system_specs" }
tab_switcher = { path = "crates/tab_switcher" }

View file

@ -30,12 +30,12 @@ project.workspace = true
regex.workspace = true
settings.workspace = true
supermaven.workspace = true
sweep_ai.workspace = true
telemetry.workspace = true
ui.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta.workspace = true
zeta2.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }

View file

@ -28,7 +28,6 @@ use std::{
time::Duration,
};
use supermaven::{AccountStatus, Supermaven};
use sweep_ai::SweepFeatureFlag;
use ui::{
Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton,
IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*,
@ -39,6 +38,7 @@ use workspace::{
};
use zed_actions::OpenBrowser;
use zeta::RateCompletions;
use zeta2::SweepFeatureFlag;
actions!(
edit_prediction,

View file

@ -1,43 +0,0 @@
[package]
name = "sweep_ai"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
exclude = ["fixtures"]
[lints]
workspace = true
[lib]
path = "src/sweep_ai.rs"
doctest = false
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
collections.workspace = true
edit_prediction.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
language.workspace = true
project.workspace = true
release_channel.workspace = true
serde.workspace = true
serde_json.workspace = true
util.workspace = true
workspace.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
reqwest_client = { workspace = true, features = ["test-support"] }
tree-sitter-rust.workspace = true
workspace = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View file

@ -1 +0,0 @@
../../LICENSE-GPL

View file

@ -1,784 +0,0 @@
mod api;
use anyhow::{Context as _, Result};
use arrayvec::ArrayVec;
use client::telemetry;
use collections::HashMap;
use feature_flags::FeatureFlag;
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
use http_client::{AsyncBody, Method};
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::{AppCommitSha, AppVersion};
use std::collections::{VecDeque, hash_map};
use std::fmt::{self, Display};
use std::mem;
use std::{
cmp,
fmt::Write,
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use util::ResultExt;
use util::rel_path::RelPath;
use workspace::Workspace;
use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
const MAX_EVENT_COUNT: usize = 6;
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
pub struct SweepFeatureFlag;
impl FeatureFlag for SweepFeatureFlag {
const NAME: &str = "sweep-ai";
}
#[derive(Clone)]
struct SweepAiGlobal(Entity<SweepAi>);
impl Global for SweepAiGlobal {}
#[derive(Clone)]
pub struct EditPrediction {
pub id: EditPredictionId,
pub path: Arc<Path>,
pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
}
impl EditPrediction {
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
impl fmt::Debug for EditPrediction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EditPrediction")
.field("path", &self.path)
.field("edits", &self.edits)
.finish_non_exhaustive()
}
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(String);
impl Display for EditPredictionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct SweepAi {
projects: HashMap<EntityId, SweepAiProject>,
debug_info: Arc<str>,
api_token: Option<String>,
}
struct SweepAiProject {
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
}
impl SweepAi {
pub fn global(cx: &mut App) -> Option<Entity<Self>> {
cx.try_global::<SweepAiGlobal>()
.map(|global| global.0.clone())
}
pub fn register(cx: &mut App) -> Entity<Self> {
Self::global(cx).unwrap_or_else(|| {
let entity = cx.new(|cx| Self::new(cx));
cx.set_global(SweepAiGlobal(entity.clone()));
entity
})
}
pub fn clear_history(&mut self) {
for sweep_ai_project in self.projects.values_mut() {
sweep_ai_project.events.clear();
}
}
pub fn new(cx: &mut Context<Self>) -> Self {
Self {
api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
projects: HashMap::default(),
debug_info: format!(
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
version = AppVersion::global(cx),
sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
os = telemetry::os_name(),
)
.into(),
}
}
fn get_or_init_sweep_ai_project(
&mut self,
project: &Entity<Project>,
cx: &mut Context<Self>,
) -> &mut SweepAiProject {
let project_id = project.entity_id();
match self.projects.entry(project_id) {
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
cx.observe_release(project, move |this, _, _cx| {
this.projects.remove(&project_id);
})
.detach();
entry.insert(SweepAiProject {
events: VecDeque::with_capacity(MAX_EVENT_COUNT),
registered_buffers: HashMap::default(),
})
}
}
}
pub fn register_buffer(
&mut self,
buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &mut Context<Self>,
) {
let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
}
fn register_buffer_impl<'a>(
sweep_ai_project: &'a mut SweepAiProject,
buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &mut Context<Self>,
) -> &'a mut RegisteredBuffer {
let buffer_id = buffer.entity_id();
match sweep_ai_project.registered_buffers.entry(buffer_id) {
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let snapshot = buffer.read(cx).snapshot();
let project_entity_id = project.entity_id();
entry.insert(RegisteredBuffer {
snapshot,
_subscriptions: [
cx.subscribe(buffer, {
let project = project.downgrade();
move |this, buffer, event, cx| {
if let language::BufferEvent::Edited = event
&& let Some(project) = project.upgrade()
{
this.report_changes_for_buffer(&buffer, &project, cx);
}
}
}),
cx.observe_release(buffer, move |this, _buffer, _cx| {
let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
else {
return;
};
sweep_ai_project.registered_buffers.remove(&buffer_id);
}),
],
})
}
}
}
pub fn request_completion(
&mut self,
project: &Entity<Project>,
recent_buffers: impl Iterator<Item = ProjectPath>,
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = active_buffer.read(cx).snapshot();
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let project_state = self.get_or_init_sweep_ai_project(project, cx);
let events = project_state.events.clone();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
None
} else {
Some(buffer.read(cx).snapshot())
}
})
.take(3)
.collect::<Vec<_>>();
let result = cx.background_spawn({
let full_path = full_path.clone();
async move {
let text = snapshot.text();
let mut recent_changes = String::new();
for event in events {
writeln!(&mut recent_changes, "{event}")?;
}
let file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
FileChunk {
content: snapshot
.text_for_range(language::Point::zero()..end_point)
.collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect();
eprintln!("{recent_changes}");
let request_body = AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((response.autocomplete_id, edits, snapshot))
}
});
let buffer = active_buffer.clone();
cx.spawn(async move |_, cx| {
let (id, edits, old_snapshot) = result.await?;
if edits.is_empty() {
return anyhow::Ok(None);
}
let Some((edits, new_snapshot, preview_task)) =
buffer.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
.into();
let preview_task = buffer.preview_edits(edits.clone(), cx);
Some((edits, new_snapshot, preview_task))
})?
else {
return anyhow::Ok(None);
};
let prediction = EditPrediction {
id: EditPredictionId(id),
path: full_path,
edits,
snapshot: new_snapshot,
edit_preview: preview_task.await,
};
anyhow::Ok(Some(prediction))
})
}
fn report_changes_for_buffer(
&mut self,
buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &mut Context<Self>,
) {
let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
let new_snapshot = buffer.read(cx).snapshot();
if new_snapshot.version == registered_buffer.snapshot.version {
return;
}
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
let end_edit_anchor = new_snapshot
.anchored_edits_since::<Point>(&old_snapshot.version)
.last()
.map(|(_, range)| range.end);
let events = &mut sweep_ai_project.events;
if let Some(Event::BufferChange {
new_snapshot: last_new_snapshot,
end_edit_anchor: last_end_edit_anchor,
..
}) = events.back_mut()
{
let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
== last_new_snapshot.remote_id()
&& old_snapshot.version == last_new_snapshot.version;
let should_coalesce = is_next_snapshot_of_same_buffer
&& end_edit_anchor
.as_ref()
.zip(last_end_edit_anchor.as_ref())
.is_some_and(|(a, b)| {
let a = a.to_point(&new_snapshot);
let b = b.to_point(&new_snapshot);
a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
});
if should_coalesce {
*last_end_edit_anchor = end_edit_anchor;
*last_new_snapshot = new_snapshot;
return;
}
}
if events.len() >= MAX_EVENT_COUNT {
events.pop_front();
}
events.push_back(Event::BufferChange {
old_snapshot,
new_snapshot,
end_edit_anchor,
});
}
}
struct RegisteredBuffer {
snapshot: BufferSnapshot,
_subscriptions: [gpui::Subscription; 2],
}
#[derive(Clone)]
pub enum Event {
BufferChange {
old_snapshot: BufferSnapshot,
new_snapshot: BufferSnapshot,
end_edit_anchor: Option<Anchor>,
},
}
impl Display for Event {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Event::BufferChange {
old_snapshot,
new_snapshot,
..
} => {
let old_path = old_snapshot
.file()
.map(|f| f.path().as_ref())
.unwrap_or(RelPath::unix("untitled").unwrap());
let new_path = new_snapshot
.file()
.map(|f| f.path().as_ref())
.unwrap_or(RelPath::unix("untitled").unwrap());
if old_path != new_path {
// TODO confirm how to do this for sweep
// writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
}
let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
if !diff.is_empty() {
write!(
f,
"File: {}:\n{}\n",
new_path.display(util::paths::PathStyle::Posix),
diff
)?
}
fmt::Result::Ok(())
}
}
}
}
#[derive(Debug, Clone)]
struct CurrentEditPrediction {
buffer_id: EntityId,
completion: EditPrediction,
}
impl CurrentEditPrediction {
fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
if self.buffer_id != old_completion.buffer_id {
return true;
}
let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
return true;
};
let Some(new_edits) = self.completion.interpolate(snapshot) else {
return false;
};
if old_edits.len() == 1 && new_edits.len() == 1 {
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
new_range == old_range && new_text.starts_with(old_text.as_ref())
} else {
true
}
}
}
struct PendingCompletion {
id: usize,
_task: Task<()>,
}
pub struct SweepAiEditPredictionProvider {
workspace: WeakEntity<Workspace>,
sweep_ai: Entity<SweepAi>,
pending_completions: ArrayVec<PendingCompletion, 2>,
next_pending_completion_id: usize,
current_completion: Option<CurrentEditPrediction>,
last_request_timestamp: Instant,
project: Entity<Project>,
}
impl SweepAiEditPredictionProvider {
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
pub fn new(
sweep_ai: Entity<SweepAi>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
) -> Self {
Self {
sweep_ai,
pending_completions: ArrayVec::new(),
next_pending_completion_id: 0,
current_completion: None,
last_request_timestamp: Instant::now(),
project,
workspace,
}
}
}
impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
fn name() -> &'static str {
"zed-predict"
}
fn display_name() -> &'static str {
"Zed's Edit Predictions"
}
fn show_completions_in_menu() -> bool {
true
}
fn show_tab_accept_marker() -> bool {
true
}
fn is_enabled(
&self,
_buffer: &Entity<Buffer>,
_cursor_position: language::Anchor,
cx: &App,
) -> bool {
self.sweep_ai.read(cx).api_token.is_some()
}
fn is_refreshing(&self) -> bool {
!self.pending_completions.is_empty()
}
fn refresh(
&mut self,
buffer: Entity<Buffer>,
position: language::Anchor,
_debounce: bool,
cx: &mut Context<Self>,
) {
if let Some(current_completion) = self.current_completion.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if current_completion
.completion
.interpolate(&snapshot)
.is_some()
{
return;
}
}
let pending_completion_id = self.next_pending_completion_id;
self.next_pending_completion_id += 1;
let last_request_timestamp = self.last_request_timestamp;
let project = self.project.clone();
let workspace = self.workspace.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
{
cx.background_executor().timer(timeout).await;
}
let completion_request = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.sweep_ai.update(cx, |sweep_ai, cx| {
let Some(recent_buffers) = workspace
.read_with(cx, |workspace, cx| {
workspace.recent_navigation_history_iter(cx)
})
.log_err()
else {
return Task::ready(Ok(None));
};
sweep_ai.request_completion(
&project,
recent_buffers.map(move |(project_path, _)| project_path),
&buffer,
position,
cx,
)
})
});
let completion = match completion_request {
Ok(completion_request) => {
let completion_request = completion_request.await;
completion_request.map(|c| {
c.map(|completion| CurrentEditPrediction {
buffer_id: buffer.entity_id(),
completion,
})
})
}
Err(error) => Err(error),
};
let Some(new_completion) = completion
.context("edit prediction failed")
.log_err()
.flatten()
else {
this.update(cx, |this, cx| {
if this.pending_completions[0].id == pending_completion_id {
this.pending_completions.remove(0);
} else {
this.pending_completions.clear();
}
cx.notify();
})
.ok();
return;
};
this.update(cx, |this, cx| {
if this.pending_completions[0].id == pending_completion_id {
this.pending_completions.remove(0);
} else {
this.pending_completions.clear();
}
if let Some(old_completion) = this.current_completion.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if new_completion.should_replace_completion(old_completion, &snapshot) {
this.current_completion = Some(new_completion);
}
} else {
this.current_completion = Some(new_completion);
}
cx.notify();
})
.ok();
});
// We always maintain at most two pending completions. When we already
// have two, we replace the newest one.
if self.pending_completions.len() <= 1 {
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
_task: task,
});
} else if self.pending_completions.len() == 2 {
self.pending_completions.pop();
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
_task: task,
});
}
}
fn cycle(
&mut self,
_buffer: Entity<Buffer>,
_cursor_position: language::Anchor,
_direction: edit_prediction::Direction,
_cx: &mut Context<Self>,
) {
// Right now we don't support cycling.
}
fn accept(&mut self, _cx: &mut Context<Self>) {
self.pending_completions.clear();
}
fn discard(&mut self, _cx: &mut Context<Self>) {
self.pending_completions.clear();
self.current_completion.take();
}
fn suggest(
&mut self,
buffer: &Entity<Buffer>,
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Option<edit_prediction::EditPrediction> {
let CurrentEditPrediction {
buffer_id,
completion,
..
} = self.current_completion.as_mut()?;
// Invalidate previous completion if it was generated for a different buffer.
if *buffer_id != buffer.entity_id() {
self.current_completion.take();
return None;
}
let buffer = buffer.read(cx);
let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
self.current_completion.take();
return None;
};
let cursor_row = cursor_position.to_point(buffer).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
let distance_from_closest_edit =
closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
break;
}
}
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit =
range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
break;
}
}
Some(edit_prediction::EditPrediction::Local {
id: Some(completion.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(completion.edit_preview.clone()),
})
}
}

View file

@ -133,7 +133,6 @@ snippet_provider.workspace = true
snippets_ui.workspace = true
supermaven.workspace = true
svg_preview.workspace = true
sweep_ai.workspace = true
sysinfo.workspace = true
tab_switcher.workspace = true
task.workspace = true

View file

@ -10,9 +10,9 @@ use language_models::MistralLanguageModelProvider;
use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag};
use ui::Window;
use zeta::ZetaEditPredictionProvider;
use zeta2::SweepFeatureFlag;
use zeta2::Zeta2FeatureFlag;
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@ -203,55 +203,41 @@ fn assign_edit_prediction_provider(
let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
EditPredictionProvider::Experimental(name) => {
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<SweepFeatureFlag>()
{
if let Some(project) = editor.project()
&& let Some(workspace) = editor.workspace()
{
let sweep_ai = sweep_ai::SweepAi::register(cx);
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
sweep_ai.update(cx, |sweep_ai, cx| {
sweep_ai.register_buffer(buffer, project, cx);
});
}
let provider = cx.new(|_| {
sweep_ai::SweepAiEditPredictionProvider::new(
sweep_ai,
workspace.downgrade(),
project.clone(),
)
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
} else {
editor.set_edit_prediction_provider::<SweepAiEditPredictionProvider>(
None, window, cx,
);
}
}
EditPredictionProvider::Zed => {
if user_store.read(cx).current_user().is_some() {
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
if let Some(project) = editor.project() {
let mut worktree = None;
if let Some(buffer) = &singleton_buffer
&& let Some(file) = buffer.read(cx).file()
{
let id = file.worktree_id(cx);
if let Some(inner_worktree) = editor
.project()
.and_then(|project| project.read(cx).worktree_for_id(id, cx))
{
worktree = Some(inner_worktree);
}
worktree = project.read(cx).worktree_for_id(id, cx);
}
if let Some(project) = editor.project() {
if let EditPredictionProvider::Experimental(name) = value
&& name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<SweepFeatureFlag>()
{
let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
project.clone(),
&client,
&user_store,
cx,
)
});
if let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
zeta2.update(cx, |zeta, cx| {
zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep);
zeta.register_buffer(buffer, project, cx);
});
}
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else if user_store.read(cx).current_user().is_some() {
if cx.has_flag::<Zeta2FeatureFlag>() {
let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
@ -268,6 +254,9 @@ fn assign_edit_prediction_provider(
&& buffer.read(cx).file().is_some()
{
zeta.update(cx, |zeta, cx| {
zeta.set_edit_prediction_model(
zeta2::ZetaEditPredictionModel::ZedCloud,
);
zeta.register_buffer(buffer, project, cx);
});
}

View file

@ -17,6 +17,7 @@ eval-support = []
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
brotli.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true

View file

@ -12,7 +12,7 @@ use language::ToPoint as _;
use project::Project;
use util::ResultExt as _;
use crate::{BufferEditPrediction, Zeta};
use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
@ -85,9 +85,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
&self,
_buffer: &Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &App,
cx: &App,
) -> bool {
true
let zeta = self.zeta.read(cx);
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
zeta.sweep_api_token.is_some()
} else {
true
}
}
fn is_refreshing(&self) -> bool {

View file

@ -1,6 +1,8 @@
use std::fmt;
use std::{path::Path, sync::Arc};
use serde::{Deserialize, Serialize};
use util::rel_path::RelPath;
#[derive(Debug, Clone, Serialize)]
pub struct AutocompleteRequest {
@ -88,3 +90,49 @@ pub struct AdditionalCompletion {
pub logprobs: Option<serde_json::Value>,
pub finish_reason: Option<String>,
}
pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
crate::Event::BufferChange {
old_snapshot,
new_snapshot,
..
} => {
let old_path = old_snapshot
.file()
.map(|f| f.path().as_ref())
.unwrap_or(RelPath::unix("untitled").unwrap());
let new_path = new_snapshot
.file()
.map(|f| f.path().as_ref())
.unwrap_or(RelPath::unix("untitled").unwrap());
if old_path != new_path {
// TODO confirm how to do this for sweep
// writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
}
let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
if !diff.is_empty() {
write!(
f,
"File: {}:\n{}\n",
new_path.display(util::paths::PathStyle::Posix),
diff
)?
}
fmt::Result::Ok(())
}
}
}
pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
format!(
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
version = release_channel::AppVersion::global(cx),
sha = release_channel::AppCommitSha::try_global(cx)
.map_or("unknown".to_string(), |sha| sha.full()),
os = client::telemetry::os_name(),
)
.into()
}

View file

@ -22,30 +22,31 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use open_ai::FunctionDefinition;
use project::Project;
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use serde::de::DeserializeOwned;
use std::collections::{VecDeque, hash_map};
use std::env;
use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use std::{env, mem};
use thiserror::Error;
use util::rel_path::RelPathBuf;
use util::{LogErrorFuture, TryFutureExt};
use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod assemble_excerpts;
mod prediction;
mod provider;
pub mod retrieval_search;
mod sweep_ai;
pub mod udiff;
mod xml_edits;
@ -55,8 +56,15 @@ pub use crate::prediction::EditPredictionId;
pub use provider::ZetaEditPredictionProvider;
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
const EVENT_COUNT_MAX_SWEEP: usize = 6;
const EVENT_COUNT_MAX_ZETA: usize = 16;
const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
pub struct SweepFeatureFlag;
impl FeatureFlag for SweepFeatureFlag {
const NAME: &str = "sweep-ai";
}
pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
max_bytes: 512,
min_bytes: 128,
@ -143,6 +151,15 @@ pub struct Zeta {
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: ZetaEditPredictionModel,
sweep_api_token: Option<String>,
sweep_ai_debug_info: Arc<str>,
}
#[derive(PartialEq, Eq)]
pub enum ZetaEditPredictionModel {
ZedCloud,
Sweep,
}
#[derive(Debug, Clone, PartialEq)]
@ -219,12 +236,14 @@ pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ZetaProject {
syntax_index: Option<Entity<SyntaxIndex>>,
events: VecDeque<Event>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
refresh_context_debounce_task: Option<Task<Option<()>>>,
refresh_context_timestamp: Option<Instant>,
_subscription: gpui::Subscription,
}
#[derive(Debug, Clone)]
@ -287,6 +306,7 @@ pub enum Event {
BufferChange {
old_snapshot: BufferSnapshot,
new_snapshot: BufferSnapshot,
end_edit_anchor: Option<Anchor>,
timestamp: Instant,
},
}
@ -381,9 +401,21 @@ impl Zeta {
debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
sweep_api_token: None,
sweep_ai_debug_info: sweep_ai::debug_info(cx),
}
}
pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
if model == ZetaEditPredictionModel::Sweep {
self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
.context("No SWEEP_AI_TOKEN environment variable set")
.log_err();
}
self.edit_prediction_model = model;
}
#[cfg(feature = "eval-support")]
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
self.eval_cache = Some(cache);
@ -443,7 +475,7 @@ impl Zeta {
self.user_store.read(cx).edit_prediction_usage()
}
pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
self.get_or_init_zeta_project(project, cx);
}
@ -460,7 +492,7 @@ impl Zeta {
fn get_or_init_zeta_project(
&mut self,
project: &Entity<Project>,
cx: &mut App,
cx: &mut Context<Self>,
) -> &mut ZetaProject {
self.projects
.entry(project.entity_id())
@ -473,12 +505,31 @@ impl Zeta {
None
},
events: VecDeque::new(),
recent_paths: VecDeque::new(),
registered_buffers: HashMap::default(),
current_prediction: None,
context: None,
refresh_context_task: None,
refresh_context_debounce_task: None,
refresh_context_timestamp: None,
_subscription: cx.subscribe(&project, |this, project, event, cx| {
// TODO [zeta2] init with recent paths
if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
let path = project.read(cx).path_for_entry(*active_entry_id, cx);
if let Some(path) = path {
if let Some(ix) = zeta_project
.recent_paths
.iter()
.position(|probe| probe == &path)
{
zeta_project.recent_paths.remove(ix);
}
zeta_project.recent_paths.push_front(path);
}
}
}
}),
})
}
@ -525,66 +576,64 @@ impl Zeta {
buffer: &Entity<Buffer>,
project: &Entity<Project>,
cx: &mut Context<Self>,
) -> BufferSnapshot {
let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
let zeta_project = self.get_or_init_zeta_project(project, cx);
let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
) {
let event_count_max = match self.edit_prediction_model {
ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
};
let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
let new_snapshot = buffer.read(cx).snapshot();
if new_snapshot.version != registered_buffer.snapshot.version {
let old_snapshot =
std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
Self::push_event(
zeta_project,
buffer_change_grouping_interval,
Event::BufferChange {
old_snapshot,
new_snapshot: new_snapshot.clone(),
timestamp: Instant::now(),
},
);
if new_snapshot.version == registered_buffer.snapshot.version {
return;
}
new_snapshot
}
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
let end_edit_anchor = new_snapshot
.anchored_edits_since::<Point>(&old_snapshot.version)
.last()
.map(|(_, range)| range.end);
let events = &mut sweep_ai_project.events;
fn push_event(
zeta_project: &mut ZetaProject,
buffer_change_grouping_interval: Duration,
event: Event,
) {
let events = &mut zeta_project.events;
if buffer_change_grouping_interval > Duration::ZERO
&& let Some(Event::BufferChange {
new_snapshot: last_new_snapshot,
timestamp: last_timestamp,
..
}) = events.back_mut()
if let Some(Event::BufferChange {
new_snapshot: last_new_snapshot,
end_edit_anchor: last_end_edit_anchor,
..
}) = events.back_mut()
{
// Coalesce edits for the same buffer when they happen one after the other.
let Event::BufferChange {
old_snapshot,
new_snapshot,
timestamp,
} = &event;
let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
== last_new_snapshot.remote_id()
&& old_snapshot.version == last_new_snapshot.version;
if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
&& old_snapshot.remote_id() == last_new_snapshot.remote_id()
&& old_snapshot.version == last_new_snapshot.version
{
*last_new_snapshot = new_snapshot.clone();
*last_timestamp = *timestamp;
let should_coalesce = is_next_snapshot_of_same_buffer
&& end_edit_anchor
.as_ref()
.zip(last_end_edit_anchor.as_ref())
.is_some_and(|(a, b)| {
let a = a.to_point(&new_snapshot);
let b = b.to_point(&new_snapshot);
a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
});
if should_coalesce {
*last_end_edit_anchor = end_edit_anchor;
*last_new_snapshot = new_snapshot;
return;
}
}
if events.len() >= MAX_EVENT_COUNT {
// These are halved instead of popping to improve prompt caching.
events.drain(..MAX_EVENT_COUNT / 2);
if events.len() >= event_count_max {
events.pop_front();
}
events.push_back(event);
events.push_back(Event::BufferChange {
old_snapshot,
new_snapshot,
end_edit_anchor,
timestamp: Instant::now(),
});
}
fn current_prediction_for_buffer(
@ -706,6 +755,203 @@ impl Zeta {
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
match self.edit_prediction_model {
ZetaEditPredictionModel::ZedCloud => {
self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
}
ZetaEditPredictionModel::Sweep => {
self.request_prediction_with_sweep(project, active_buffer, position, cx)
}
}
}
fn request_prediction_with_sweep(
&mut self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = active_buffer.read(cx).snapshot();
let debug_info = self.sweep_ai_debug_info.clone();
let Some(api_token) = self.sweep_api_token.clone() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let project_state = self.get_or_init_zeta_project(project, cx);
let events = project_state.events.clone();
let recent_buffers = project_state.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
None
} else {
Some(buffer.read(cx).snapshot())
}
})
.take(3)
.collect::<Vec<_>>();
let result = cx.background_spawn(async move {
let text = snapshot.text();
let mut recent_changes = String::new();
for event in events {
sweep_ai::write_event(event, &mut recent_changes).unwrap();
}
let file_chunks = recent_buffer_snapshots
.into_iter()
.map(|snapshot| {
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
sweep_ai::FileChunk {
content: snapshot
.text_for_range(language::Point::zero()..end_point)
.collect(),
file_path: snapshot
.file()
.map(|f| f.path().as_unix_str())
.unwrap_or("untitled")
.to_string(),
start_line: 0,
end_line: end_point.row as usize,
timestamp: snapshot.file().and_then(|file| {
Some(
file.disk_state()
.mtime()?
.to_seconds_and_nanos_for_persistence()?
.0,
)
}),
}
})
.collect();
let request_body = sweep_ai::AutocompleteRequest {
debug_info,
repo_name,
file_path: full_path.clone(),
file_contents: text.clone(),
original_file_contents: text,
cursor_position: offset,
recent_changes: recent_changes.clone(),
changes_above_cursor: true,
multiple_suggestions: false,
branch: None,
file_chunks,
retrieval_chunks: vec![],
recent_user_actions: vec![],
// TODO
privacy_mode_enabled: false,
};
let mut buf: Vec<u8> = Vec::new();
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
const SWEEP_API_URL: &str =
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_token))
.header("Connection", "keep-alive")
.header("Content-Encoding", "br")
.method(Method::POST)
.body(body)?;
let mut response = http_client.send(request).await?;
let mut body: Vec<u8> = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
String::from_utf8_lossy(&body),
);
};
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
text,
)
})
.collect::<Vec<_>>();
anyhow::Ok((response.autocomplete_id, edits, snapshot))
});
let buffer = active_buffer.clone();
cx.spawn(async move |_, cx| {
let (id, edits, old_snapshot) = result.await?;
if edits.is_empty() {
return anyhow::Ok(None);
}
let Some((edits, new_snapshot, preview_task)) =
buffer.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
.into();
let preview_task = buffer.preview_edits(edits.clone(), cx);
Some((edits, new_snapshot, preview_task))
})?
else {
return anyhow::Ok(None);
};
let prediction = EditPrediction {
id: EditPredictionId(id.into()),
edits,
snapshot: new_snapshot,
edit_preview: preview_task.await,
buffer,
};
anyhow::Ok(Some(prediction))
})
}
fn request_prediction_with_zed_cloud(
&mut self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let project_state = self.projects.get(&project.entity_id());
@ -1653,7 +1899,7 @@ impl Zeta {
pub fn wait_for_initial_indexing(
&mut self,
project: &Entity<Project>,
cx: &mut App,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let zeta_project = self.get_or_init_zeta_project(project, cx);
if let Some(syntax_index) = &zeta_project.syntax_index {

View file

@ -49,7 +49,6 @@ settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
soa-rs = "0.8.1"
sweep_ai.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true

View file

@ -8,16 +8,15 @@ use anyhow::Result;
use collections::HashSet;
use gpui::{AsyncApp, Entity};
use project::Project;
use sweep_ai::SweepAi;
use util::ResultExt as _;
use zeta2::{Zeta, udiff::DiffLine};
use crate::{
EvaluateArguments, PredictionOptions, PredictionProvider,
EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta},
predict::{PredictionDetails, perform_predict, setup_zeta},
};
#[derive(Debug)]
@ -46,46 +45,35 @@ pub async fn run_evaluate(
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
.map(|_| {
(
setup_zeta(&project, &app_state, cx).unwrap(),
if matches!(args.options.provider, PredictionProvider::Sweep) {
Some(setup_sweep(&project, cx).unwrap())
} else {
None
},
)
})
.map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let tasks =
providers
.into_iter()
.enumerate()
.map(move |(repetition_ix, (zeta, sweep))| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
let options = options.clone();
let tasks = providers
.into_iter()
.enumerate()
.map(move |(repetition_ix, zeta)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
let options = options.clone();
cx.spawn(async move |cx| {
let name = example.name.clone();
run_evaluate_one(
example,
repetition_ix,
project,
zeta,
sweep,
options,
!args.skip_prediction,
cx,
)
.await
.map_err(|err| (err, name, repetition_ix))
})
});
cx.spawn(async move |cx| {
let name = example.name.clone();
run_evaluate_one(
example,
repetition_ix,
project,
zeta,
options,
!args.skip_prediction,
cx,
)
.await
.map_err(|err| (err, name, repetition_ix))
})
});
futures::future::join_all(tasks).await
})
});
@ -177,7 +165,6 @@ pub async fn run_evaluate_one(
repetition_ix: Option<u16>,
project: Entity<Project>,
zeta: Entity<Zeta>,
sweep: Option<Entity<SweepAi>>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
@ -186,7 +173,6 @@ pub async fn run_evaluate_one(
example.clone(),
project,
zeta,
sweep,
repetition_ix,
prediction_options,
cx,

View file

@ -191,7 +191,7 @@ pub struct EvaluateArguments {
skip_prediction: bool,
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
enum PredictionProvider {
#[default]
Zeta2,

View file

@ -21,7 +21,6 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use sweep_ai::SweepAi;
use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
pub async fn run_predict(
@ -31,14 +30,9 @@ pub async fn run_predict(
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
let zeta = setup_zeta(&project, app_state, cx).unwrap();
let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
Some(setup_sweep(&project, cx).unwrap())
} else {
None
};
let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
let result = perform_predict(example, project, zeta, None, args.options, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
@ -47,6 +41,7 @@ pub async fn run_predict(
}
pub fn setup_zeta(
provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
@ -54,6 +49,14 @@ pub fn setup_zeta(
let zeta =
cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
zeta.update(cx, |zeta, _cx| {
let model = match provider {
PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
};
zeta.set_edit_prediction_model(model);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
@ -71,31 +74,10 @@ pub fn setup_zeta(
anyhow::Ok(zeta)
}
pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
let sweep = cx.new(|cx| SweepAi::new(cx))?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
let sweep = sweep.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
anyhow::Ok(sweep)
}
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
zeta: Entity<Zeta>,
sweep: Option<Entity<SweepAi>>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
@ -147,194 +129,152 @@ pub async fn perform_predict(
zeta.set_options(options);
})?;
let prediction = match options.provider {
crate::PredictionProvider::Zeta2 => {
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
let mut debug_task = gpui::Task::ready(Ok(()));
let debug_task = cx.background_spawn({
let result = result.clone();
async move {
let mut start_time = None;
let mut search_queries_generated_at = None;
let mut search_queries_executed_at = None;
while let Some(event) = debug_rx.next().await {
match event {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
example_run_dir.join("search_queries.json"),
serde_json::to_string_pretty(&info.search_queries).unwrap(),
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
}
zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
if options.provider == crate::PredictionProvider::Zeta2 {
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
{
let mut result = result.lock().unwrap();
result.prompt_len = prompt.chars().count();
for included_file in request.request.included_files {
let insertions =
vec![(request.request.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| {
ActualExcerpt {
path: included_file
.path
.components()
.skip(1)
.collect(),
text: String::from(excerpt.text.as_ref()),
}
},
));
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
if included_file.path == request.request.excerpt_path {
&insertions
} else {
&[]
},
included_file.max_row,
false,
&mut result.excerpts_text,
);
}
}
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response =
zeta2::text_from_response(response).unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(
example_run_dir.join("prediction_response.md"),
&response,
)?;
debug_task = cx.background_spawn({
let result = result.clone();
async move {
let mut start_time = None;
let mut search_queries_generated_at = None;
let mut search_queries_executed_at = None;
while let Some(event) = debug_rx.next().await {
match event {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
example_run_dir.join("search_queries.json"),
serde_json::to_string_pretty(&info.search_queries).unwrap(),
)?;
}
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
}
zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
{
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
result.prompt_len = prompt.chars().count();
if !options.use_expected_context {
result.planning_search_time = Some(
search_queries_generated_at.unwrap() - start_time.unwrap(),
);
result.running_search_time = Some(
search_queries_executed_at.unwrap()
- search_queries_generated_at.unwrap(),
for included_file in request.request.included_files {
let insertions =
vec![(request.request.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| ActualExcerpt {
path: included_file.path.components().skip(1).collect(),
text: String::from(excerpt.text.as_ref()),
},
));
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
if included_file.path == request.request.excerpt_path {
&insertions
} else {
&[]
},
included_file.max_row,
false,
&mut result.excerpts_text,
);
}
result.prediction_time =
prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
break;
}
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response = zeta2::text_from_response(response).unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
if !options.use_expected_context {
result.planning_search_time = Some(
search_queries_generated_at.unwrap() - start_time.unwrap(),
);
result.running_search_time = Some(
search_queries_executed_at.unwrap()
- search_queries_generated_at.unwrap(),
);
}
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
break;
}
}
anyhow::Ok(())
}
});
anyhow::Ok(())
}
});
if options.use_expected_context {
let context_excerpts_tasks = example
.example
.expected_context
.iter()
.flat_map(|section| {
section.alternatives[0].excerpts.iter().map(|excerpt| {
resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
})
if options.use_expected_context {
let context_excerpts_tasks = example
.example
.expected_context
.iter()
.flat_map(|section| {
section.alternatives[0].excerpts.iter().map(|excerpt| {
resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
})
.collect::<Vec<_>>();
let context_excerpts_vec =
futures::future::try_join_all(context_excerpts_tasks).await?;
})
.collect::<Vec<_>>();
let context_excerpts_vec =
futures::future::try_join_all(context_excerpts_tasks).await?;
let mut context_excerpts = HashMap::default();
for (buffer, mut excerpts) in context_excerpts_vec {
context_excerpts
.entry(buffer)
.or_insert(Vec::new())
.append(&mut excerpts);
}
zeta.update(cx, |zeta, _cx| {
zeta.set_context(project.clone(), context_excerpts)
})?;
} else {
zeta.update(cx, |zeta, cx| {
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?
.await?;
let mut context_excerpts = HashMap::default();
for (buffer, mut excerpts) in context_excerpts_vec {
context_excerpts
.entry(buffer)
.or_insert(Vec::new())
.append(&mut excerpts);
}
let prediction = zeta
.update(cx, |zeta, cx| {
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
})?
.await?
.map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
debug_task.await?;
prediction
}
crate::PredictionProvider::Sweep => sweep
.unwrap()
.update(cx, |sweep, cx| {
let mut recent_paths = Vec::new();
for path in zeta
.read(cx)
.history_for_project(&project)
.rev()
.filter_map(|event| event.project_path(cx))
{
if !recent_paths.contains(&path) {
recent_paths.push(path);
}
}
sweep.request_completion(
&project,
recent_paths.into_iter(),
&cursor_buffer,
cursor_anchor,
cx,
)
zeta.update(cx, |zeta, _cx| {
zeta.set_context(project.clone(), context_excerpts)
})?;
} else {
zeta.update(cx, |zeta, cx| {
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?
.await?
.map(
|sweep_ai::EditPrediction {
edits, snapshot, ..
}| { (cursor_buffer.clone(), snapshot, edits) },
),
};
.await?;
}
}
let prediction = zeta
.update(cx, |zeta, cx| {
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
})?
.await?;
debug_task.await?;
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
.map(|(buffer, snapshot, edits)| {
let old_text = snapshot.text();
let new_text = buffer
.map(|prediction| {
let old_text = prediction.snapshot.text();
let new_text = prediction
.buffer
.update(cx, |buffer, cx| {
let branch = buffer.branch(cx);
branch.update(cx, |branch, cx| {
branch.edit(edits.iter().cloned(), None, cx);
branch.edit(prediction.edits.iter().cloned(), None, cx);
branch.text()
})
})