mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
zeta2: Merge Sweep and Zeta2 Providers (#43097)
Closes #ISSUE Release Notes: - N/A *or* Added/Fixed/Improved ... --------- Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
c70f2d16ad
commit
f2f40a5099
17 changed files with 554 additions and 1199 deletions
32
Cargo.lock
generated
32
Cargo.lock
generated
|
|
@ -5314,13 +5314,13 @@ dependencies = [
|
|||
"serde_json",
|
||||
"settings",
|
||||
"supermaven",
|
||||
"sweep_ai",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"zeta",
|
||||
"zeta2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -16590,33 +16590,6 @@ dependencies = [
|
|||
"zeno",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sweep_ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arrayvec",
|
||||
"brotli",
|
||||
"client",
|
||||
"collections",
|
||||
"edit_prediction",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"language",
|
||||
"project",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tree-sitter-rust",
|
||||
"util",
|
||||
"workspace",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "symphonia"
|
||||
version = "0.5.5"
|
||||
|
|
@ -21343,7 +21316,6 @@ dependencies = [
|
|||
"snippets_ui",
|
||||
"supermaven",
|
||||
"svg_preview",
|
||||
"sweep_ai",
|
||||
"sysinfo 0.37.2",
|
||||
"system_specs",
|
||||
"tab_switcher",
|
||||
|
|
@ -21754,6 +21726,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"arrayvec",
|
||||
"brotli",
|
||||
"chrono",
|
||||
"client",
|
||||
"clock",
|
||||
|
|
@ -21864,7 +21837,6 @@ dependencies = [
|
|||
"shellexpand 2.1.2",
|
||||
"smol",
|
||||
"soa-rs",
|
||||
"sweep_ai",
|
||||
"terminal_view",
|
||||
"toml 0.8.23",
|
||||
"util",
|
||||
|
|
|
|||
|
|
@ -165,7 +165,6 @@ members = [
|
|||
"crates/sum_tree",
|
||||
"crates/supermaven",
|
||||
"crates/supermaven_api",
|
||||
"crates/sweep_ai",
|
||||
"crates/codestral",
|
||||
"crates/svg_preview",
|
||||
"crates/system_specs",
|
||||
|
|
@ -399,7 +398,6 @@ streaming_diff = { path = "crates/streaming_diff" }
|
|||
sum_tree = { path = "crates/sum_tree" }
|
||||
supermaven = { path = "crates/supermaven" }
|
||||
supermaven_api = { path = "crates/supermaven_api" }
|
||||
sweep_ai = { path = "crates/sweep_ai" }
|
||||
codestral = { path = "crates/codestral" }
|
||||
system_specs = { path = "crates/system_specs" }
|
||||
tab_switcher = { path = "crates/tab_switcher" }
|
||||
|
|
|
|||
|
|
@ -30,12 +30,12 @@ project.workspace = true
|
|||
regex.workspace = true
|
||||
settings.workspace = true
|
||||
supermaven.workspace = true
|
||||
sweep_ai.workspace = true
|
||||
telemetry.workspace = true
|
||||
ui.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta.workspace = true
|
||||
zeta2.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
use supermaven::{AccountStatus, Supermaven};
|
||||
use sweep_ai::SweepFeatureFlag;
|
||||
use ui::{
|
||||
Clickable, ContextMenu, ContextMenuEntry, DocumentationEdge, DocumentationSide, IconButton,
|
||||
IconButtonShape, Indicator, PopoverMenu, PopoverMenuHandle, ProgressBar, Tooltip, prelude::*,
|
||||
|
|
@ -39,6 +38,7 @@ use workspace::{
|
|||
};
|
||||
use zed_actions::OpenBrowser;
|
||||
use zeta::RateCompletions;
|
||||
use zeta2::SweepFeatureFlag;
|
||||
|
||||
actions!(
|
||||
edit_prediction,
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
[package]
|
||||
name = "sweep_ai"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
exclude = ["fixtures"]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/sweep_ai.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
project.workspace = true
|
||||
release_channel.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
http_client = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
reqwest_client = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-rust.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-GPL
|
||||
|
|
@ -1,784 +0,0 @@
|
|||
mod api;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use arrayvec::ArrayVec;
|
||||
use client::telemetry;
|
||||
use collections::HashMap;
|
||||
use feature_flags::FeatureFlag;
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Task, WeakEntity};
|
||||
use http_client::{AsyncBody, Method};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, EditPreview, Point, ToOffset as _, ToPoint, text_diff,
|
||||
};
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::{AppCommitSha, AppVersion};
|
||||
use std::collections::{VecDeque, hash_map};
|
||||
use std::fmt::{self, Display};
|
||||
use std::mem;
|
||||
use std::{
|
||||
cmp,
|
||||
fmt::Write,
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::ResultExt;
|
||||
use util::rel_path::RelPath;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::api::{AutocompleteRequest, AutocompleteResponse, FileChunk};
|
||||
|
||||
const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
|
||||
const MAX_EVENT_COUNT: usize = 6;
|
||||
|
||||
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
pub struct SweepFeatureFlag;
|
||||
|
||||
impl FeatureFlag for SweepFeatureFlag {
|
||||
const NAME: &str = "sweep-ai";
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SweepAiGlobal(Entity<SweepAi>);
|
||||
|
||||
impl Global for SweepAiGlobal {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPrediction {
|
||||
pub id: EditPredictionId,
|
||||
pub path: Arc<Path>,
|
||||
pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
pub edit_preview: EditPreview,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for EditPrediction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("EditPrediction")
|
||||
.field("path", &self.path)
|
||||
.field("edits", &self.edits)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(String);
|
||||
|
||||
impl Display for EditPredictionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SweepAi {
|
||||
projects: HashMap<EntityId, SweepAiProject>,
|
||||
debug_info: Arc<str>,
|
||||
api_token: Option<String>,
|
||||
}
|
||||
|
||||
struct SweepAiProject {
|
||||
events: VecDeque<Event>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
}
|
||||
|
||||
impl SweepAi {
|
||||
pub fn global(cx: &mut App) -> Option<Entity<Self>> {
|
||||
cx.try_global::<SweepAiGlobal>()
|
||||
.map(|global| global.0.clone())
|
||||
}
|
||||
|
||||
pub fn register(cx: &mut App) -> Entity<Self> {
|
||||
Self::global(cx).unwrap_or_else(|| {
|
||||
let entity = cx.new(|cx| Self::new(cx));
|
||||
cx.set_global(SweepAiGlobal(entity.clone()));
|
||||
entity
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_history(&mut self) {
|
||||
for sweep_ai_project in self.projects.values_mut() {
|
||||
sweep_ai_project.events.clear();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
|
||||
projects: HashMap::default(),
|
||||
debug_info: format!(
|
||||
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
|
||||
version = AppVersion::global(cx),
|
||||
sha = AppCommitSha::try_global(cx).map_or("unknown".to_string(), |sha| sha.full()),
|
||||
os = telemetry::os_name(),
|
||||
)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_or_init_sweep_ai_project(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut SweepAiProject {
|
||||
let project_id = project.entity_id();
|
||||
match self.projects.entry(project_id) {
|
||||
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
cx.observe_release(project, move |this, _, _cx| {
|
||||
this.projects.remove(&project_id);
|
||||
})
|
||||
.detach();
|
||||
entry.insert(SweepAiProject {
|
||||
events: VecDeque::with_capacity(MAX_EVENT_COUNT),
|
||||
registered_buffers: HashMap::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_buffer(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
|
||||
Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
|
||||
}
|
||||
|
||||
fn register_buffer_impl<'a>(
|
||||
sweep_ai_project: &'a mut SweepAiProject,
|
||||
buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &'a mut RegisteredBuffer {
|
||||
let buffer_id = buffer.entity_id();
|
||||
match sweep_ai_project.registered_buffers.entry(buffer_id) {
|
||||
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let project_entity_id = project.entity_id();
|
||||
entry.insert(RegisteredBuffer {
|
||||
snapshot,
|
||||
_subscriptions: [
|
||||
cx.subscribe(buffer, {
|
||||
let project = project.downgrade();
|
||||
move |this, buffer, event, cx| {
|
||||
if let language::BufferEvent::Edited = event
|
||||
&& let Some(project) = project.upgrade()
|
||||
{
|
||||
this.report_changes_for_buffer(&buffer, &project, cx);
|
||||
}
|
||||
}
|
||||
}),
|
||||
cx.observe_release(buffer, move |this, _buffer, _cx| {
|
||||
let Some(sweep_ai_project) = this.projects.get_mut(&project_entity_id)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
sweep_ai_project.registered_buffers.remove(&buffer_id);
|
||||
}),
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_completion(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
recent_buffers: impl Iterator<Item = ProjectPath>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let debug_info = self.debug_info.clone();
|
||||
let Some(api_token) = self.api_token.clone() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
|
||||
let project_state = self.get_or_init_sweep_ai_project(project, cx);
|
||||
let events = project_state.events.clone();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
}
|
||||
})
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = cx.background_spawn({
|
||||
let full_path = full_path.clone();
|
||||
async move {
|
||||
let text = snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
|
||||
for event in events {
|
||||
writeln!(&mut recent_changes, "{event}")?;
|
||||
}
|
||||
|
||||
let file_chunks = recent_buffer_snapshots
|
||||
.into_iter()
|
||||
.map(|snapshot| {
|
||||
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
|
||||
FileChunk {
|
||||
content: snapshot
|
||||
.text_for_range(language::Point::zero()..end_point)
|
||||
.collect(),
|
||||
file_path: snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_unix_str())
|
||||
.unwrap_or("untitled")
|
||||
.to_string(),
|
||||
start_line: 0,
|
||||
end_line: end_point.row as usize,
|
||||
timestamp: snapshot.file().and_then(|file| {
|
||||
Some(
|
||||
file.disk_state()
|
||||
.mtime()?
|
||||
.to_seconds_and_nanos_for_persistence()?
|
||||
.0,
|
||||
)
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
eprintln!("{recent_changes}");
|
||||
|
||||
let request_body = AutocompleteRequest {
|
||||
debug_info,
|
||||
repo_name,
|
||||
file_path: full_path.clone(),
|
||||
file_contents: text.clone(),
|
||||
original_file_contents: text,
|
||||
cursor_position: offset,
|
||||
recent_changes: recent_changes.clone(),
|
||||
changes_above_cursor: true,
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
recent_user_actions: vec![],
|
||||
// TODO
|
||||
privacy_mode_enabled: false,
|
||||
};
|
||||
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(SWEEP_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Content-Encoding", "br")
|
||||
.method(Method::POST)
|
||||
.body(body)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
anyhow::Ok((response.autocomplete_id, edits, snapshot))
|
||||
}
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let (id, edits, old_snapshot) = result.await?;
|
||||
|
||||
if edits.is_empty() {
|
||||
return anyhow::Ok(None);
|
||||
}
|
||||
|
||||
let Some((edits, new_snapshot, preview_task)) =
|
||||
buffer.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
|
||||
edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
|
||||
.into();
|
||||
let preview_task = buffer.preview_edits(edits.clone(), cx);
|
||||
|
||||
Some((edits, new_snapshot, preview_task))
|
||||
})?
|
||||
else {
|
||||
return anyhow::Ok(None);
|
||||
};
|
||||
|
||||
let prediction = EditPrediction {
|
||||
id: EditPredictionId(id),
|
||||
path: full_path,
|
||||
edits,
|
||||
snapshot: new_snapshot,
|
||||
edit_preview: preview_task.await,
|
||||
};
|
||||
|
||||
anyhow::Ok(Some(prediction))
|
||||
})
|
||||
}
|
||||
|
||||
fn report_changes_for_buffer(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let sweep_ai_project = self.get_or_init_sweep_ai_project(project, cx);
|
||||
let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
|
||||
|
||||
let new_snapshot = buffer.read(cx).snapshot();
|
||||
if new_snapshot.version == registered_buffer.snapshot.version {
|
||||
return;
|
||||
}
|
||||
|
||||
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
|
||||
let end_edit_anchor = new_snapshot
|
||||
.anchored_edits_since::<Point>(&old_snapshot.version)
|
||||
.last()
|
||||
.map(|(_, range)| range.end);
|
||||
let events = &mut sweep_ai_project.events;
|
||||
|
||||
if let Some(Event::BufferChange {
|
||||
new_snapshot: last_new_snapshot,
|
||||
end_edit_anchor: last_end_edit_anchor,
|
||||
..
|
||||
}) = events.back_mut()
|
||||
{
|
||||
let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
|
||||
== last_new_snapshot.remote_id()
|
||||
&& old_snapshot.version == last_new_snapshot.version;
|
||||
|
||||
let should_coalesce = is_next_snapshot_of_same_buffer
|
||||
&& end_edit_anchor
|
||||
.as_ref()
|
||||
.zip(last_end_edit_anchor.as_ref())
|
||||
.is_some_and(|(a, b)| {
|
||||
let a = a.to_point(&new_snapshot);
|
||||
let b = b.to_point(&new_snapshot);
|
||||
a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
|
||||
});
|
||||
|
||||
if should_coalesce {
|
||||
*last_end_edit_anchor = end_edit_anchor;
|
||||
*last_new_snapshot = new_snapshot;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if events.len() >= MAX_EVENT_COUNT {
|
||||
events.pop_front();
|
||||
}
|
||||
|
||||
events.push_back(Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot,
|
||||
end_edit_anchor,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
struct RegisteredBuffer {
|
||||
snapshot: BufferSnapshot,
|
||||
_subscriptions: [gpui::Subscription; 2],
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Event {
|
||||
BufferChange {
|
||||
old_snapshot: BufferSnapshot,
|
||||
new_snapshot: BufferSnapshot,
|
||||
end_edit_anchor: Option<Anchor>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Event {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot,
|
||||
..
|
||||
} => {
|
||||
let old_path = old_snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_ref())
|
||||
.unwrap_or(RelPath::unix("untitled").unwrap());
|
||||
let new_path = new_snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_ref())
|
||||
.unwrap_or(RelPath::unix("untitled").unwrap());
|
||||
if old_path != new_path {
|
||||
// TODO confirm how to do this for sweep
|
||||
// writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
|
||||
}
|
||||
|
||||
let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
|
||||
if !diff.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"File: {}:\n{}\n",
|
||||
new_path.display(util::paths::PathStyle::Posix),
|
||||
diff
|
||||
)?
|
||||
}
|
||||
|
||||
fmt::Result::Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CurrentEditPrediction {
|
||||
buffer_id: EntityId,
|
||||
completion: EditPrediction,
|
||||
}
|
||||
|
||||
impl CurrentEditPrediction {
|
||||
fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
|
||||
if self.buffer_id != old_completion.buffer_id {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
|
||||
return true;
|
||||
};
|
||||
let Some(new_edits) = self.completion.interpolate(snapshot) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if old_edits.len() == 1 && new_edits.len() == 1 {
|
||||
let (old_range, old_text) = &old_edits[0];
|
||||
let (new_range, new_text) = &new_edits[0];
|
||||
new_range == old_range && new_text.starts_with(old_text.as_ref())
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PendingCompletion {
|
||||
id: usize,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
pub struct SweepAiEditPredictionProvider {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
sweep_ai: Entity<SweepAi>,
|
||||
pending_completions: ArrayVec<PendingCompletion, 2>,
|
||||
next_pending_completion_id: usize,
|
||||
current_completion: Option<CurrentEditPrediction>,
|
||||
last_request_timestamp: Instant,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl SweepAiEditPredictionProvider {
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
|
||||
pub fn new(
|
||||
sweep_ai: Entity<SweepAi>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sweep_ai,
|
||||
pending_completions: ArrayVec::new(),
|
||||
next_pending_completion_id: 0,
|
||||
current_completion: None,
|
||||
last_request_timestamp: Instant::now(),
|
||||
project,
|
||||
workspace,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl edit_prediction::EditPredictionProvider for SweepAiEditPredictionProvider {
|
||||
fn name() -> &'static str {
|
||||
"zed-predict"
|
||||
}
|
||||
|
||||
fn display_name() -> &'static str {
|
||||
"Zed's Edit Predictions"
|
||||
}
|
||||
|
||||
fn show_completions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn show_tab_accept_marker() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn is_enabled(
|
||||
&self,
|
||||
_buffer: &Entity<Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
self.sweep_ai.read(cx).api_token.is_some()
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
!self.pending_completions.is_empty()
|
||||
}
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
_debounce: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(current_completion) = self.current_completion.as_ref() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
if current_completion
|
||||
.completion
|
||||
.interpolate(&snapshot)
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let pending_completion_id = self.next_pending_completion_id;
|
||||
self.next_pending_completion_id += 1;
|
||||
let last_request_timestamp = self.last_request_timestamp;
|
||||
|
||||
let project = self.project.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
|
||||
.checked_duration_since(Instant::now())
|
||||
{
|
||||
cx.background_executor().timer(timeout).await;
|
||||
}
|
||||
|
||||
let completion_request = this.update(cx, |this, cx| {
|
||||
this.last_request_timestamp = Instant::now();
|
||||
|
||||
this.sweep_ai.update(cx, |sweep_ai, cx| {
|
||||
let Some(recent_buffers) = workspace
|
||||
.read_with(cx, |workspace, cx| {
|
||||
workspace.recent_navigation_history_iter(cx)
|
||||
})
|
||||
.log_err()
|
||||
else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
sweep_ai.request_completion(
|
||||
&project,
|
||||
recent_buffers.map(move |(project_path, _)| project_path),
|
||||
&buffer,
|
||||
position,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let completion = match completion_request {
|
||||
Ok(completion_request) => {
|
||||
let completion_request = completion_request.await;
|
||||
completion_request.map(|c| {
|
||||
c.map(|completion| CurrentEditPrediction {
|
||||
buffer_id: buffer.entity_id(),
|
||||
completion,
|
||||
})
|
||||
})
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
};
|
||||
|
||||
let Some(new_completion) = completion
|
||||
.context("edit prediction failed")
|
||||
.log_err()
|
||||
.flatten()
|
||||
else {
|
||||
this.update(cx, |this, cx| {
|
||||
if this.pending_completions[0].id == pending_completion_id {
|
||||
this.pending_completions.remove(0);
|
||||
} else {
|
||||
this.pending_completions.clear();
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if this.pending_completions[0].id == pending_completion_id {
|
||||
this.pending_completions.remove(0);
|
||||
} else {
|
||||
this.pending_completions.clear();
|
||||
}
|
||||
|
||||
if let Some(old_completion) = this.current_completion.as_ref() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
if new_completion.should_replace_completion(old_completion, &snapshot) {
|
||||
this.current_completion = Some(new_completion);
|
||||
}
|
||||
} else {
|
||||
this.current_completion = Some(new_completion);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
// We always maintain at most two pending completions. When we already
|
||||
// have two, we replace the newest one.
|
||||
if self.pending_completions.len() <= 1 {
|
||||
self.pending_completions.push(PendingCompletion {
|
||||
id: pending_completion_id,
|
||||
_task: task,
|
||||
});
|
||||
} else if self.pending_completions.len() == 2 {
|
||||
self.pending_completions.pop();
|
||||
self.pending_completions.push(PendingCompletion {
|
||||
id: pending_completion_id,
|
||||
_task: task,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn cycle(
|
||||
&mut self,
|
||||
_buffer: Entity<Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_direction: edit_prediction::Direction,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
// Right now we don't support cycling.
|
||||
}
|
||||
|
||||
fn accept(&mut self, _cx: &mut Context<Self>) {
|
||||
self.pending_completions.clear();
|
||||
}
|
||||
|
||||
fn discard(&mut self, _cx: &mut Context<Self>) {
|
||||
self.pending_completions.clear();
|
||||
self.current_completion.take();
|
||||
}
|
||||
|
||||
fn suggest(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction::EditPrediction> {
|
||||
let CurrentEditPrediction {
|
||||
buffer_id,
|
||||
completion,
|
||||
..
|
||||
} = self.current_completion.as_mut()?;
|
||||
|
||||
// Invalidate previous completion if it was generated for a different buffer.
|
||||
if *buffer_id != buffer.entity_id() {
|
||||
self.current_completion.take();
|
||||
return None;
|
||||
}
|
||||
|
||||
let buffer = buffer.read(cx);
|
||||
let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
|
||||
self.current_completion.take();
|
||||
return None;
|
||||
};
|
||||
|
||||
let cursor_row = cursor_position.to_point(buffer).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit =
|
||||
closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit =
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction::EditPrediction::Local {
|
||||
id: Some(completion.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(completion.edit_preview.clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -133,7 +133,6 @@ snippet_provider.workspace = true
|
|||
snippets_ui.workspace = true
|
||||
supermaven.workspace = true
|
||||
svg_preview.workspace = true
|
||||
sweep_ai.workspace = true
|
||||
sysinfo.workspace = true
|
||||
tab_switcher.workspace = true
|
||||
task.workspace = true
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ use language_models::MistralLanguageModelProvider;
|
|||
use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
|
||||
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
||||
use supermaven::{Supermaven, SupermavenCompletionProvider};
|
||||
use sweep_ai::{SweepAiEditPredictionProvider, SweepFeatureFlag};
|
||||
use ui::Window;
|
||||
use zeta::ZetaEditPredictionProvider;
|
||||
use zeta2::SweepFeatureFlag;
|
||||
use zeta2::Zeta2FeatureFlag;
|
||||
|
||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
|
|
@ -203,55 +203,41 @@ fn assign_edit_prediction_provider(
|
|||
let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
EditPredictionProvider::Experimental(name) => {
|
||||
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
|
||||
&& cx.has_flag::<SweepFeatureFlag>()
|
||||
{
|
||||
if let Some(project) = editor.project()
|
||||
&& let Some(workspace) = editor.workspace()
|
||||
{
|
||||
let sweep_ai = sweep_ai::SweepAi::register(cx);
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
sweep_ai.update(cx, |sweep_ai, cx| {
|
||||
sweep_ai.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
let provider = cx.new(|_| {
|
||||
sweep_ai::SweepAiEditPredictionProvider::new(
|
||||
sweep_ai,
|
||||
workspace.downgrade(),
|
||||
project.clone(),
|
||||
)
|
||||
});
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
} else {
|
||||
editor.set_edit_prediction_provider::<SweepAiEditPredictionProvider>(
|
||||
None, window, cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
EditPredictionProvider::Zed => {
|
||||
if user_store.read(cx).current_user().is_some() {
|
||||
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
|
||||
if let Some(project) = editor.project() {
|
||||
let mut worktree = None;
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& let Some(file) = buffer.read(cx).file()
|
||||
{
|
||||
let id = file.worktree_id(cx);
|
||||
if let Some(inner_worktree) = editor
|
||||
.project()
|
||||
.and_then(|project| project.read(cx).worktree_for_id(id, cx))
|
||||
{
|
||||
worktree = Some(inner_worktree);
|
||||
}
|
||||
worktree = project.read(cx).worktree_for_id(id, cx);
|
||||
}
|
||||
|
||||
if let Some(project) = editor.project() {
|
||||
if let EditPredictionProvider::Experimental(name) = value
|
||||
&& name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
|
||||
&& cx.has_flag::<SweepFeatureFlag>()
|
||||
{
|
||||
let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
|
||||
let provider = cx.new(|cx| {
|
||||
zeta2::ZetaEditPredictionProvider::new(
|
||||
project.clone(),
|
||||
&client,
|
||||
&user_store,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(buffer) = &singleton_buffer
|
||||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
zeta2.update(cx, |zeta, cx| {
|
||||
zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep);
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
} else if user_store.read(cx).current_user().is_some() {
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
let zeta = zeta2::Zeta::global(client, &user_store, cx);
|
||||
let provider = cx.new(|cx| {
|
||||
|
|
@ -268,6 +254,9 @@ fn assign_edit_prediction_provider(
|
|||
&& buffer.read(cx).file().is_some()
|
||||
{
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.set_edit_prediction_model(
|
||||
zeta2::ZetaEditPredictionModel::ZedCloud,
|
||||
);
|
||||
zeta.register_buffer(buffer, project, cx);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ eval-support = []
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use language::ToPoint as _;
|
|||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{BufferEditPrediction, Zeta};
|
||||
use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
|
||||
|
||||
pub struct ZetaEditPredictionProvider {
|
||||
zeta: Entity<Zeta>,
|
||||
|
|
@ -85,9 +85,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
|||
&self,
|
||||
_buffer: &Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_cx: &App,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
true
|
||||
let zeta = self.zeta.read(cx);
|
||||
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
|
||||
zeta.sweep_api_token.is_some()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
use std::fmt;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AutocompleteRequest {
|
||||
|
|
@ -88,3 +90,49 @@ pub struct AdditionalCompletion {
|
|||
pub logprobs: Option<serde_json::Value>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result {
|
||||
match event {
|
||||
crate::Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot,
|
||||
..
|
||||
} => {
|
||||
let old_path = old_snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_ref())
|
||||
.unwrap_or(RelPath::unix("untitled").unwrap());
|
||||
let new_path = new_snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_ref())
|
||||
.unwrap_or(RelPath::unix("untitled").unwrap());
|
||||
if old_path != new_path {
|
||||
// TODO confirm how to do this for sweep
|
||||
// writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?;
|
||||
}
|
||||
|
||||
let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
|
||||
if !diff.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"File: {}:\n{}\n",
|
||||
new_path.display(util::paths::PathStyle::Posix),
|
||||
diff
|
||||
)?
|
||||
}
|
||||
|
||||
fmt::Result::Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn debug_info(cx: &gpui::App) -> Arc<str> {
|
||||
format!(
|
||||
"Zed v{version} ({sha}) - OS: {os} - Zed v{version}",
|
||||
version = release_channel::AppVersion::global(cx),
|
||||
sha = release_channel::AppCommitSha::try_global(cx)
|
||||
.map_or("unknown".to_string(), |sha| sha.full()),
|
||||
os = client::telemetry::os_name(),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
|
@ -22,30 +22,31 @@ use gpui::{
|
|||
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
|
||||
http_client, prelude::*,
|
||||
};
|
||||
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
|
||||
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
|
||||
use language::{BufferSnapshot, OffsetRangeExt};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use open_ai::FunctionDefinition;
|
||||
use project::Project;
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::AppVersion;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::collections::{VecDeque, hash_map};
|
||||
|
||||
use std::env;
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{env, mem};
|
||||
use thiserror::Error;
|
||||
use util::rel_path::RelPathBuf;
|
||||
use util::{LogErrorFuture, TryFutureExt};
|
||||
use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
|
||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
|
||||
pub mod assemble_excerpts;
|
||||
mod prediction;
|
||||
mod provider;
|
||||
pub mod retrieval_search;
|
||||
mod sweep_ai;
|
||||
pub mod udiff;
|
||||
mod xml_edits;
|
||||
|
||||
|
|
@ -55,8 +56,15 @@ pub use crate::prediction::EditPredictionId;
|
|||
pub use provider::ZetaEditPredictionProvider;
|
||||
|
||||
/// Maximum number of events to track.
|
||||
const MAX_EVENT_COUNT: usize = 16;
|
||||
const EVENT_COUNT_MAX_SWEEP: usize = 6;
|
||||
const EVENT_COUNT_MAX_ZETA: usize = 16;
|
||||
const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
|
||||
|
||||
pub struct SweepFeatureFlag;
|
||||
|
||||
impl FeatureFlag for SweepFeatureFlag {
|
||||
const NAME: &str = "sweep-ai";
|
||||
}
|
||||
pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
|
||||
max_bytes: 512,
|
||||
min_bytes: 128,
|
||||
|
|
@ -143,6 +151,15 @@ pub struct Zeta {
|
|||
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: ZetaEditPredictionModel,
|
||||
sweep_api_token: Option<String>,
|
||||
sweep_ai_debug_info: Arc<str>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum ZetaEditPredictionModel {
|
||||
ZedCloud,
|
||||
Sweep,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -219,12 +236,14 @@ pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
|
|||
struct ZetaProject {
|
||||
syntax_index: Option<Entity<SyntaxIndex>>,
|
||||
events: VecDeque<Event>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
|
||||
refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
|
||||
refresh_context_debounce_task: Option<Task<Option<()>>>,
|
||||
refresh_context_timestamp: Option<Instant>,
|
||||
_subscription: gpui::Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -287,6 +306,7 @@ pub enum Event {
|
|||
BufferChange {
|
||||
old_snapshot: BufferSnapshot,
|
||||
new_snapshot: BufferSnapshot,
|
||||
end_edit_anchor: Option<Anchor>,
|
||||
timestamp: Instant,
|
||||
},
|
||||
}
|
||||
|
|
@ -381,9 +401,21 @@ impl Zeta {
|
|||
debug_tx: None,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
|
||||
sweep_api_token: None,
|
||||
sweep_ai_debug_info: sweep_ai::debug_info(cx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
|
||||
if model == ZetaEditPredictionModel::Sweep {
|
||||
self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
|
||||
.context("No SWEEP_AI_TOKEN environment variable set")
|
||||
.log_err();
|
||||
}
|
||||
self.edit_prediction_model = model;
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
|
||||
self.eval_cache = Some(cache);
|
||||
|
|
@ -443,7 +475,7 @@ impl Zeta {
|
|||
self.user_store.read(cx).edit_prediction_usage()
|
||||
}
|
||||
|
||||
pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
|
||||
pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
|
||||
self.get_or_init_zeta_project(project, cx);
|
||||
}
|
||||
|
||||
|
|
@ -460,7 +492,7 @@ impl Zeta {
|
|||
fn get_or_init_zeta_project(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut ZetaProject {
|
||||
self.projects
|
||||
.entry(project.entity_id())
|
||||
|
|
@ -473,12 +505,31 @@ impl Zeta {
|
|||
None
|
||||
},
|
||||
events: VecDeque::new(),
|
||||
recent_paths: VecDeque::new(),
|
||||
registered_buffers: HashMap::default(),
|
||||
current_prediction: None,
|
||||
context: None,
|
||||
refresh_context_task: None,
|
||||
refresh_context_debounce_task: None,
|
||||
refresh_context_timestamp: None,
|
||||
_subscription: cx.subscribe(&project, |this, project, event, cx| {
|
||||
// TODO [zeta2] init with recent paths
|
||||
if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
|
||||
if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
|
||||
let path = project.read(cx).path_for_entry(*active_entry_id, cx);
|
||||
if let Some(path) = path {
|
||||
if let Some(ix) = zeta_project
|
||||
.recent_paths
|
||||
.iter()
|
||||
.position(|probe| probe == &path)
|
||||
{
|
||||
zeta_project.recent_paths.remove(ix);
|
||||
}
|
||||
zeta_project.recent_paths.push_front(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -525,66 +576,64 @@ impl Zeta {
|
|||
buffer: &Entity<Buffer>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> BufferSnapshot {
|
||||
let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
|
||||
let zeta_project = self.get_or_init_zeta_project(project, cx);
|
||||
let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
|
||||
) {
|
||||
let event_count_max = match self.edit_prediction_model {
|
||||
ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
|
||||
ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
|
||||
};
|
||||
|
||||
let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
|
||||
let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
|
||||
|
||||
let new_snapshot = buffer.read(cx).snapshot();
|
||||
if new_snapshot.version != registered_buffer.snapshot.version {
|
||||
let old_snapshot =
|
||||
std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
|
||||
Self::push_event(
|
||||
zeta_project,
|
||||
buffer_change_grouping_interval,
|
||||
Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot: new_snapshot.clone(),
|
||||
timestamp: Instant::now(),
|
||||
},
|
||||
);
|
||||
if new_snapshot.version == registered_buffer.snapshot.version {
|
||||
return;
|
||||
}
|
||||
|
||||
new_snapshot
|
||||
}
|
||||
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
|
||||
let end_edit_anchor = new_snapshot
|
||||
.anchored_edits_since::<Point>(&old_snapshot.version)
|
||||
.last()
|
||||
.map(|(_, range)| range.end);
|
||||
let events = &mut sweep_ai_project.events;
|
||||
|
||||
fn push_event(
|
||||
zeta_project: &mut ZetaProject,
|
||||
buffer_change_grouping_interval: Duration,
|
||||
event: Event,
|
||||
) {
|
||||
let events = &mut zeta_project.events;
|
||||
|
||||
if buffer_change_grouping_interval > Duration::ZERO
|
||||
&& let Some(Event::BufferChange {
|
||||
new_snapshot: last_new_snapshot,
|
||||
timestamp: last_timestamp,
|
||||
..
|
||||
}) = events.back_mut()
|
||||
if let Some(Event::BufferChange {
|
||||
new_snapshot: last_new_snapshot,
|
||||
end_edit_anchor: last_end_edit_anchor,
|
||||
..
|
||||
}) = events.back_mut()
|
||||
{
|
||||
// Coalesce edits for the same buffer when they happen one after the other.
|
||||
let Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot,
|
||||
timestamp,
|
||||
} = &event;
|
||||
let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
|
||||
== last_new_snapshot.remote_id()
|
||||
&& old_snapshot.version == last_new_snapshot.version;
|
||||
|
||||
if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
|
||||
&& old_snapshot.remote_id() == last_new_snapshot.remote_id()
|
||||
&& old_snapshot.version == last_new_snapshot.version
|
||||
{
|
||||
*last_new_snapshot = new_snapshot.clone();
|
||||
*last_timestamp = *timestamp;
|
||||
let should_coalesce = is_next_snapshot_of_same_buffer
|
||||
&& end_edit_anchor
|
||||
.as_ref()
|
||||
.zip(last_end_edit_anchor.as_ref())
|
||||
.is_some_and(|(a, b)| {
|
||||
let a = a.to_point(&new_snapshot);
|
||||
let b = b.to_point(&new_snapshot);
|
||||
a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
|
||||
});
|
||||
|
||||
if should_coalesce {
|
||||
*last_end_edit_anchor = end_edit_anchor;
|
||||
*last_new_snapshot = new_snapshot;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if events.len() >= MAX_EVENT_COUNT {
|
||||
// These are halved instead of popping to improve prompt caching.
|
||||
events.drain(..MAX_EVENT_COUNT / 2);
|
||||
if events.len() >= event_count_max {
|
||||
events.pop_front();
|
||||
}
|
||||
|
||||
events.push_back(event);
|
||||
events.push_back(Event::BufferChange {
|
||||
old_snapshot,
|
||||
new_snapshot,
|
||||
end_edit_anchor,
|
||||
timestamp: Instant::now(),
|
||||
});
|
||||
}
|
||||
|
||||
fn current_prediction_for_buffer(
|
||||
|
|
@ -706,6 +755,203 @@ impl Zeta {
|
|||
active_buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
match self.edit_prediction_model {
|
||||
ZetaEditPredictionModel::ZedCloud => {
|
||||
self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
|
||||
}
|
||||
ZetaEditPredictionModel::Sweep => {
|
||||
self.request_prediction_with_sweep(project, active_buffer, position, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn request_prediction_with_sweep(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let debug_info = self.sweep_ai_debug_info.clone();
|
||||
let Some(api_token) = self.sweep_api_token.clone() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
|
||||
let project_state = self.get_or_init_zeta_project(project, cx);
|
||||
let events = project_state.events.clone();
|
||||
let recent_buffers = project_state.recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
}
|
||||
})
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let text = snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in events {
|
||||
sweep_ai::write_event(event, &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
let file_chunks = recent_buffer_snapshots
|
||||
.into_iter()
|
||||
.map(|snapshot| {
|
||||
let end_point = language::Point::new(30, 0).min(snapshot.max_point());
|
||||
sweep_ai::FileChunk {
|
||||
content: snapshot
|
||||
.text_for_range(language::Point::zero()..end_point)
|
||||
.collect(),
|
||||
file_path: snapshot
|
||||
.file()
|
||||
.map(|f| f.path().as_unix_str())
|
||||
.unwrap_or("untitled")
|
||||
.to_string(),
|
||||
start_line: 0,
|
||||
end_line: end_point.row as usize,
|
||||
timestamp: snapshot.file().and_then(|file| {
|
||||
Some(
|
||||
file.disk_state()
|
||||
.mtime()?
|
||||
.to_seconds_and_nanos_for_persistence()?
|
||||
.0,
|
||||
)
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request_body = sweep_ai::AutocompleteRequest {
|
||||
debug_info,
|
||||
repo_name,
|
||||
file_path: full_path.clone(),
|
||||
file_contents: text.clone(),
|
||||
original_file_contents: text,
|
||||
cursor_position: offset,
|
||||
recent_changes: recent_changes.clone(),
|
||||
changes_above_cursor: true,
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
recent_user_actions: vec![],
|
||||
// TODO
|
||||
privacy_mode_enabled: false,
|
||||
};
|
||||
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
const SWEEP_API_URL: &str =
|
||||
"https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(SWEEP_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Content-Encoding", "br")
|
||||
.method(Method::POST)
|
||||
.body(body)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
anyhow::Ok((response.autocomplete_id, edits, snapshot))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let (id, edits, old_snapshot) = result.await?;
|
||||
|
||||
if edits.is_empty() {
|
||||
return anyhow::Ok(None);
|
||||
}
|
||||
|
||||
let Some((edits, new_snapshot, preview_task)) =
|
||||
buffer.read_with(cx, |buffer, cx| {
|
||||
let new_snapshot = buffer.snapshot();
|
||||
|
||||
let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
|
||||
edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
|
||||
.into();
|
||||
let preview_task = buffer.preview_edits(edits.clone(), cx);
|
||||
|
||||
Some((edits, new_snapshot, preview_task))
|
||||
})?
|
||||
else {
|
||||
return anyhow::Ok(None);
|
||||
};
|
||||
|
||||
let prediction = EditPrediction {
|
||||
id: EditPredictionId(id.into()),
|
||||
edits,
|
||||
snapshot: new_snapshot,
|
||||
edit_preview: preview_task.await,
|
||||
buffer,
|
||||
};
|
||||
|
||||
anyhow::Ok(Some(prediction))
|
||||
})
|
||||
}
|
||||
|
||||
fn request_prediction_with_zed_cloud(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<EditPrediction>>> {
|
||||
let project_state = self.projects.get(&project.entity_id());
|
||||
|
||||
|
|
@ -1653,7 +1899,7 @@ impl Zeta {
|
|||
pub fn wait_for_initial_indexing(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let zeta_project = self.get_or_init_zeta_project(project, cx);
|
||||
if let Some(syntax_index) = &zeta_project.syntax_index {
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ settings.workspace = true
|
|||
shellexpand.workspace = true
|
||||
smol.workspace = true
|
||||
soa-rs = "0.8.1"
|
||||
sweep_ai.workspace = true
|
||||
terminal_view.workspace = true
|
||||
toml.workspace = true
|
||||
util.workspace = true
|
||||
|
|
|
|||
|
|
@ -8,16 +8,15 @@ use anyhow::Result;
|
|||
use collections::HashSet;
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use sweep_ai::SweepAi;
|
||||
use util::ResultExt as _;
|
||||
use zeta2::{Zeta, udiff::DiffLine};
|
||||
|
||||
use crate::{
|
||||
EvaluateArguments, PredictionOptions, PredictionProvider,
|
||||
EvaluateArguments, PredictionOptions,
|
||||
example::{Example, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
paths::print_run_data_dir,
|
||||
predict::{PredictionDetails, perform_predict, setup_sweep, setup_zeta},
|
||||
predict::{PredictionDetails, perform_predict, setup_zeta},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
@ -46,46 +45,35 @@ pub async fn run_evaluate(
|
|||
let project = example.setup_project(&app_state, cx).await.unwrap();
|
||||
|
||||
let providers = (0..args.repetitions)
|
||||
.map(|_| {
|
||||
(
|
||||
setup_zeta(&project, &app_state, cx).unwrap(),
|
||||
if matches!(args.options.provider, PredictionProvider::Sweep) {
|
||||
Some(setup_sweep(&project, cx).unwrap())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
})
|
||||
.map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
|
||||
let tasks =
|
||||
providers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(repetition_ix, (zeta, sweep))| {
|
||||
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
|
||||
let example = example.clone();
|
||||
let project = project.clone();
|
||||
let options = options.clone();
|
||||
let tasks = providers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(repetition_ix, zeta)| {
|
||||
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
|
||||
let example = example.clone();
|
||||
let project = project.clone();
|
||||
let options = options.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let name = example.name.clone();
|
||||
run_evaluate_one(
|
||||
example,
|
||||
repetition_ix,
|
||||
project,
|
||||
zeta,
|
||||
sweep,
|
||||
options,
|
||||
!args.skip_prediction,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| (err, name, repetition_ix))
|
||||
})
|
||||
});
|
||||
cx.spawn(async move |cx| {
|
||||
let name = example.name.clone();
|
||||
run_evaluate_one(
|
||||
example,
|
||||
repetition_ix,
|
||||
project,
|
||||
zeta,
|
||||
options,
|
||||
!args.skip_prediction,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| (err, name, repetition_ix))
|
||||
})
|
||||
});
|
||||
futures::future::join_all(tasks).await
|
||||
})
|
||||
});
|
||||
|
|
@ -177,7 +165,6 @@ pub async fn run_evaluate_one(
|
|||
repetition_ix: Option<u16>,
|
||||
project: Entity<Project>,
|
||||
zeta: Entity<Zeta>,
|
||||
sweep: Option<Entity<SweepAi>>,
|
||||
prediction_options: PredictionOptions,
|
||||
predict: bool,
|
||||
cx: &mut AsyncApp,
|
||||
|
|
@ -186,7 +173,6 @@ pub async fn run_evaluate_one(
|
|||
example.clone(),
|
||||
project,
|
||||
zeta,
|
||||
sweep,
|
||||
repetition_ix,
|
||||
prediction_options,
|
||||
cx,
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ pub struct EvaluateArguments {
|
|||
skip_prediction: bool,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
|
||||
enum PredictionProvider {
|
||||
#[default]
|
||||
Zeta2,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ use std::path::PathBuf;
|
|||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
use sweep_ai::SweepAi;
|
||||
use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
|
||||
|
||||
pub async fn run_predict(
|
||||
|
|
@ -31,14 +30,9 @@ pub async fn run_predict(
|
|||
) {
|
||||
let example = NamedExample::load(args.example_path).unwrap();
|
||||
let project = example.setup_project(app_state, cx).await.unwrap();
|
||||
let zeta = setup_zeta(&project, app_state, cx).unwrap();
|
||||
let sweep = if matches!(args.options.provider, PredictionProvider::Sweep) {
|
||||
Some(setup_sweep(&project, cx).unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
let result = perform_predict(example, project, zeta, sweep, None, args.options, cx)
|
||||
let result = perform_predict(example, project, zeta, None, args.options, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
result.write(args.format, std::io::stdout()).unwrap();
|
||||
|
|
@ -47,6 +41,7 @@ pub async fn run_predict(
|
|||
}
|
||||
|
||||
pub fn setup_zeta(
|
||||
provider: PredictionProvider,
|
||||
project: &Entity<Project>,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
|
|
@ -54,6 +49,14 @@ pub fn setup_zeta(
|
|||
let zeta =
|
||||
cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
|
||||
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud,
|
||||
PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep,
|
||||
};
|
||||
zeta.set_edit_prediction_model(model);
|
||||
})?;
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
|
|
@ -71,31 +74,10 @@ pub fn setup_zeta(
|
|||
anyhow::Ok(zeta)
|
||||
}
|
||||
|
||||
pub fn setup_sweep(project: &Entity<Project>, cx: &mut AsyncApp) -> Result<Entity<SweepAi>> {
|
||||
let sweep = cx.new(|cx| SweepAi::new(cx))?;
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
let sweep = sweep.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
sweep.update(cx, |sweep, cx| sweep.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
|
||||
anyhow::Ok(sweep)
|
||||
}
|
||||
|
||||
pub async fn perform_predict(
|
||||
example: NamedExample,
|
||||
project: Entity<Project>,
|
||||
zeta: Entity<Zeta>,
|
||||
sweep: Option<Entity<SweepAi>>,
|
||||
repetition_ix: Option<u16>,
|
||||
options: PredictionOptions,
|
||||
cx: &mut AsyncApp,
|
||||
|
|
@ -147,194 +129,152 @@ pub async fn perform_predict(
|
|||
zeta.set_options(options);
|
||||
})?;
|
||||
|
||||
let prediction = match options.provider {
|
||||
crate::PredictionProvider::Zeta2 => {
|
||||
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
|
||||
let mut debug_task = gpui::Task::ready(Ok(()));
|
||||
|
||||
let debug_task = cx.background_spawn({
|
||||
let result = result.clone();
|
||||
async move {
|
||||
let mut start_time = None;
|
||||
let mut search_queries_generated_at = None;
|
||||
let mut search_queries_executed_at = None;
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
match event {
|
||||
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
|
||||
start_time = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_prompt.md"),
|
||||
&info.search_prompt,
|
||||
)?;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
|
||||
search_queries_generated_at = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
serde_json::to_string_pretty(&info.search_queries).unwrap(),
|
||||
)?;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
|
||||
search_queries_executed_at = Some(info.timestamp);
|
||||
}
|
||||
zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
|
||||
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
|
||||
let prediction_started_at = Instant::now();
|
||||
start_time.get_or_insert(prediction_started_at);
|
||||
let prompt = request.local_prompt.unwrap_or_default();
|
||||
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
if options.provider == crate::PredictionProvider::Zeta2 {
|
||||
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
|
||||
|
||||
{
|
||||
let mut result = result.lock().unwrap();
|
||||
result.prompt_len = prompt.chars().count();
|
||||
|
||||
for included_file in request.request.included_files {
|
||||
let insertions =
|
||||
vec![(request.request.cursor_point, CURSOR_MARKER)];
|
||||
result.excerpts.extend(included_file.excerpts.iter().map(
|
||||
|excerpt| {
|
||||
ActualExcerpt {
|
||||
path: included_file
|
||||
.path
|
||||
.components()
|
||||
.skip(1)
|
||||
.collect(),
|
||||
text: String::from(excerpt.text.as_ref()),
|
||||
}
|
||||
},
|
||||
));
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
included_file.excerpts.iter(),
|
||||
if included_file.path == request.request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut result.excerpts_text,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response =
|
||||
zeta2::text_from_response(response).unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(
|
||||
example_run_dir.join("prediction_response.md"),
|
||||
&response,
|
||||
)?;
|
||||
debug_task = cx.background_spawn({
|
||||
let result = result.clone();
|
||||
async move {
|
||||
let mut start_time = None;
|
||||
let mut search_queries_generated_at = None;
|
||||
let mut search_queries_executed_at = None;
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
match event {
|
||||
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
|
||||
start_time = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_prompt.md"),
|
||||
&info.search_prompt,
|
||||
)?;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
|
||||
search_queries_generated_at = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
serde_json::to_string_pretty(&info.search_queries).unwrap(),
|
||||
)?;
|
||||
}
|
||||
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
|
||||
search_queries_executed_at = Some(info.timestamp);
|
||||
}
|
||||
zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
|
||||
zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
|
||||
let prediction_started_at = Instant::now();
|
||||
start_time.get_or_insert(prediction_started_at);
|
||||
let prompt = request.local_prompt.unwrap_or_default();
|
||||
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
|
||||
{
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
result.prompt_len = prompt.chars().count();
|
||||
|
||||
if !options.use_expected_context {
|
||||
result.planning_search_time = Some(
|
||||
search_queries_generated_at.unwrap() - start_time.unwrap(),
|
||||
);
|
||||
result.running_search_time = Some(
|
||||
search_queries_executed_at.unwrap()
|
||||
- search_queries_generated_at.unwrap(),
|
||||
for included_file in request.request.included_files {
|
||||
let insertions =
|
||||
vec![(request.request.cursor_point, CURSOR_MARKER)];
|
||||
result.excerpts.extend(included_file.excerpts.iter().map(
|
||||
|excerpt| ActualExcerpt {
|
||||
path: included_file.path.components().skip(1).collect(),
|
||||
text: String::from(excerpt.text.as_ref()),
|
||||
},
|
||||
));
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
included_file.excerpts.iter(),
|
||||
if included_file.path == request.request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut result.excerpts_text,
|
||||
);
|
||||
}
|
||||
result.prediction_time =
|
||||
prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response = zeta2::text_from_response(response).unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
|
||||
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
|
||||
if !options.use_expected_context {
|
||||
result.planning_search_time = Some(
|
||||
search_queries_generated_at.unwrap() - start_time.unwrap(),
|
||||
);
|
||||
result.running_search_time = Some(
|
||||
search_queries_executed_at.unwrap()
|
||||
- search_queries_generated_at.unwrap(),
|
||||
);
|
||||
}
|
||||
result.prediction_time = prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
if options.use_expected_context {
|
||||
let context_excerpts_tasks = example
|
||||
.example
|
||||
.expected_context
|
||||
.iter()
|
||||
.flat_map(|section| {
|
||||
section.alternatives[0].excerpts.iter().map(|excerpt| {
|
||||
resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
|
||||
})
|
||||
if options.use_expected_context {
|
||||
let context_excerpts_tasks = example
|
||||
.example
|
||||
.expected_context
|
||||
.iter()
|
||||
.flat_map(|section| {
|
||||
section.alternatives[0].excerpts.iter().map(|excerpt| {
|
||||
resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let context_excerpts_vec =
|
||||
futures::future::try_join_all(context_excerpts_tasks).await?;
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let context_excerpts_vec =
|
||||
futures::future::try_join_all(context_excerpts_tasks).await?;
|
||||
|
||||
let mut context_excerpts = HashMap::default();
|
||||
for (buffer, mut excerpts) in context_excerpts_vec {
|
||||
context_excerpts
|
||||
.entry(buffer)
|
||||
.or_insert(Vec::new())
|
||||
.append(&mut excerpts);
|
||||
}
|
||||
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
zeta.set_context(project.clone(), context_excerpts)
|
||||
})?;
|
||||
} else {
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
|
||||
})?
|
||||
.await?;
|
||||
let mut context_excerpts = HashMap::default();
|
||||
for (buffer, mut excerpts) in context_excerpts_vec {
|
||||
context_excerpts
|
||||
.entry(buffer)
|
||||
.or_insert(Vec::new())
|
||||
.append(&mut excerpts);
|
||||
}
|
||||
|
||||
let prediction = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?
|
||||
.await?
|
||||
.map(|prediction| (prediction.buffer, prediction.snapshot, prediction.edits));
|
||||
|
||||
debug_task.await?;
|
||||
|
||||
prediction
|
||||
}
|
||||
crate::PredictionProvider::Sweep => sweep
|
||||
.unwrap()
|
||||
.update(cx, |sweep, cx| {
|
||||
let mut recent_paths = Vec::new();
|
||||
for path in zeta
|
||||
.read(cx)
|
||||
.history_for_project(&project)
|
||||
.rev()
|
||||
.filter_map(|event| event.project_path(cx))
|
||||
{
|
||||
if !recent_paths.contains(&path) {
|
||||
recent_paths.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
sweep.request_completion(
|
||||
&project,
|
||||
recent_paths.into_iter(),
|
||||
&cursor_buffer,
|
||||
cursor_anchor,
|
||||
cx,
|
||||
)
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
zeta.set_context(project.clone(), context_excerpts)
|
||||
})?;
|
||||
} else {
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
|
||||
})?
|
||||
.await?
|
||||
.map(
|
||||
|sweep_ai::EditPrediction {
|
||||
edits, snapshot, ..
|
||||
}| { (cursor_buffer.clone(), snapshot, edits) },
|
||||
),
|
||||
};
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let prediction = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
debug_task.await?;
|
||||
|
||||
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
|
||||
|
||||
result.diff = prediction
|
||||
.map(|(buffer, snapshot, edits)| {
|
||||
let old_text = snapshot.text();
|
||||
let new_text = buffer
|
||||
.map(|prediction| {
|
||||
let old_text = prediction.snapshot.text();
|
||||
let new_text = prediction
|
||||
.buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
let branch = buffer.branch(cx);
|
||||
branch.update(cx, |branch, cx| {
|
||||
branch.edit(edits.iter().cloned(), None, cx);
|
||||
branch.edit(prediction.edits.iter().cloned(), None, cx);
|
||||
branch.text()
|
||||
})
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in a new issue