Add agent thread sharing (#46140)

Staff only ship for now

Here's the agent planning doc that guided this:
https://gist.github.com/mikayla-maki/c826b7997bd85b58273c1def9397940b

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Mikayla Maki 2026-01-06 12:49:51 -08:00 committed by GitHub
parent 8df27897e3
commit 3da926981c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 1132 additions and 13 deletions

5
Cargo.lock generated
View file

@ -3280,7 +3280,10 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
"agent",
"agent-client-protocol",
"agent_settings",
"agent_ui",
"anyhow",
"assistant_slash_command",
"assistant_text_thread",
@ -20662,6 +20665,8 @@ version = "0.219.0"
dependencies = [
"acp_tools",
"activity_indicator",
"agent",
"agent-client-protocol",
"agent_settings",
"agent_ui",
"agent_ui_v2",

6
Procfile.all Normal file
View file

@ -0,0 +1,6 @@
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
cloud: cd ../cloud; cargo make dev
dashboard: cd ../cloud/packages/dashboard; pnpm dev
website: cd ../zed.dev; pnpm dev --port=3000
livekit: livekit-server --dev
blob_store: ./script/run-local-minio

View file

@ -50,6 +50,63 @@ pub struct DbThread {
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub profile: Option<AgentProfileId>,
#[serde(default)]
pub imported: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SharedThread {
pub title: SharedString,
pub messages: Vec<DbMessage>,
pub updated_at: DateTime<Utc>,
#[serde(default)]
pub model: Option<DbLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
pub version: String,
}
impl SharedThread {
pub const VERSION: &'static str = "1.0.0";
pub fn from_db_thread(thread: &DbThread) -> Self {
Self {
title: thread.title.clone(),
messages: thread.messages.clone(),
updated_at: thread.updated_at,
model: thread.model.clone(),
completion_mode: thread.completion_mode,
version: Self::VERSION.to_string(),
}
}
pub fn to_db_thread(self) -> DbThread {
DbThread {
title: format!("🔗 {}", self.title).into(),
messages: self.messages,
updated_at: self.updated_at,
detailed_summary: None,
initial_project_snapshot: None,
cumulative_token_usage: Default::default(),
request_token_usage: Default::default(),
model: self.model,
completion_mode: self.completion_mode,
profile: None,
imported: true,
}
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
const COMPRESSION_LEVEL: i32 = 3;
let json = serde_json::to_vec(self)?;
let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
Ok(compressed)
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
let decompressed = zstd::decode_all(data)?;
Ok(serde_json::from_slice(&decompressed)?)
}
}
impl DbThread {
@ -209,6 +266,7 @@ impl DbThread {
model: thread.model,
completion_mode: thread.completion_mode,
profile: thread.profile,
imported: false,
})
}
}
@ -441,3 +499,45 @@ impl ThreadsDatabase {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
#[test]
fn test_shared_thread_roundtrip() {
let original = SharedThread {
title: "Test Thread".into(),
messages: vec![],
updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
model: None,
completion_mode: None,
version: SharedThread::VERSION.to_string(),
};
let bytes = original.to_bytes().expect("Failed to serialize");
let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
assert_eq!(restored.title, original.title);
assert_eq!(restored.version, original.version);
assert_eq!(restored.updated_at, original.updated_at);
}
#[test]
fn test_imported_flag_defaults_to_false() {
// Simulate deserializing a thread without the imported field (backwards compatibility).
let json = r#"{
"title": "Old Thread",
"messages": [],
"updated_at": "2024-01-01T00:00:00Z"
}"#;
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
assert!(
!db_thread.imported,
"Legacy threads without imported field should default to false"
);
}
}

View file

@ -175,6 +175,20 @@ impl HistoryStore {
})
}
pub fn save_thread(
&mut self,
id: acp::SessionId,
thread: crate::DbThread,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
database.save_thread(id, thread).await?;
this.update(cx, |this, cx| this.reload(cx))
})
}
pub fn delete_thread(
&mut self,
id: acp::SessionId,

View file

@ -44,7 +44,7 @@ pub struct SerializedThread {
pub profile: Option<AgentProfileId>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,

View file

@ -622,6 +622,8 @@ pub struct Thread {
pub(crate) action_log: Entity<ActionLog>,
/// Tracks the last time files were read by the agent, to detect external modifications
pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
/// True if this thread was imported from a shared thread and can be synced.
imported: bool,
}
impl Thread {
@ -678,6 +680,7 @@ impl Thread {
project,
action_log,
file_read_times: HashMap::default(),
imported: false,
}
}
@ -685,6 +688,11 @@ impl Thread {
&self.id
}
/// Returns true if this thread was imported from a shared thread.
pub fn is_imported(&self) -> bool {
self.imported
}
pub fn replay(
&mut self,
cx: &mut Context<Self>,
@ -866,6 +874,7 @@ impl Thread {
prompt_capabilities_tx,
prompt_capabilities_rx,
file_read_times: HashMap::default(),
imported: db_thread.imported,
}
}
@ -885,6 +894,7 @@ impl Thread {
}),
completion_mode: Some(self.completion_mode),
profile: Some(self.profile_id.clone()),
imported: self.imported,
};
cx.background_spawn(async move {

View file

@ -5,7 +5,9 @@ use acp_thread::{
};
use acp_thread::{AgentConnection, Plan};
use action_log::{ActionLog, ActionLogTelemetry};
use agent::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer};
use agent::{
DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer, SharedThread,
};
use agent_client_protocol::{self as acp, PromptCapabilities};
use agent_servers::{AgentServer, AgentServerDelegate};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
@ -20,15 +22,16 @@ use editor::scroll::Autoscroll;
use editor::{
Editor, EditorEvent, EditorMode, MultiBuffer, PathKey, SelectionEffects, SizingBehavior,
};
use feature_flags::{AgentSharingFeatureFlag, FeatureFlagAppExt};
use file_icons::FileIcons;
use fs::Fs;
use futures::FutureExt as _;
use gpui::{
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, CursorStyle,
EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset,
ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task, TextStyle,
TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div, ease_in_out,
linear_color_stop, linear_gradient, list, point, pulsating_between,
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
CursorStyle, EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length,
ListOffset, ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task,
TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
ease_in_out, linear_color_stop, linear_gradient, list, point, pulsating_between,
};
use language::Buffer;
@ -52,7 +55,7 @@ use ui::{
WithScrollbar, prelude::*, right_click_menu,
};
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
use workspace::{CollaboratorId, NewTerminal, Workspace};
use workspace::{CollaboratorId, NewTerminal, Toast, Workspace, notifications::NotificationId};
use zed_actions::agent::{Chat, ToggleModelSelector};
use zed_actions::assistant::OpenRulesLibrary;
@ -935,6 +938,124 @@ impl AcpThreadView {
}
}
fn share_thread(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
let Some(thread) = self.as_native_thread(cx) else {
return;
};
let client = self.project.read(cx).client();
let workspace = self.workspace.clone();
let session_id = thread.read(cx).id().to_string();
let load_task = thread.read(cx).to_db(cx);
cx.spawn(async move |_this, cx| {
let db_thread = load_task.await;
let shared_thread = SharedThread::from_db_thread(&db_thread);
let thread_data = shared_thread.to_bytes()?;
let title = shared_thread.title.to_string();
client
.request(proto::ShareAgentThread {
session_id: session_id.clone(),
title,
thread_data,
})
.await?;
let share_url = client::zed_urls::shared_agent_thread_url(&session_id);
cx.update(|cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
struct ThreadSharedToast;
workspace.show_toast(
Toast::new(
NotificationId::unique::<ThreadSharedToast>(),
"Thread shared!",
)
.on_click(
"Copy URL",
move |_window, cx| {
cx.write_to_clipboard(ClipboardItem::new_string(
share_url.clone(),
));
},
),
cx,
);
});
}
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
fn sync_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if !self.is_imported_thread(cx) {
return;
}
let Some(thread) = self.as_native_thread(cx) else {
return;
};
let client = self.project.read(cx).client();
let history_store = self.history_store.clone();
let session_id = thread.read(cx).id().clone();
cx.spawn_in(window, async move |this, cx| {
let response = client
.request(proto::GetSharedAgentThread {
session_id: session_id.to_string(),
})
.await?;
let shared_thread = SharedThread::from_bytes(&response.thread_data)?;
let db_thread = shared_thread.to_db_thread();
history_store
.update(&mut cx.clone(), |store, cx| {
store.save_thread(session_id.clone(), db_thread, cx)
})?
.await?;
let thread_metadata = agent::DbThreadMetadata {
id: session_id,
title: format!("🔗 {}", response.title).into(),
updated_at: chrono::Utc::now(),
};
this.update_in(cx, |this, window, cx| {
this.resume_thread_metadata = Some(thread_metadata);
this.reset(window, cx);
})?;
this.update_in(cx, |this, _window, cx| {
if let Some(workspace) = this.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
struct ThreadSyncedToast;
workspace.show_toast(
Toast::new(
NotificationId::unique::<ThreadSyncedToast>(),
"Thread synced with latest version",
)
.autohide(),
cx,
);
});
}
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
pub fn expand_message_editor(
&mut self,
_: &ExpandMessageEditor,
@ -4904,6 +5025,13 @@ impl AcpThreadView {
.thread(acp_thread.session_id(), cx)
}
fn is_imported_thread(&self, cx: &App) -> bool {
let Some(thread) = self.as_native_thread(cx) else {
return false;
};
thread.read(cx).is_imported()
}
fn is_using_zed_ai_models(&self, cx: &App) -> bool {
self.as_native_thread(cx)
.and_then(|thread| thread.read(cx).model())
@ -5819,6 +5947,41 @@ impl AcpThreadView {
);
}
if cx.has_flag::<AgentSharingFeatureFlag>()
&& self.is_imported_thread(cx)
&& self
.project
.read(cx)
.client()
.status()
.borrow()
.is_connected()
{
let sync_button = IconButton::new("sync-thread", IconName::ArrowCircle)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text("Sync with source thread"))
.on_click(cx.listener(move |this, _, window, cx| {
this.sync_thread(window, cx);
}));
container = container.child(sync_button);
}
if cx.has_flag::<AgentSharingFeatureFlag>() && !self.is_imported_thread(cx) {
let share_button = IconButton::new("share-thread", IconName::ArrowUpRight)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text("Share Thread"))
.on_click(cx.listener(move |this, _, window, cx| {
this.share_thread(window, cx);
}));
container = container.child(share_button);
}
container
.child(open_as_markdown)
.child(scroll_to_recent_user_prompt)

View file

@ -720,10 +720,25 @@ impl AgentPanel {
&self.prompt_store
}
pub(crate) fn thread_store(&self) -> &Entity<HistoryStore> {
pub fn thread_store(&self) -> &Entity<HistoryStore> {
&self.history_store
}
pub fn open_thread(
&mut self,
thread: DbThreadMetadata,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.external_thread(
Some(crate::ExternalAgent::NativeAgent),
Some(thread),
None,
window,
cx,
);
}
pub(crate) fn context_server_registry(&self) -> &Entity<ContextServerRegistry> {
&self.context_server_registry
}

View file

@ -67,3 +67,7 @@ pub fn edit_prediction_docs(cx: &App) -> String {
server_url = server_url(cx)
)
}
pub fn shared_agent_thread_url(session_id: &str) -> String {
format!("zed://agent/shared/{}", session_id)
}

View file

@ -69,7 +69,10 @@ util.workspace = true
uuid.workspace = true
[dev-dependencies]
agent = { workspace = true, features = ["test-support"] }
agent-client-protocol.workspace = true
agent_settings.workspace = true
agent_ui = { workspace = true, features = ["test-support"] }
assistant_text_thread.workspace = true
assistant_slash_command.workspace = true
async-trait.workspace = true

View file

@ -460,3 +460,14 @@ CREATE TABLE IF NOT EXISTS "breakpoints" (
);
CREATE INDEX "index_breakpoints_on_project_id" ON "breakpoints" ("project_id");
CREATE TABLE IF NOT EXISTS "shared_threads" (
"id" TEXT PRIMARY KEY NOT NULL,
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"title" VARCHAR(512) NOT NULL,
"data" BLOB NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX "index_shared_threads_user_id" ON "shared_threads" ("user_id");

View file

@ -430,6 +430,15 @@ CREATE SEQUENCE public.servers_id_seq
ALTER SEQUENCE public.servers_id_seq OWNED BY public.servers.id;
CREATE TABLE public.shared_threads (
id uuid NOT NULL,
user_id integer NOT NULL,
title text NOT NULL,
data bytea NOT NULL,
created_at timestamp without time zone DEFAULT now() NOT NULL,
updated_at timestamp without time zone DEFAULT now() NOT NULL
);
CREATE TABLE public.user_features (
user_id integer NOT NULL,
feature_id integer NOT NULL
@ -630,6 +639,9 @@ ALTER TABLE ONLY public.rooms
ALTER TABLE ONLY public.servers
ADD CONSTRAINT servers_pkey PRIMARY KEY (id);
ALTER TABLE ONLY public.shared_threads
ADD CONSTRAINT shared_threads_pkey PRIMARY KEY (id);
ALTER TABLE ONLY public.user_features
ADD CONSTRAINT user_features_pkey PRIMARY KEY (user_id, feature_id);
@ -648,6 +660,8 @@ ALTER TABLE ONLY public.worktree_settings_files
ALTER TABLE ONLY public.worktrees
ADD CONSTRAINT worktrees_pkey PRIMARY KEY (project_id, id);
CREATE INDEX idx_shared_threads_user_id ON public.shared_threads USING btree (user_id);
CREATE INDEX index_access_tokens_user_id ON public.access_tokens USING btree (user_id);
CREATE INDEX index_breakpoints_on_project_id ON public.breakpoints USING btree (project_id);
@ -879,6 +893,9 @@ ALTER TABLE ONLY public.room_participants
ALTER TABLE ONLY public.rooms
ADD CONSTRAINT rooms_channel_id_fkey FOREIGN KEY (channel_id) REFERENCES public.channels(id) ON DELETE CASCADE;
ALTER TABLE ONLY public.shared_threads
ADD CONSTRAINT shared_threads_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id) ON DELETE CASCADE;
ALTER TABLE ONLY public.user_features
ADD CONSTRAINT user_features_feature_id_fkey FOREIGN KEY (feature_id) REFERENCES public.feature_flags(id) ON DELETE CASCADE;

View file

@ -2,6 +2,7 @@ use crate::Result;
use rpc::proto;
use sea_orm::{DbErr, entity::prelude::*};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[macro_export]
macro_rules! id_type {
@ -92,6 +93,39 @@ id_type!(ServerId);
id_type!(SignupId);
id_type!(UserId);
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, DeriveValueType)]
pub struct SharedThreadId(pub Uuid);
impl SharedThreadId {
pub fn from_proto(id: String) -> Option<Self> {
Uuid::parse_str(&id).ok().map(SharedThreadId)
}
pub fn to_proto(self) -> String {
self.0.to_string()
}
}
impl sea_orm::TryFromU64 for SharedThreadId {
fn try_from_u64(_n: u64) -> std::result::Result<Self, DbErr> {
Err(DbErr::ConvertFromU64(
"SharedThreadId uses UUID and cannot be converted from u64",
))
}
}
impl sea_orm::sea_query::Nullable for SharedThreadId {
fn null() -> Value {
Value::Uuid(None)
}
}
impl std::fmt::Display for SharedThreadId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.0.fmt(f)
}
}
/// ChannelRole gives you permissions for both channels and calls.
#[derive(
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,

View file

@ -10,4 +10,5 @@ pub mod notifications;
pub mod projects;
pub mod rooms;
pub mod servers;
pub mod shared_threads;
pub mod users;

View file

@ -0,0 +1,77 @@
use chrono::Utc;
use super::*;
use crate::db::tables::shared_thread;
impl Database {
pub async fn upsert_shared_thread(
&self,
id: SharedThreadId,
user_id: UserId,
title: &str,
data: Vec<u8>,
) -> Result<()> {
let title = title.to_string();
self.transaction(|tx| {
let title = title.clone();
let data = data.clone();
async move {
let now = Utc::now().naive_utc();
let existing = shared_thread::Entity::find_by_id(id).one(&*tx).await?;
match existing {
Some(existing) => {
if existing.user_id != user_id {
Err(anyhow!("Cannot update shared thread owned by another user"))?;
}
let mut active: shared_thread::ActiveModel = existing.into();
active.title = ActiveValue::Set(title);
active.data = ActiveValue::Set(data);
active.updated_at = ActiveValue::Set(now);
active.update(&*tx).await?;
}
None => {
shared_thread::ActiveModel {
id: ActiveValue::Set(id),
user_id: ActiveValue::Set(user_id),
title: ActiveValue::Set(title),
data: ActiveValue::Set(data),
created_at: ActiveValue::Set(now),
updated_at: ActiveValue::Set(now),
}
.insert(&*tx)
.await?;
}
}
Ok(())
}
})
.await
}
pub async fn get_shared_thread(
&self,
share_id: SharedThreadId,
) -> Result<Option<(shared_thread::Model, String)>> {
self.transaction(|tx| async move {
let Some(thread) = shared_thread::Entity::find_by_id(share_id)
.one(&*tx)
.await?
else {
return Ok(None);
};
let user = user::Entity::find_by_id(thread.user_id).one(&*tx).await?;
let username = user
.map(|u| u.github_login)
.unwrap_or_else(|| "Unknown".to_string());
Ok(Some((thread, username)))
})
.await
}
}

View file

@ -22,6 +22,7 @@ pub mod project_repository_statuses;
pub mod room;
pub mod room_participant;
pub mod server;
pub mod shared_thread;
pub mod user;
pub mod worktree;
pub mod worktree_diagnostic_summary;

View file

@ -0,0 +1,32 @@
use crate::db::{SharedThreadId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "shared_threads")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: SharedThreadId,
pub user_id: UserId,
pub title: String,
pub data: Vec<u8>,
pub created_at: DateTime,
pub updated_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -586,3 +586,121 @@ async fn test_fuzzy_search_users(cx: &mut gpui::TestAppContext) {
.collect::<Vec<_>>()
}
}
test_both_dbs!(
test_upsert_shared_thread,
test_upsert_shared_thread_postgres,
test_upsert_shared_thread_sqlite
);
async fn test_upsert_shared_thread(db: &Arc<Database>) {
use crate::db::SharedThreadId;
use uuid::Uuid;
let user_id = new_test_user(db, "user1@example.com").await;
let thread_id = SharedThreadId(Uuid::new_v4());
let title = "My Test Thread";
let data = b"test thread data".to_vec();
db.upsert_shared_thread(thread_id, user_id, title, data.clone())
.await
.unwrap();
let result = db.get_shared_thread(thread_id).await.unwrap();
assert!(result.is_some(), "Should find the shared thread");
let (thread, username) = result.unwrap();
assert_eq!(thread.title, title);
assert_eq!(thread.data, data);
assert_eq!(thread.user_id, user_id);
assert_eq!(username, "user1");
}
test_both_dbs!(
test_upsert_shared_thread_updates_existing,
test_upsert_shared_thread_updates_existing_postgres,
test_upsert_shared_thread_updates_existing_sqlite
);
async fn test_upsert_shared_thread_updates_existing(db: &Arc<Database>) {
use crate::db::SharedThreadId;
use uuid::Uuid;
let user_id = new_test_user(db, "user1@example.com").await;
let thread_id = SharedThreadId(Uuid::new_v4());
// Create initial thread.
db.upsert_shared_thread(
thread_id,
user_id,
"Original Title",
b"original data".to_vec(),
)
.await
.unwrap();
// Update the same thread.
db.upsert_shared_thread(
thread_id,
user_id,
"Updated Title",
b"updated data".to_vec(),
)
.await
.unwrap();
let result = db.get_shared_thread(thread_id).await.unwrap();
let (thread, _) = result.unwrap();
assert_eq!(thread.title, "Updated Title");
assert_eq!(thread.data, b"updated data".to_vec());
}
test_both_dbs!(
test_cannot_update_another_users_shared_thread,
test_cannot_update_another_users_shared_thread_postgres,
test_cannot_update_another_users_shared_thread_sqlite
);
async fn test_cannot_update_another_users_shared_thread(db: &Arc<Database>) {
use crate::db::SharedThreadId;
use uuid::Uuid;
let user1_id = new_test_user(db, "user1@example.com").await;
let user2_id = new_test_user(db, "user2@example.com").await;
let thread_id = SharedThreadId(Uuid::new_v4());
db.upsert_shared_thread(thread_id, user1_id, "User 1 Thread", b"user1 data".to_vec())
.await
.unwrap();
let result = db
.upsert_shared_thread(thread_id, user2_id, "User 2 Title", b"user2 data".to_vec())
.await;
assert!(
result.is_err(),
"Should not allow updating another user's thread"
);
}
test_both_dbs!(
test_get_nonexistent_shared_thread,
test_get_nonexistent_shared_thread_postgres,
test_get_nonexistent_shared_thread_sqlite
);
async fn test_get_nonexistent_shared_thread(db: &Arc<Database>) {
use crate::db::SharedThreadId;
use uuid::Uuid;
let result = db
.get_shared_thread(SharedThreadId(Uuid::new_v4()))
.await
.unwrap();
assert!(result.is_none(), "Should not find non-existent thread");
}

View file

@ -6,7 +6,8 @@ use crate::{
db::{
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
UserId,
},
executor::Executor,
};
@ -465,7 +466,9 @@ impl Server {
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
.add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
.add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
.add_request_handler(share_agent_thread)
.add_request_handler(get_shared_agent_thread);
Arc::new(server)
}
@ -4016,6 +4019,54 @@ fn project_left(project: &db::LeftProject, session: &Session) {
}
}
async fn share_agent_thread(
request: proto::ShareAgentThread,
response: Response<proto::ShareAgentThread>,
session: MessageContext,
) -> Result<()> {
let user_id = session.user_id();
let share_id = SharedThreadId::from_proto(request.session_id.clone())
.ok_or_else(|| anyhow!("Invalid session ID format"))?;
session
.db()
.await
.upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
.await?;
response.send(proto::Ack {})?;
Ok(())
}
async fn get_shared_agent_thread(
request: proto::GetSharedAgentThread,
response: Response<proto::GetSharedAgentThread>,
session: MessageContext,
) -> Result<()> {
let share_id = SharedThreadId::from_proto(request.session_id)
.ok_or_else(|| anyhow!("Invalid session ID format"))?;
let result = session.db().await.get_shared_thread(share_id).await?;
match result {
Some((thread, username)) => {
response.send(proto::GetSharedAgentThreadResponse {
title: thread.title,
thread_data: thread.data,
sharer_username: username,
created_at: thread.created_at.and_utc().to_rfc3339(),
})?;
}
None => {
return Err(anyhow!("Shared thread not found").into());
}
}
Ok(())
}
pub trait ResultExt {
type Ok;

View file

@ -2,6 +2,7 @@ use call::Room;
use client::ChannelId;
use gpui::{Entity, TestAppContext};
mod agent_sharing_tests;
mod channel_buffer_tests;
mod channel_guest_tests;
mod channel_tests;

View file

@ -0,0 +1,217 @@
use agent::SharedThread;
use gpui::{BackgroundExecutor, TestAppContext};
use rpc::proto;
use uuid::Uuid;
use crate::tests::TestServer;
#[gpui::test]
async fn test_share_and_retrieve_thread(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
executor.run_until_parked();
let session_id = Uuid::new_v4().to_string();
let original_thread = SharedThread {
title: "Shared Test Thread".into(),
messages: vec![],
updated_at: chrono::Utc::now(),
model: None,
completion_mode: None,
version: SharedThread::VERSION.to_string(),
};
let thread_data = original_thread
.to_bytes()
.expect("Failed to serialize thread");
client_a
.client()
.request(proto::ShareAgentThread {
session_id: session_id.clone(),
title: original_thread.title.to_string(),
thread_data,
})
.await
.expect("Failed to share thread");
let get_response = client_b
.client()
.request(proto::GetSharedAgentThread {
session_id: session_id.clone(),
})
.await
.expect("Failed to get shared thread");
let imported_shared_thread =
SharedThread::from_bytes(&get_response.thread_data).expect("Failed to deserialize thread");
assert_eq!(imported_shared_thread.title, original_thread.title);
assert_eq!(imported_shared_thread.version, SharedThread::VERSION);
let db_thread = imported_shared_thread.to_db_thread();
assert!(
db_thread.title.starts_with("🔗"),
"Imported thread title should have link prefix"
);
assert!(
db_thread.title.contains("Shared Test Thread"),
"Imported thread should preserve original title"
);
}
#[gpui::test]
async fn test_reshare_updates_existing_thread(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
executor.run_until_parked();
let session_id = Uuid::new_v4().to_string();
client_a
.client()
.request(proto::ShareAgentThread {
session_id: session_id.clone(),
title: "Original Title".to_string(),
thread_data: b"original data".to_vec(),
})
.await
.expect("Failed to share thread");
client_a
.client()
.request(proto::ShareAgentThread {
session_id: session_id.clone(),
title: "Updated Title".to_string(),
thread_data: b"updated data".to_vec(),
})
.await
.expect("Failed to re-share thread");
let get_response = client_b
.client()
.request(proto::GetSharedAgentThread {
session_id: session_id.clone(),
})
.await
.expect("Failed to get shared thread");
assert_eq!(get_response.title, "Updated Title");
assert_eq!(get_response.thread_data, b"updated data".to_vec());
}
#[gpui::test]
async fn test_get_nonexistent_thread(executor: BackgroundExecutor, cx: &mut TestAppContext) {
let mut server = TestServer::start(executor.clone()).await;
let client = server.create_client(cx, "user_a").await;
executor.run_until_parked();
let nonexistent_session_id = Uuid::new_v4().to_string();
let result = client
.client()
.request(proto::GetSharedAgentThread {
session_id: nonexistent_session_id,
})
.await;
assert!(result.is_err(), "Should fail for nonexistent thread");
}
#[gpui::test]
async fn test_sync_imported_thread(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
executor.run_until_parked();
let session_id = Uuid::new_v4().to_string();
// User A shares a thread with initial content.
let initial_thread = SharedThread {
title: "Initial Title".into(),
messages: vec![],
updated_at: chrono::Utc::now(),
model: None,
completion_mode: None,
version: SharedThread::VERSION.to_string(),
};
client_a
.client()
.request(proto::ShareAgentThread {
session_id: session_id.clone(),
title: initial_thread.title.to_string(),
thread_data: initial_thread.to_bytes().expect("Failed to serialize"),
})
.await
.expect("Failed to share thread");
// User B imports the thread.
let initial_response = client_b
.client()
.request(proto::GetSharedAgentThread {
session_id: session_id.clone(),
})
.await
.expect("Failed to get shared thread");
let initial_imported =
SharedThread::from_bytes(&initial_response.thread_data).expect("Failed to deserialize");
assert_eq!(initial_imported.title.as_ref(), "Initial Title");
// User A updates the shared thread.
let updated_thread = SharedThread {
title: "Updated Title".into(),
messages: vec![],
updated_at: chrono::Utc::now(),
model: None,
completion_mode: None,
version: SharedThread::VERSION.to_string(),
};
client_a
.client()
.request(proto::ShareAgentThread {
session_id: session_id.clone(),
title: updated_thread.title.to_string(),
thread_data: updated_thread.to_bytes().expect("Failed to serialize"),
})
.await
.expect("Failed to re-share thread");
// User B syncs the imported thread (fetches the latest version).
let synced_response = client_b
.client()
.request(proto::GetSharedAgentThread {
session_id: session_id.clone(),
})
.await
.expect("Failed to sync shared thread");
let synced_thread =
SharedThread::from_bytes(&synced_response.thread_data).expect("Failed to deserialize");
// The synced thread should have the updated title.
assert_eq!(synced_thread.title.as_ref(), "Updated Title");
}

View file

@ -29,3 +29,9 @@ pub struct AcpBetaFeatureFlag;
impl FeatureFlag for AcpBetaFeatureFlag {
const NAME: &'static str = "acp-beta";
}
pub struct AgentSharingFeatureFlag;
impl FeatureFlag for AgentSharingFeatureFlag {
const NAME: &'static str = "agent-sharing";
}

View file

@ -218,3 +218,20 @@ message NewExternalAgentVersionAvailable {
string name = 2;
string version = 3;
}
message ShareAgentThread {
string session_id = 1; // Client-generated UUID (acp::SessionId)
string title = 2;
bytes thread_data = 3;
}
message GetSharedAgentThread {
string session_id = 1; // UUID string
}
message GetSharedAgentThreadResponse {
string title = 1;
bytes thread_data = 2;
string sharer_username = 3;
string created_at = 4;
}

View file

@ -447,7 +447,11 @@ message Envelope {
GitRemoveRemote git_remove_remote = 403;
TrustWorktrees trust_worktrees = 404;
RestrictWorktrees restrict_worktrees = 405; // current max
RestrictWorktrees restrict_worktrees = 405;
ShareAgentThread share_agent_thread = 406;
GetSharedAgentThread get_shared_agent_thread = 407;
GetSharedAgentThreadResponse get_shared_agent_thread_response = 408; // current max
}
reserved 87 to 88;

View file

@ -342,7 +342,10 @@ messages!(
(RemoteStarted, Background),
(GitGetWorktrees, Background),
(GitWorktreesResponse, Background),
(GitCreateWorktree, Background)
(GitCreateWorktree, Background),
(ShareAgentThread, Foreground),
(GetSharedAgentThread, Foreground),
(GetSharedAgentThreadResponse, Foreground)
);
request_messages!(
@ -441,6 +444,8 @@ request_messages!(
(SendChannelMessage, SendChannelMessageResponse),
(SetChannelMemberRole, Ack),
(SetChannelVisibility, Ack),
(ShareAgentThread, Ack),
(GetSharedAgentThread, GetSharedAgentThreadResponse),
(ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse),
(TaskContextForLocation, TaskContext),

View file

@ -50,6 +50,8 @@ required-features = ["visual-tests"]
[dependencies]
acp_tools.workspace = true
activity_indicator.workspace = true
agent.workspace = true
agent-client-protocol.workspace = true
agent_settings.workspace = true
agent_ui.workspace = true
agent_ui_v2.workspace = true

View file

@ -4,6 +4,8 @@
mod reliability;
mod zed;
use agent::{HistoryStore, SharedThread};
use agent_client_protocol;
use agent_ui::AgentPanel;
use anyhow::{Context as _, Error, Result};
use clap::Parser;
@ -33,6 +35,7 @@ use assets::Assets;
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use parking_lot::Mutex;
use project::{project_settings::ProjectSettings, trusted_worktrees};
use proto;
use recent_projects::{RemoteSettings, open_remote_project};
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
use session::{AppSession, Session};
@ -837,6 +840,73 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
})
.detach_and_log_err(cx);
}
OpenRequestKind::SharedAgentThread { session_id } => {
cx.spawn(async move |cx| {
let workspace =
workspace::get_any_active_workspace(app_state.clone(), cx.clone()).await?;
let (client, history_store) =
workspace.update(cx, |workspace, _window, cx| {
let client = workspace.project().read(cx).client();
let history_store: Option<gpui::Entity<HistoryStore>> = workspace
.panel::<AgentPanel>(cx)
.map(|panel| panel.read(cx).thread_store().clone());
(client, history_store)
})?;
let Some(history_store): Option<gpui::Entity<HistoryStore>> = history_store
else {
anyhow::bail!("Agent panel not available");
};
let response = client
.request(proto::GetSharedAgentThread {
session_id: session_id.clone(),
})
.await
.context("Failed to fetch shared thread")?;
let shared_thread = SharedThread::from_bytes(&response.thread_data)?;
let db_thread = shared_thread.to_db_thread();
let session_id = agent_client_protocol::SessionId::new(session_id);
history_store
.update(&mut cx.clone(), |store, cx| {
store.save_thread(session_id.clone(), db_thread, cx)
})?
.await?;
let thread_metadata = agent::DbThreadMetadata {
id: session_id,
title: format!("🔗 {}", response.title).into(),
updated_at: chrono::Utc::now(),
};
workspace.update(cx, |workspace, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
panel.open_thread(thread_metadata, window, cx);
});
panel.focus_handle(cx).focus(window, cx);
}
})?;
workspace.update(cx, |workspace, _window, cx| {
struct ImportedThreadToast;
workspace.show_toast(
Toast::new(
NotificationId::unique::<ImportedThreadToast>(),
format!("Imported shared thread from {}", response.sharer_username),
)
.autohide(),
cx,
);
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
OpenRequestKind::DockMenuAction { index } => {
cx.perform_dock_menu_action(index);
}

View file

@ -4,6 +4,7 @@ pub mod edit_prediction_registry;
pub(crate) mod mac_only_instance;
mod migrate;
mod open_listener;
mod open_url_modal;
mod quick_action_bar;
#[cfg(all(target_os = "macos", any(test, feature = "test-support")))]
pub mod visual_tests;
@ -141,6 +142,8 @@ actions!(
/// audio system (including yourself) on the current call in a tar file
/// in the current working directory.
CaptureRecentAudio,
/// Opens a prompt to enter a URL to open.
OpenUrlPrompt,
]
);
@ -823,6 +826,11 @@ fn register_actions(
..Default::default()
})
})
.register_action(|workspace, _: &OpenUrlPrompt, window, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
open_url_modal::OpenUrlModal::new(window, cx)
});
})
.register_action(|workspace, action: &OpenBrowser, _window, cx| {
// Parse and validate the URL to ensure it's properly formatted
match url::Url::parse(&action.url) {

View file

@ -49,6 +49,9 @@ pub enum OpenRequestKind {
extension_id: String,
},
AgentPanel,
SharedAgentThread {
session_id: String,
},
DockMenuAction {
index: usize,
},
@ -107,6 +110,14 @@ impl OpenRequest {
});
} else if url == "zed://agent" {
this.kind = Some(OpenRequestKind::AgentPanel);
} else if let Some(session_id_str) = url.strip_prefix("zed://agent/shared/") {
if uuid::Uuid::parse_str(session_id_str).is_ok() {
this.kind = Some(OpenRequestKind::SharedAgentThread {
session_id: session_id_str.to_string(),
});
} else {
log::error!("Invalid session ID in URL: {}", session_id_str);
}
} else if let Some(schema_path) = url.strip_prefix("zed://schemas/") {
this.kind = Some(OpenRequestKind::BuiltinJsonSchema {
schema_path: schema_path.to_string(),

View file

@ -0,0 +1,116 @@
use editor::Editor;
use gpui::{AppContext as _, DismissEvent, Entity, EventEmitter, Focusable, ReadGlobal, Styled};
use ui::{
ActiveTheme, App, Color, Context, FluentBuilder, InteractiveElement, IntoElement, Label,
LabelCommon, LabelSize, ParentElement, Render, SharedString, StyledExt, Window, div, h_flex,
v_flex,
};
use workspace::ModalView;
use super::{OpenListener, RawOpenRequest};
pub struct OpenUrlModal {
editor: Entity<Editor>,
last_error: Option<SharedString>,
}
impl EventEmitter<DismissEvent> for OpenUrlModal {}
impl ModalView for OpenUrlModal {}
impl Focusable for OpenUrlModal {
fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
self.editor.focus_handle(cx)
}
}
impl OpenUrlModal {
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
let editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text("zed://...", window, cx);
editor
});
Self {
editor,
last_error: None,
}
}
fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let url = self.editor.update(cx, |editor, cx| {
let text = editor.text(cx).trim().to_string();
editor.clear(window, cx);
text
});
if url.is_empty() {
cx.emit(DismissEvent);
return;
}
// Handle zed:// URLs internally.
if url.starts_with("zed://") || url.starts_with("zed-cli://") {
OpenListener::global(cx).open(RawOpenRequest {
urls: vec![url],
..Default::default()
});
cx.emit(DismissEvent);
return;
}
match url::Url::parse(&url) {
Ok(parsed_url) => {
cx.open_url(parsed_url.as_str());
cx.emit(DismissEvent);
}
Err(e) => {
self.last_error = Some(format!("Invalid URL: {}", e).into());
cx.notify();
}
}
}
}
impl Render for OpenUrlModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let theme = cx.theme();
v_flex()
.key_context("OpenUrlModal")
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::confirm))
.elevation_3(cx)
.w_96()
.overflow_hidden()
.child(
div()
.p_2()
.border_b_1()
.border_color(theme.colors().border_variant)
.child(self.editor.clone()),
)
.child(
h_flex()
.bg(theme.colors().editor_background)
.rounded_b_sm()
.w_full()
.p_2()
.gap_1()
.when_some(self.last_error.clone(), |this, error| {
this.child(Label::new(error).size(LabelSize::Small).color(Color::Error))
})
.when(self.last_error.is_none(), |this| {
this.child(
Label::new("Paste a URL to open.")
.color(Color::Muted)
.size(LabelSize::Small),
)
}),
)
}
}