mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
Add agent thread sharing (#46140)
Staff only ship for now Here's the agent planning doc that guided this: https://gist.github.com/mikayla-maki/c826b7997bd85b58273c1def9397940b Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
8df27897e3
commit
3da926981c
30 changed files with 1132 additions and 13 deletions
5
Cargo.lock
generated
5
Cargo.lock
generated
|
|
@ -3280,7 +3280,10 @@ dependencies = [
|
|||
name = "collab"
|
||||
version = "0.44.0"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent_settings",
|
||||
"agent_ui",
|
||||
"anyhow",
|
||||
"assistant_slash_command",
|
||||
"assistant_text_thread",
|
||||
|
|
@ -20662,6 +20665,8 @@ version = "0.219.0"
|
|||
dependencies = [
|
||||
"acp_tools",
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent_settings",
|
||||
"agent_ui",
|
||||
"agent_ui_v2",
|
||||
|
|
|
|||
6
Procfile.all
Normal file
6
Procfile.all
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
|
||||
cloud: cd ../cloud; cargo make dev
|
||||
dashboard: cd ../cloud/packages/dashboard; pnpm dev
|
||||
website: cd ../zed.dev; pnpm dev --port=3000
|
||||
livekit: livekit-server --dev
|
||||
blob_store: ./script/run-local-minio
|
||||
|
|
@ -50,6 +50,63 @@ pub struct DbThread {
|
|||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
#[serde(default)]
|
||||
pub imported: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SharedThread {
|
||||
pub title: SharedString,
|
||||
pub messages: Vec<DbMessage>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub model: Option<DbLanguageModel>,
|
||||
#[serde(default)]
|
||||
pub completion_mode: Option<CompletionMode>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl SharedThread {
|
||||
pub const VERSION: &'static str = "1.0.0";
|
||||
|
||||
pub fn from_db_thread(thread: &DbThread) -> Self {
|
||||
Self {
|
||||
title: thread.title.clone(),
|
||||
messages: thread.messages.clone(),
|
||||
updated_at: thread.updated_at,
|
||||
model: thread.model.clone(),
|
||||
completion_mode: thread.completion_mode,
|
||||
version: Self::VERSION.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_db_thread(self) -> DbThread {
|
||||
DbThread {
|
||||
title: format!("🔗 {}", self.title).into(),
|
||||
messages: self.messages,
|
||||
updated_at: self.updated_at,
|
||||
detailed_summary: None,
|
||||
initial_project_snapshot: None,
|
||||
cumulative_token_usage: Default::default(),
|
||||
request_token_usage: Default::default(),
|
||||
model: self.model,
|
||||
completion_mode: self.completion_mode,
|
||||
profile: None,
|
||||
imported: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bytes(&self) -> Result<Vec<u8>> {
|
||||
const COMPRESSION_LEVEL: i32 = 3;
|
||||
let json = serde_json::to_vec(self)?;
|
||||
let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
|
||||
Ok(compressed)
|
||||
}
|
||||
|
||||
pub fn from_bytes(data: &[u8]) -> Result<Self> {
|
||||
let decompressed = zstd::decode_all(data)?;
|
||||
Ok(serde_json::from_slice(&decompressed)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl DbThread {
|
||||
|
|
@ -209,6 +266,7 @@ impl DbThread {
|
|||
model: thread.model,
|
||||
completion_mode: thread.completion_mode,
|
||||
profile: thread.profile,
|
||||
imported: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -441,3 +499,45 @@ impl ThreadsDatabase {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
|
||||
#[test]
|
||||
fn test_shared_thread_roundtrip() {
|
||||
let original = SharedThread {
|
||||
title: "Test Thread".into(),
|
||||
messages: vec![],
|
||||
updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
version: SharedThread::VERSION.to_string(),
|
||||
};
|
||||
|
||||
let bytes = original.to_bytes().expect("Failed to serialize");
|
||||
let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(restored.title, original.title);
|
||||
assert_eq!(restored.version, original.version);
|
||||
assert_eq!(restored.updated_at, original.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_imported_flag_defaults_to_false() {
|
||||
// Simulate deserializing a thread without the imported field (backwards compatibility).
|
||||
let json = r#"{
|
||||
"title": "Old Thread",
|
||||
"messages": [],
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}"#;
|
||||
|
||||
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
|
||||
|
||||
assert!(
|
||||
!db_thread.imported,
|
||||
"Legacy threads without imported field should default to false"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -175,6 +175,20 @@ impl HistoryStore {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn save_thread(
|
||||
&mut self,
|
||||
id: acp::SessionId,
|
||||
thread: crate::DbThread,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let database = database_future.await.map_err(|err| anyhow!(err))?;
|
||||
database.save_thread(id, thread).await?;
|
||||
this.update(cx, |this, cx| this.reload(cx))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete_thread(
|
||||
&mut self,
|
||||
id: acp::SessionId,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ pub struct SerializedThread {
|
|||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
|
||||
pub struct SerializedLanguageModel {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
|
|
|
|||
|
|
@ -622,6 +622,8 @@ pub struct Thread {
|
|||
pub(crate) action_log: Entity<ActionLog>,
|
||||
/// Tracks the last time files were read by the agent, to detect external modifications
|
||||
pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
|
||||
/// True if this thread was imported from a shared thread and can be synced.
|
||||
imported: bool,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
|
|
@ -678,6 +680,7 @@ impl Thread {
|
|||
project,
|
||||
action_log,
|
||||
file_read_times: HashMap::default(),
|
||||
imported: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -685,6 +688,11 @@ impl Thread {
|
|||
&self.id
|
||||
}
|
||||
|
||||
/// Returns true if this thread was imported from a shared thread.
|
||||
pub fn is_imported(&self) -> bool {
|
||||
self.imported
|
||||
}
|
||||
|
||||
pub fn replay(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
|
|
@ -866,6 +874,7 @@ impl Thread {
|
|||
prompt_capabilities_tx,
|
||||
prompt_capabilities_rx,
|
||||
file_read_times: HashMap::default(),
|
||||
imported: db_thread.imported,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -885,6 +894,7 @@ impl Thread {
|
|||
}),
|
||||
completion_mode: Some(self.completion_mode),
|
||||
profile: Some(self.profile_id.clone()),
|
||||
imported: self.imported,
|
||||
};
|
||||
|
||||
cx.background_spawn(async move {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ use acp_thread::{
|
|||
};
|
||||
use acp_thread::{AgentConnection, Plan};
|
||||
use action_log::{ActionLog, ActionLogTelemetry};
|
||||
use agent::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer};
|
||||
use agent::{
|
||||
DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer, SharedThread,
|
||||
};
|
||||
use agent_client_protocol::{self as acp, PromptCapabilities};
|
||||
use agent_servers::{AgentServer, AgentServerDelegate};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
|
|
@ -20,15 +22,16 @@ use editor::scroll::Autoscroll;
|
|||
use editor::{
|
||||
Editor, EditorEvent, EditorMode, MultiBuffer, PathKey, SelectionEffects, SizingBehavior,
|
||||
};
|
||||
use feature_flags::{AgentSharingFeatureFlag, FeatureFlagAppExt};
|
||||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::FutureExt as _;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, CursorStyle,
|
||||
EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset,
|
||||
ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task, TextStyle,
|
||||
TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div, ease_in_out,
|
||||
linear_color_stop, linear_gradient, list, point, pulsating_between,
|
||||
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
|
||||
CursorStyle, EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length,
|
||||
ListOffset, ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task,
|
||||
TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
|
||||
ease_in_out, linear_color_stop, linear_gradient, list, point, pulsating_between,
|
||||
};
|
||||
use language::Buffer;
|
||||
|
||||
|
|
@ -52,7 +55,7 @@ use ui::{
|
|||
WithScrollbar, prelude::*, right_click_menu,
|
||||
};
|
||||
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
|
||||
use workspace::{CollaboratorId, NewTerminal, Workspace};
|
||||
use workspace::{CollaboratorId, NewTerminal, Toast, Workspace, notifications::NotificationId};
|
||||
use zed_actions::agent::{Chat, ToggleModelSelector};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
|
|
@ -935,6 +938,124 @@ impl AcpThreadView {
|
|||
}
|
||||
}
|
||||
|
||||
fn share_thread(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(thread) = self.as_native_thread(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let client = self.project.read(cx).client();
|
||||
let workspace = self.workspace.clone();
|
||||
let session_id = thread.read(cx).id().to_string();
|
||||
|
||||
let load_task = thread.read(cx).to_db(cx);
|
||||
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let db_thread = load_task.await;
|
||||
|
||||
let shared_thread = SharedThread::from_db_thread(&db_thread);
|
||||
let thread_data = shared_thread.to_bytes()?;
|
||||
let title = shared_thread.title.to_string();
|
||||
|
||||
client
|
||||
.request(proto::ShareAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
title,
|
||||
thread_data,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let share_url = client::zed_urls::shared_agent_thread_url(&session_id);
|
||||
|
||||
cx.update(|cx| {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
struct ThreadSharedToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<ThreadSharedToast>(),
|
||||
"Thread shared!",
|
||||
)
|
||||
.on_click(
|
||||
"Copy URL",
|
||||
move |_window, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(
|
||||
share_url.clone(),
|
||||
));
|
||||
},
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn sync_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if !self.is_imported_thread(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(thread) = self.as_native_thread(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let client = self.project.read(cx).client();
|
||||
let history_store = self.history_store.clone();
|
||||
let session_id = thread.read(cx).id().clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let response = client
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: session_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let shared_thread = SharedThread::from_bytes(&response.thread_data)?;
|
||||
|
||||
let db_thread = shared_thread.to_db_thread();
|
||||
|
||||
history_store
|
||||
.update(&mut cx.clone(), |store, cx| {
|
||||
store.save_thread(session_id.clone(), db_thread, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let thread_metadata = agent::DbThreadMetadata {
|
||||
id: session_id,
|
||||
title: format!("🔗 {}", response.title).into(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.resume_thread_metadata = Some(thread_metadata);
|
||||
this.reset(window, cx);
|
||||
})?;
|
||||
|
||||
this.update_in(cx, |this, _window, cx| {
|
||||
if let Some(workspace) = this.workspace.upgrade() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
struct ThreadSyncedToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<ThreadSyncedToast>(),
|
||||
"Thread synced with latest version",
|
||||
)
|
||||
.autohide(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn expand_message_editor(
|
||||
&mut self,
|
||||
_: &ExpandMessageEditor,
|
||||
|
|
@ -4904,6 +5025,13 @@ impl AcpThreadView {
|
|||
.thread(acp_thread.session_id(), cx)
|
||||
}
|
||||
|
||||
fn is_imported_thread(&self, cx: &App) -> bool {
|
||||
let Some(thread) = self.as_native_thread(cx) else {
|
||||
return false;
|
||||
};
|
||||
thread.read(cx).is_imported()
|
||||
}
|
||||
|
||||
fn is_using_zed_ai_models(&self, cx: &App) -> bool {
|
||||
self.as_native_thread(cx)
|
||||
.and_then(|thread| thread.read(cx).model())
|
||||
|
|
@ -5819,6 +5947,41 @@ impl AcpThreadView {
|
|||
);
|
||||
}
|
||||
|
||||
if cx.has_flag::<AgentSharingFeatureFlag>()
|
||||
&& self.is_imported_thread(cx)
|
||||
&& self
|
||||
.project
|
||||
.read(cx)
|
||||
.client()
|
||||
.status()
|
||||
.borrow()
|
||||
.is_connected()
|
||||
{
|
||||
let sync_button = IconButton::new("sync-thread", IconName::ArrowCircle)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Ignored)
|
||||
.tooltip(Tooltip::text("Sync with source thread"))
|
||||
.on_click(cx.listener(move |this, _, window, cx| {
|
||||
this.sync_thread(window, cx);
|
||||
}));
|
||||
|
||||
container = container.child(sync_button);
|
||||
}
|
||||
|
||||
if cx.has_flag::<AgentSharingFeatureFlag>() && !self.is_imported_thread(cx) {
|
||||
let share_button = IconButton::new("share-thread", IconName::ArrowUpRight)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Ignored)
|
||||
.tooltip(Tooltip::text("Share Thread"))
|
||||
.on_click(cx.listener(move |this, _, window, cx| {
|
||||
this.share_thread(window, cx);
|
||||
}));
|
||||
|
||||
container = container.child(share_button);
|
||||
}
|
||||
|
||||
container
|
||||
.child(open_as_markdown)
|
||||
.child(scroll_to_recent_user_prompt)
|
||||
|
|
|
|||
|
|
@ -720,10 +720,25 @@ impl AgentPanel {
|
|||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub(crate) fn thread_store(&self) -> &Entity<HistoryStore> {
|
||||
pub fn thread_store(&self) -> &Entity<HistoryStore> {
|
||||
&self.history_store
|
||||
}
|
||||
|
||||
pub fn open_thread(
|
||||
&mut self,
|
||||
thread: DbThreadMetadata,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.external_thread(
|
||||
Some(crate::ExternalAgent::NativeAgent),
|
||||
Some(thread),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn context_server_registry(&self) -> &Entity<ContextServerRegistry> {
|
||||
&self.context_server_registry
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,3 +67,7 @@ pub fn edit_prediction_docs(cx: &App) -> String {
|
|||
server_url = server_url(cx)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn shared_agent_thread_url(session_id: &str) -> String {
|
||||
format!("zed://agent/shared/{}", session_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,10 @@ util.workspace = true
|
|||
uuid.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
agent = { workspace = true, features = ["test-support"] }
|
||||
agent-client-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agent_ui = { workspace = true, features = ["test-support"] }
|
||||
assistant_text_thread.workspace = true
|
||||
assistant_slash_command.workspace = true
|
||||
async-trait.workspace = true
|
||||
|
|
|
|||
|
|
@ -460,3 +460,14 @@ CREATE TABLE IF NOT EXISTS "breakpoints" (
|
|||
);
|
||||
|
||||
CREATE INDEX "index_breakpoints_on_project_id" ON "breakpoints" ("project_id");
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "shared_threads" (
|
||||
"id" TEXT PRIMARY KEY NOT NULL,
|
||||
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
|
||||
"title" VARCHAR(512) NOT NULL,
|
||||
"data" BLOB NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX "index_shared_threads_user_id" ON "shared_threads" ("user_id");
|
||||
|
|
|
|||
|
|
@ -430,6 +430,15 @@ CREATE SEQUENCE public.servers_id_seq
|
|||
|
||||
ALTER SEQUENCE public.servers_id_seq OWNED BY public.servers.id;
|
||||
|
||||
CREATE TABLE public.shared_threads (
|
||||
id uuid NOT NULL,
|
||||
user_id integer NOT NULL,
|
||||
title text NOT NULL,
|
||||
data bytea NOT NULL,
|
||||
created_at timestamp without time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp without time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE public.user_features (
|
||||
user_id integer NOT NULL,
|
||||
feature_id integer NOT NULL
|
||||
|
|
@ -630,6 +639,9 @@ ALTER TABLE ONLY public.rooms
|
|||
ALTER TABLE ONLY public.servers
|
||||
ADD CONSTRAINT servers_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY public.shared_threads
|
||||
ADD CONSTRAINT shared_threads_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY public.user_features
|
||||
ADD CONSTRAINT user_features_pkey PRIMARY KEY (user_id, feature_id);
|
||||
|
||||
|
|
@ -648,6 +660,8 @@ ALTER TABLE ONLY public.worktree_settings_files
|
|||
ALTER TABLE ONLY public.worktrees
|
||||
ADD CONSTRAINT worktrees_pkey PRIMARY KEY (project_id, id);
|
||||
|
||||
CREATE INDEX idx_shared_threads_user_id ON public.shared_threads USING btree (user_id);
|
||||
|
||||
CREATE INDEX index_access_tokens_user_id ON public.access_tokens USING btree (user_id);
|
||||
|
||||
CREATE INDEX index_breakpoints_on_project_id ON public.breakpoints USING btree (project_id);
|
||||
|
|
@ -879,6 +893,9 @@ ALTER TABLE ONLY public.room_participants
|
|||
ALTER TABLE ONLY public.rooms
|
||||
ADD CONSTRAINT rooms_channel_id_fkey FOREIGN KEY (channel_id) REFERENCES public.channels(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY public.shared_threads
|
||||
ADD CONSTRAINT shared_threads_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY public.user_features
|
||||
ADD CONSTRAINT user_features_feature_id_fkey FOREIGN KEY (feature_id) REFERENCES public.feature_flags(id) ON DELETE CASCADE;
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use crate::Result;
|
|||
use rpc::proto;
|
||||
use sea_orm::{DbErr, entity::prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! id_type {
|
||||
|
|
@ -92,6 +93,39 @@ id_type!(ServerId);
|
|||
id_type!(SignupId);
|
||||
id_type!(UserId);
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, DeriveValueType)]
|
||||
pub struct SharedThreadId(pub Uuid);
|
||||
|
||||
impl SharedThreadId {
|
||||
pub fn from_proto(id: String) -> Option<Self> {
|
||||
Uuid::parse_str(&id).ok().map(SharedThreadId)
|
||||
}
|
||||
|
||||
pub fn to_proto(self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl sea_orm::TryFromU64 for SharedThreadId {
|
||||
fn try_from_u64(_n: u64) -> std::result::Result<Self, DbErr> {
|
||||
Err(DbErr::ConvertFromU64(
|
||||
"SharedThreadId uses UUID and cannot be converted from u64",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl sea_orm::sea_query::Nullable for SharedThreadId {
|
||||
fn null() -> Value {
|
||||
Value::Uuid(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SharedThreadId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
/// ChannelRole gives you permissions for both channels and calls.
|
||||
#[derive(
|
||||
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ pub mod notifications;
|
|||
pub mod projects;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
pub mod shared_threads;
|
||||
pub mod users;
|
||||
|
|
|
|||
77
crates/collab/src/db/queries/shared_threads.rs
Normal file
77
crates/collab/src/db/queries/shared_threads.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
use chrono::Utc;
|
||||
|
||||
use super::*;
|
||||
use crate::db::tables::shared_thread;
|
||||
|
||||
impl Database {
|
||||
pub async fn upsert_shared_thread(
|
||||
&self,
|
||||
id: SharedThreadId,
|
||||
user_id: UserId,
|
||||
title: &str,
|
||||
data: Vec<u8>,
|
||||
) -> Result<()> {
|
||||
let title = title.to_string();
|
||||
self.transaction(|tx| {
|
||||
let title = title.clone();
|
||||
let data = data.clone();
|
||||
async move {
|
||||
let now = Utc::now().naive_utc();
|
||||
|
||||
let existing = shared_thread::Entity::find_by_id(id).one(&*tx).await?;
|
||||
|
||||
match existing {
|
||||
Some(existing) => {
|
||||
if existing.user_id != user_id {
|
||||
Err(anyhow!("Cannot update shared thread owned by another user"))?;
|
||||
}
|
||||
|
||||
let mut active: shared_thread::ActiveModel = existing.into();
|
||||
active.title = ActiveValue::Set(title);
|
||||
active.data = ActiveValue::Set(data);
|
||||
active.updated_at = ActiveValue::Set(now);
|
||||
active.update(&*tx).await?;
|
||||
}
|
||||
None => {
|
||||
shared_thread::ActiveModel {
|
||||
id: ActiveValue::Set(id),
|
||||
user_id: ActiveValue::Set(user_id),
|
||||
title: ActiveValue::Set(title),
|
||||
data: ActiveValue::Set(data),
|
||||
created_at: ActiveValue::Set(now),
|
||||
updated_at: ActiveValue::Set(now),
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_shared_thread(
|
||||
&self,
|
||||
share_id: SharedThreadId,
|
||||
) -> Result<Option<(shared_thread::Model, String)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let Some(thread) = shared_thread::Entity::find_by_id(share_id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let user = user::Entity::find_by_id(thread.user_id).one(&*tx).await?;
|
||||
|
||||
let username = user
|
||||
.map(|u| u.github_login)
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
Ok(Some((thread, username)))
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ pub mod project_repository_statuses;
|
|||
pub mod room;
|
||||
pub mod room_participant;
|
||||
pub mod server;
|
||||
pub mod shared_thread;
|
||||
pub mod user;
|
||||
pub mod worktree;
|
||||
pub mod worktree_diagnostic_summary;
|
||||
|
|
|
|||
32
crates/collab/src/db/tables/shared_thread.rs
Normal file
32
crates/collab/src/db/tables/shared_thread.rs
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
use crate::db::{SharedThreadId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "shared_threads")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub id: SharedThreadId,
|
||||
pub user_id: UserId,
|
||||
pub title: String,
|
||||
pub data: Vec<u8>,
|
||||
pub created_at: DateTime,
|
||||
pub updated_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
|
@ -586,3 +586,121 @@ async fn test_fuzzy_search_users(cx: &mut gpui::TestAppContext) {
|
|||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_upsert_shared_thread,
|
||||
test_upsert_shared_thread_postgres,
|
||||
test_upsert_shared_thread_sqlite
|
||||
);
|
||||
|
||||
async fn test_upsert_shared_thread(db: &Arc<Database>) {
|
||||
use crate::db::SharedThreadId;
|
||||
use uuid::Uuid;
|
||||
|
||||
let user_id = new_test_user(db, "user1@example.com").await;
|
||||
|
||||
let thread_id = SharedThreadId(Uuid::new_v4());
|
||||
let title = "My Test Thread";
|
||||
let data = b"test thread data".to_vec();
|
||||
|
||||
db.upsert_shared_thread(thread_id, user_id, title, data.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = db.get_shared_thread(thread_id).await.unwrap();
|
||||
assert!(result.is_some(), "Should find the shared thread");
|
||||
|
||||
let (thread, username) = result.unwrap();
|
||||
assert_eq!(thread.title, title);
|
||||
assert_eq!(thread.data, data);
|
||||
assert_eq!(thread.user_id, user_id);
|
||||
assert_eq!(username, "user1");
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_upsert_shared_thread_updates_existing,
|
||||
test_upsert_shared_thread_updates_existing_postgres,
|
||||
test_upsert_shared_thread_updates_existing_sqlite
|
||||
);
|
||||
|
||||
async fn test_upsert_shared_thread_updates_existing(db: &Arc<Database>) {
|
||||
use crate::db::SharedThreadId;
|
||||
use uuid::Uuid;
|
||||
|
||||
let user_id = new_test_user(db, "user1@example.com").await;
|
||||
|
||||
let thread_id = SharedThreadId(Uuid::new_v4());
|
||||
|
||||
// Create initial thread.
|
||||
db.upsert_shared_thread(
|
||||
thread_id,
|
||||
user_id,
|
||||
"Original Title",
|
||||
b"original data".to_vec(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Update the same thread.
|
||||
db.upsert_shared_thread(
|
||||
thread_id,
|
||||
user_id,
|
||||
"Updated Title",
|
||||
b"updated data".to_vec(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = db.get_shared_thread(thread_id).await.unwrap();
|
||||
let (thread, _) = result.unwrap();
|
||||
|
||||
assert_eq!(thread.title, "Updated Title");
|
||||
assert_eq!(thread.data, b"updated data".to_vec());
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_cannot_update_another_users_shared_thread,
|
||||
test_cannot_update_another_users_shared_thread_postgres,
|
||||
test_cannot_update_another_users_shared_thread_sqlite
|
||||
);
|
||||
|
||||
async fn test_cannot_update_another_users_shared_thread(db: &Arc<Database>) {
|
||||
use crate::db::SharedThreadId;
|
||||
use uuid::Uuid;
|
||||
|
||||
let user1_id = new_test_user(db, "user1@example.com").await;
|
||||
let user2_id = new_test_user(db, "user2@example.com").await;
|
||||
|
||||
let thread_id = SharedThreadId(Uuid::new_v4());
|
||||
|
||||
db.upsert_shared_thread(thread_id, user1_id, "User 1 Thread", b"user1 data".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = db
|
||||
.upsert_shared_thread(thread_id, user2_id, "User 2 Title", b"user2 data".to_vec())
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should not allow updating another user's thread"
|
||||
);
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_get_nonexistent_shared_thread,
|
||||
test_get_nonexistent_shared_thread_postgres,
|
||||
test_get_nonexistent_shared_thread_sqlite
|
||||
);
|
||||
|
||||
async fn test_get_nonexistent_shared_thread(db: &Arc<Database>) {
|
||||
use crate::db::SharedThreadId;
|
||||
use uuid::Uuid;
|
||||
|
||||
let result = db
|
||||
.get_shared_thread(SharedThreadId(Uuid::new_v4()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.is_none(), "Should not find non-existent thread");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ use crate::{
|
|||
db::{
|
||||
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
|
||||
InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
|
||||
RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
|
||||
RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
|
||||
UserId,
|
||||
},
|
||||
executor::Executor,
|
||||
};
|
||||
|
|
@ -465,7 +466,9 @@ impl Server {
|
|||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||
.add_message_handler(update_context)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
|
||||
.add_request_handler(share_agent_thread)
|
||||
.add_request_handler(get_shared_agent_thread);
|
||||
|
||||
Arc::new(server)
|
||||
}
|
||||
|
|
@ -4016,6 +4019,54 @@ fn project_left(project: &db::LeftProject, session: &Session) {
|
|||
}
|
||||
}
|
||||
|
||||
async fn share_agent_thread(
|
||||
request: proto::ShareAgentThread,
|
||||
response: Response<proto::ShareAgentThread>,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let user_id = session.user_id();
|
||||
|
||||
let share_id = SharedThreadId::from_proto(request.session_id.clone())
|
||||
.ok_or_else(|| anyhow!("Invalid session ID format"))?;
|
||||
|
||||
session
|
||||
.db()
|
||||
.await
|
||||
.upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
|
||||
.await?;
|
||||
|
||||
response.send(proto::Ack {})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_shared_agent_thread(
|
||||
request: proto::GetSharedAgentThread,
|
||||
response: Response<proto::GetSharedAgentThread>,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let share_id = SharedThreadId::from_proto(request.session_id)
|
||||
.ok_or_else(|| anyhow!("Invalid session ID format"))?;
|
||||
|
||||
let result = session.db().await.get_shared_thread(share_id).await?;
|
||||
|
||||
match result {
|
||||
Some((thread, username)) => {
|
||||
response.send(proto::GetSharedAgentThreadResponse {
|
||||
title: thread.title,
|
||||
thread_data: thread.data,
|
||||
sharer_username: username,
|
||||
created_at: thread.created_at.and_utc().to_rfc3339(),
|
||||
})?;
|
||||
}
|
||||
None => {
|
||||
return Err(anyhow!("Shared thread not found").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub trait ResultExt {
|
||||
type Ok;
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use call::Room;
|
|||
use client::ChannelId;
|
||||
use gpui::{Entity, TestAppContext};
|
||||
|
||||
mod agent_sharing_tests;
|
||||
mod channel_buffer_tests;
|
||||
mod channel_guest_tests;
|
||||
mod channel_tests;
|
||||
|
|
|
|||
217
crates/collab/src/tests/agent_sharing_tests.rs
Normal file
217
crates/collab/src/tests/agent_sharing_tests.rs
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
use agent::SharedThread;
|
||||
use gpui::{BackgroundExecutor, TestAppContext};
|
||||
use rpc::proto;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tests::TestServer;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_share_and_retrieve_thread(
|
||||
executor: BackgroundExecutor,
|
||||
cx_a: &mut TestAppContext,
|
||||
cx_b: &mut TestAppContext,
|
||||
) {
|
||||
let mut server = TestServer::start(executor.clone()).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
executor.run_until_parked();
|
||||
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
|
||||
let original_thread = SharedThread {
|
||||
title: "Shared Test Thread".into(),
|
||||
messages: vec![],
|
||||
updated_at: chrono::Utc::now(),
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
version: SharedThread::VERSION.to_string(),
|
||||
};
|
||||
|
||||
let thread_data = original_thread
|
||||
.to_bytes()
|
||||
.expect("Failed to serialize thread");
|
||||
|
||||
client_a
|
||||
.client()
|
||||
.request(proto::ShareAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
title: original_thread.title.to_string(),
|
||||
thread_data,
|
||||
})
|
||||
.await
|
||||
.expect("Failed to share thread");
|
||||
|
||||
let get_response = client_b
|
||||
.client()
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to get shared thread");
|
||||
|
||||
let imported_shared_thread =
|
||||
SharedThread::from_bytes(&get_response.thread_data).expect("Failed to deserialize thread");
|
||||
|
||||
assert_eq!(imported_shared_thread.title, original_thread.title);
|
||||
assert_eq!(imported_shared_thread.version, SharedThread::VERSION);
|
||||
|
||||
let db_thread = imported_shared_thread.to_db_thread();
|
||||
|
||||
assert!(
|
||||
db_thread.title.starts_with("🔗"),
|
||||
"Imported thread title should have link prefix"
|
||||
);
|
||||
assert!(
|
||||
db_thread.title.contains("Shared Test Thread"),
|
||||
"Imported thread should preserve original title"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_reshare_updates_existing_thread(
|
||||
executor: BackgroundExecutor,
|
||||
cx_a: &mut TestAppContext,
|
||||
cx_b: &mut TestAppContext,
|
||||
) {
|
||||
let mut server = TestServer::start(executor.clone()).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
executor.run_until_parked();
|
||||
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
|
||||
client_a
|
||||
.client()
|
||||
.request(proto::ShareAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
title: "Original Title".to_string(),
|
||||
thread_data: b"original data".to_vec(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to share thread");
|
||||
|
||||
client_a
|
||||
.client()
|
||||
.request(proto::ShareAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
title: "Updated Title".to_string(),
|
||||
thread_data: b"updated data".to_vec(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to re-share thread");
|
||||
|
||||
let get_response = client_b
|
||||
.client()
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to get shared thread");
|
||||
|
||||
assert_eq!(get_response.title, "Updated Title");
|
||||
assert_eq!(get_response.thread_data, b"updated data".to_vec());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_get_nonexistent_thread(executor: BackgroundExecutor, cx: &mut TestAppContext) {
|
||||
let mut server = TestServer::start(executor.clone()).await;
|
||||
let client = server.create_client(cx, "user_a").await;
|
||||
|
||||
executor.run_until_parked();
|
||||
|
||||
let nonexistent_session_id = Uuid::new_v4().to_string();
|
||||
|
||||
let result = client
|
||||
.client()
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: nonexistent_session_id,
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err(), "Should fail for nonexistent thread");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_sync_imported_thread(
|
||||
executor: BackgroundExecutor,
|
||||
cx_a: &mut TestAppContext,
|
||||
cx_b: &mut TestAppContext,
|
||||
) {
|
||||
let mut server = TestServer::start(executor.clone()).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
executor.run_until_parked();
|
||||
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
|
||||
// User A shares a thread with initial content.
|
||||
let initial_thread = SharedThread {
|
||||
title: "Initial Title".into(),
|
||||
messages: vec![],
|
||||
updated_at: chrono::Utc::now(),
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
version: SharedThread::VERSION.to_string(),
|
||||
};
|
||||
|
||||
client_a
|
||||
.client()
|
||||
.request(proto::ShareAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
title: initial_thread.title.to_string(),
|
||||
thread_data: initial_thread.to_bytes().expect("Failed to serialize"),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to share thread");
|
||||
|
||||
// User B imports the thread.
|
||||
let initial_response = client_b
|
||||
.client()
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to get shared thread");
|
||||
|
||||
let initial_imported =
|
||||
SharedThread::from_bytes(&initial_response.thread_data).expect("Failed to deserialize");
|
||||
assert_eq!(initial_imported.title.as_ref(), "Initial Title");
|
||||
|
||||
// User A updates the shared thread.
|
||||
let updated_thread = SharedThread {
|
||||
title: "Updated Title".into(),
|
||||
messages: vec![],
|
||||
updated_at: chrono::Utc::now(),
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
version: SharedThread::VERSION.to_string(),
|
||||
};
|
||||
|
||||
client_a
|
||||
.client()
|
||||
.request(proto::ShareAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
title: updated_thread.title.to_string(),
|
||||
thread_data: updated_thread.to_bytes().expect("Failed to serialize"),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to re-share thread");
|
||||
|
||||
// User B syncs the imported thread (fetches the latest version).
|
||||
let synced_response = client_b
|
||||
.client()
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("Failed to sync shared thread");
|
||||
|
||||
let synced_thread =
|
||||
SharedThread::from_bytes(&synced_response.thread_data).expect("Failed to deserialize");
|
||||
|
||||
// The synced thread should have the updated title.
|
||||
assert_eq!(synced_thread.title.as_ref(), "Updated Title");
|
||||
}
|
||||
|
|
@ -29,3 +29,9 @@ pub struct AcpBetaFeatureFlag;
|
|||
impl FeatureFlag for AcpBetaFeatureFlag {
|
||||
const NAME: &'static str = "acp-beta";
|
||||
}
|
||||
|
||||
pub struct AgentSharingFeatureFlag;
|
||||
|
||||
impl FeatureFlag for AgentSharingFeatureFlag {
|
||||
const NAME: &'static str = "agent-sharing";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -218,3 +218,20 @@ message NewExternalAgentVersionAvailable {
|
|||
string name = 2;
|
||||
string version = 3;
|
||||
}
|
||||
|
||||
message ShareAgentThread {
|
||||
string session_id = 1; // Client-generated UUID (acp::SessionId)
|
||||
string title = 2;
|
||||
bytes thread_data = 3;
|
||||
}
|
||||
|
||||
message GetSharedAgentThread {
|
||||
string session_id = 1; // UUID string
|
||||
}
|
||||
|
||||
message GetSharedAgentThreadResponse {
|
||||
string title = 1;
|
||||
bytes thread_data = 2;
|
||||
string sharer_username = 3;
|
||||
string created_at = 4;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -447,7 +447,11 @@ message Envelope {
|
|||
GitRemoveRemote git_remove_remote = 403;
|
||||
|
||||
TrustWorktrees trust_worktrees = 404;
|
||||
RestrictWorktrees restrict_worktrees = 405; // current max
|
||||
RestrictWorktrees restrict_worktrees = 405;
|
||||
|
||||
ShareAgentThread share_agent_thread = 406;
|
||||
GetSharedAgentThread get_shared_agent_thread = 407;
|
||||
GetSharedAgentThreadResponse get_shared_agent_thread_response = 408; // current max
|
||||
}
|
||||
|
||||
reserved 87 to 88;
|
||||
|
|
|
|||
|
|
@ -342,7 +342,10 @@ messages!(
|
|||
(RemoteStarted, Background),
|
||||
(GitGetWorktrees, Background),
|
||||
(GitWorktreesResponse, Background),
|
||||
(GitCreateWorktree, Background)
|
||||
(GitCreateWorktree, Background),
|
||||
(ShareAgentThread, Foreground),
|
||||
(GetSharedAgentThread, Foreground),
|
||||
(GetSharedAgentThreadResponse, Foreground)
|
||||
);
|
||||
|
||||
request_messages!(
|
||||
|
|
@ -441,6 +444,8 @@ request_messages!(
|
|||
(SendChannelMessage, SendChannelMessageResponse),
|
||||
(SetChannelMemberRole, Ack),
|
||||
(SetChannelVisibility, Ack),
|
||||
(ShareAgentThread, Ack),
|
||||
(GetSharedAgentThread, GetSharedAgentThreadResponse),
|
||||
(ShareProject, ShareProjectResponse),
|
||||
(SynchronizeBuffers, SynchronizeBuffersResponse),
|
||||
(TaskContextForLocation, TaskContext),
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ required-features = ["visual-tests"]
|
|||
[dependencies]
|
||||
acp_tools.workspace = true
|
||||
activity_indicator.workspace = true
|
||||
agent.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agent_ui.workspace = true
|
||||
agent_ui_v2.workspace = true
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
mod reliability;
|
||||
mod zed;
|
||||
|
||||
use agent::{HistoryStore, SharedThread};
|
||||
use agent_client_protocol;
|
||||
use agent_ui::AgentPanel;
|
||||
use anyhow::{Context as _, Error, Result};
|
||||
use clap::Parser;
|
||||
|
|
@ -33,6 +35,7 @@ use assets::Assets;
|
|||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use parking_lot::Mutex;
|
||||
use project::{project_settings::ProjectSettings, trusted_worktrees};
|
||||
use proto;
|
||||
use recent_projects::{RemoteSettings, open_remote_project};
|
||||
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
|
||||
use session::{AppSession, Session};
|
||||
|
|
@ -837,6 +840,73 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
|
|||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
OpenRequestKind::SharedAgentThread { session_id } => {
|
||||
cx.spawn(async move |cx| {
|
||||
let workspace =
|
||||
workspace::get_any_active_workspace(app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let (client, history_store) =
|
||||
workspace.update(cx, |workspace, _window, cx| {
|
||||
let client = workspace.project().read(cx).client();
|
||||
let history_store: Option<gpui::Entity<HistoryStore>> = workspace
|
||||
.panel::<AgentPanel>(cx)
|
||||
.map(|panel| panel.read(cx).thread_store().clone());
|
||||
(client, history_store)
|
||||
})?;
|
||||
|
||||
let Some(history_store): Option<gpui::Entity<HistoryStore>> = history_store
|
||||
else {
|
||||
anyhow::bail!("Agent panel not available");
|
||||
};
|
||||
|
||||
let response = client
|
||||
.request(proto::GetSharedAgentThread {
|
||||
session_id: session_id.clone(),
|
||||
})
|
||||
.await
|
||||
.context("Failed to fetch shared thread")?;
|
||||
|
||||
let shared_thread = SharedThread::from_bytes(&response.thread_data)?;
|
||||
let db_thread = shared_thread.to_db_thread();
|
||||
let session_id = agent_client_protocol::SessionId::new(session_id);
|
||||
|
||||
history_store
|
||||
.update(&mut cx.clone(), |store, cx| {
|
||||
store.save_thread(session_id.clone(), db_thread, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let thread_metadata = agent::DbThreadMetadata {
|
||||
id: session_id,
|
||||
title: format!("🔗 {}", response.title).into(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
workspace.update(cx, |workspace, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.open_thread(thread_metadata, window, cx);
|
||||
});
|
||||
panel.focus_handle(cx).focus(window, cx);
|
||||
}
|
||||
})?;
|
||||
|
||||
workspace.update(cx, |workspace, _window, cx| {
|
||||
struct ImportedThreadToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<ImportedThreadToast>(),
|
||||
format!("Imported shared thread from {}", response.sharer_username),
|
||||
)
|
||||
.autohide(),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
OpenRequestKind::DockMenuAction { index } => {
|
||||
cx.perform_dock_menu_action(index);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ pub mod edit_prediction_registry;
|
|||
pub(crate) mod mac_only_instance;
|
||||
mod migrate;
|
||||
mod open_listener;
|
||||
mod open_url_modal;
|
||||
mod quick_action_bar;
|
||||
#[cfg(all(target_os = "macos", any(test, feature = "test-support")))]
|
||||
pub mod visual_tests;
|
||||
|
|
@ -141,6 +142,8 @@ actions!(
|
|||
/// audio system (including yourself) on the current call in a tar file
|
||||
/// in the current working directory.
|
||||
CaptureRecentAudio,
|
||||
/// Opens a prompt to enter a URL to open.
|
||||
OpenUrlPrompt,
|
||||
]
|
||||
);
|
||||
|
||||
|
|
@ -823,6 +826,11 @@ fn register_actions(
|
|||
..Default::default()
|
||||
})
|
||||
})
|
||||
.register_action(|workspace, _: &OpenUrlPrompt, window, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
open_url_modal::OpenUrlModal::new(window, cx)
|
||||
});
|
||||
})
|
||||
.register_action(|workspace, action: &OpenBrowser, _window, cx| {
|
||||
// Parse and validate the URL to ensure it's properly formatted
|
||||
match url::Url::parse(&action.url) {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ pub enum OpenRequestKind {
|
|||
extension_id: String,
|
||||
},
|
||||
AgentPanel,
|
||||
SharedAgentThread {
|
||||
session_id: String,
|
||||
},
|
||||
DockMenuAction {
|
||||
index: usize,
|
||||
},
|
||||
|
|
@ -107,6 +110,14 @@ impl OpenRequest {
|
|||
});
|
||||
} else if url == "zed://agent" {
|
||||
this.kind = Some(OpenRequestKind::AgentPanel);
|
||||
} else if let Some(session_id_str) = url.strip_prefix("zed://agent/shared/") {
|
||||
if uuid::Uuid::parse_str(session_id_str).is_ok() {
|
||||
this.kind = Some(OpenRequestKind::SharedAgentThread {
|
||||
session_id: session_id_str.to_string(),
|
||||
});
|
||||
} else {
|
||||
log::error!("Invalid session ID in URL: {}", session_id_str);
|
||||
}
|
||||
} else if let Some(schema_path) = url.strip_prefix("zed://schemas/") {
|
||||
this.kind = Some(OpenRequestKind::BuiltinJsonSchema {
|
||||
schema_path: schema_path.to_string(),
|
||||
|
|
|
|||
116
crates/zed/src/zed/open_url_modal.rs
Normal file
116
crates/zed/src/zed/open_url_modal.rs
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
use editor::Editor;
|
||||
use gpui::{AppContext as _, DismissEvent, Entity, EventEmitter, Focusable, ReadGlobal, Styled};
|
||||
use ui::{
|
||||
ActiveTheme, App, Color, Context, FluentBuilder, InteractiveElement, IntoElement, Label,
|
||||
LabelCommon, LabelSize, ParentElement, Render, SharedString, StyledExt, Window, div, h_flex,
|
||||
v_flex,
|
||||
};
|
||||
use workspace::ModalView;
|
||||
|
||||
use super::{OpenListener, RawOpenRequest};
|
||||
|
||||
pub struct OpenUrlModal {
|
||||
editor: Entity<Editor>,
|
||||
last_error: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for OpenUrlModal {}
|
||||
impl ModalView for OpenUrlModal {}
|
||||
|
||||
impl Focusable for OpenUrlModal {
|
||||
fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
|
||||
self.editor.focus_handle(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenUrlModal {
|
||||
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::single_line(window, cx);
|
||||
editor.set_placeholder_text("zed://...", window, cx);
|
||||
editor
|
||||
});
|
||||
|
||||
Self {
|
||||
editor,
|
||||
last_error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let url = self.editor.update(cx, |editor, cx| {
|
||||
let text = editor.text(cx).trim().to_string();
|
||||
editor.clear(window, cx);
|
||||
text
|
||||
});
|
||||
|
||||
if url.is_empty() {
|
||||
cx.emit(DismissEvent);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle zed:// URLs internally.
|
||||
if url.starts_with("zed://") || url.starts_with("zed-cli://") {
|
||||
OpenListener::global(cx).open(RawOpenRequest {
|
||||
urls: vec![url],
|
||||
..Default::default()
|
||||
});
|
||||
cx.emit(DismissEvent);
|
||||
return;
|
||||
}
|
||||
|
||||
match url::Url::parse(&url) {
|
||||
Ok(parsed_url) => {
|
||||
cx.open_url(parsed_url.as_str());
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
Err(e) => {
|
||||
self.last_error = Some(format!("Invalid URL: {}", e).into());
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for OpenUrlModal {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let theme = cx.theme();
|
||||
|
||||
v_flex()
|
||||
.key_context("OpenUrlModal")
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::confirm))
|
||||
.elevation_3(cx)
|
||||
.w_96()
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
div()
|
||||
.p_2()
|
||||
.border_b_1()
|
||||
.border_color(theme.colors().border_variant)
|
||||
.child(self.editor.clone()),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.bg(theme.colors().editor_background)
|
||||
.rounded_b_sm()
|
||||
.w_full()
|
||||
.p_2()
|
||||
.gap_1()
|
||||
.when_some(self.last_error.clone(), |this, error| {
|
||||
this.child(Label::new(error).size(LabelSize::Small).color(Color::Error))
|
||||
})
|
||||
.when(self.last_error.is_none(), |this| {
|
||||
this.child(
|
||||
Label::new("Paste a URL to open.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue