mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
1072 lines
37 KiB
Rust
1072 lines
37 KiB
Rust
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
|
use acp_thread::UserMessageId;
|
|
use agent_client_protocol::schema as acp;
|
|
use agent_settings::AgentProfileId;
|
|
use anyhow::{Result, anyhow};
|
|
use chrono::{DateTime, Utc};
|
|
use collections::{HashMap, IndexMap};
|
|
use futures::{FutureExt, future::Shared};
|
|
use gpui::{BackgroundExecutor, Global, Task};
|
|
use indoc::indoc;
|
|
use language_model::Speed;
|
|
use parking_lot::Mutex;
|
|
use serde::{Deserialize, Serialize};
|
|
use sqlez::{
|
|
bindable::{Bind, Column},
|
|
connection::Connection,
|
|
statement::Statement,
|
|
};
|
|
use std::{io::ErrorKind, path::PathBuf, sync::Arc};
|
|
use ui::{App, SharedString};
|
|
use util::path_list::PathList;
|
|
use zed_env_vars::ZED_STATELESS;
|
|
|
|
pub type DbMessage = crate::Message;
|
|
pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
|
|
pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct DbThreadMetadata {
|
|
pub id: acp::SessionId,
|
|
pub parent_session_id: Option<acp::SessionId>,
|
|
pub title: SharedString,
|
|
pub updated_at: DateTime<Utc>,
|
|
pub created_at: Option<DateTime<Utc>>,
|
|
/// The workspace folder paths this thread was created against, sorted
|
|
/// lexicographically. Used for grouping threads by project in the sidebar.
|
|
pub folder_paths: PathList,
|
|
}
|
|
|
|
impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
|
|
fn from(meta: &DbThreadMetadata) -> Self {
|
|
Self {
|
|
session_id: meta.id.clone(),
|
|
work_dirs: Some(meta.folder_paths.clone()),
|
|
title: Some(meta.title.clone()),
|
|
updated_at: Some(meta.updated_at),
|
|
created_at: meta.created_at,
|
|
meta: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct DbThread {
|
|
pub title: SharedString,
|
|
pub messages: Vec<Arc<DbMessage>>,
|
|
pub updated_at: DateTime<Utc>,
|
|
#[serde(default)]
|
|
pub detailed_summary: Option<SharedString>,
|
|
#[serde(default)]
|
|
pub initial_project_snapshot: Option<Arc<crate::ProjectSnapshot>>,
|
|
#[serde(default)]
|
|
pub cumulative_token_usage: language_model::TokenUsage,
|
|
#[serde(default)]
|
|
pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
|
|
#[serde(default)]
|
|
pub model: Option<DbLanguageModel>,
|
|
#[serde(default)]
|
|
pub profile: Option<AgentProfileId>,
|
|
#[serde(default)]
|
|
pub subagent_context: Option<crate::SubagentContext>,
|
|
#[serde(default)]
|
|
pub speed: Option<Speed>,
|
|
#[serde(default)]
|
|
pub thinking_enabled: bool,
|
|
#[serde(default)]
|
|
pub thinking_effort: Option<String>,
|
|
#[serde(default)]
|
|
pub draft_prompt: Option<Vec<acp::ContentBlock>>,
|
|
#[serde(default)]
|
|
pub ui_scroll_position: Option<SerializedScrollPosition>,
|
|
#[serde(default)]
|
|
pub sandboxed_terminal_temp_dir: Option<PathBuf>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
|
pub struct SerializedScrollPosition {
|
|
pub item_ix: usize,
|
|
pub offset_in_item: f32,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SharedThread {
|
|
pub title: SharedString,
|
|
pub messages: Vec<Arc<DbMessage>>,
|
|
pub updated_at: DateTime<Utc>,
|
|
#[serde(default)]
|
|
pub model: Option<DbLanguageModel>,
|
|
pub version: String,
|
|
}
|
|
|
|
impl SharedThread {
|
|
pub const VERSION: &'static str = "1.0.0";
|
|
|
|
pub fn from_db_thread(thread: &DbThread) -> Self {
|
|
Self {
|
|
title: thread.title.clone(),
|
|
messages: thread.messages.clone(),
|
|
updated_at: thread.updated_at,
|
|
model: thread.model.clone(),
|
|
version: Self::VERSION.to_string(),
|
|
}
|
|
}
|
|
|
|
pub fn to_db_thread(self) -> DbThread {
|
|
DbThread {
|
|
title: format!("🔗 {}", self.title).into(),
|
|
messages: self.messages,
|
|
updated_at: self.updated_at,
|
|
detailed_summary: None,
|
|
initial_project_snapshot: None,
|
|
cumulative_token_usage: Default::default(),
|
|
request_token_usage: Default::default(),
|
|
model: self.model,
|
|
profile: None,
|
|
subagent_context: None,
|
|
speed: None,
|
|
thinking_enabled: false,
|
|
thinking_effort: None,
|
|
draft_prompt: None,
|
|
ui_scroll_position: None,
|
|
sandboxed_terminal_temp_dir: None,
|
|
}
|
|
}
|
|
|
|
pub fn to_bytes(&self) -> Result<Vec<u8>> {
|
|
const COMPRESSION_LEVEL: i32 = 3;
|
|
let json = serde_json::to_vec(self)?;
|
|
let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
|
|
Ok(compressed)
|
|
}
|
|
|
|
pub fn from_bytes(data: &[u8]) -> Result<Self> {
|
|
let decompressed = zstd::decode_all(data)?;
|
|
Ok(serde_json::from_slice(&decompressed)?)
|
|
}
|
|
}
|
|
|
|
impl DbThread {
|
|
pub const VERSION: &'static str = "0.3.0";
|
|
|
|
pub fn from_json(json: &[u8]) -> Result<Self> {
|
|
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
|
match saved_thread_json.get("version") {
|
|
Some(serde_json::Value::String(version)) => match version.as_str() {
|
|
Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
|
|
_ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(
|
|
json,
|
|
)?),
|
|
},
|
|
_ => {
|
|
Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result<Self> {
|
|
let mut messages = Vec::new();
|
|
let mut request_token_usage = HashMap::default();
|
|
|
|
let mut last_user_message_id = None;
|
|
for (ix, msg) in thread.messages.into_iter().enumerate() {
|
|
let message = match msg.role {
|
|
language_model::Role::User => {
|
|
let mut content = Vec::new();
|
|
|
|
// Convert segments to content
|
|
for segment in msg.segments {
|
|
match segment {
|
|
crate::legacy_thread::SerializedMessageSegment::Text { text } => {
|
|
content.push(UserMessageContent::Text(text));
|
|
}
|
|
crate::legacy_thread::SerializedMessageSegment::Thinking {
|
|
text,
|
|
..
|
|
} => {
|
|
// User messages don't have thinking segments, but handle gracefully
|
|
content.push(UserMessageContent::Text(text));
|
|
}
|
|
crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
|
|
..
|
|
} => {
|
|
// User messages don't have redacted thinking, skip.
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no content was added, add context as text if available
|
|
if content.is_empty() && !msg.context.is_empty() {
|
|
content.push(UserMessageContent::Text(msg.context));
|
|
}
|
|
|
|
let id = UserMessageId::new();
|
|
last_user_message_id = Some(id.clone());
|
|
|
|
crate::Message::User(UserMessage {
|
|
// MessageId from old format can't be meaningfully converted, so generate a new one
|
|
id,
|
|
content: Arc::from(content),
|
|
})
|
|
}
|
|
language_model::Role::Assistant => {
|
|
let mut content = Vec::new();
|
|
|
|
// Convert segments to content
|
|
for segment in msg.segments {
|
|
match segment {
|
|
crate::legacy_thread::SerializedMessageSegment::Text { text } => {
|
|
content.push(AgentMessageContent::Text(text));
|
|
}
|
|
crate::legacy_thread::SerializedMessageSegment::Thinking {
|
|
text,
|
|
signature,
|
|
} => {
|
|
content.push(AgentMessageContent::Thinking { text, signature });
|
|
}
|
|
crate::legacy_thread::SerializedMessageSegment::RedactedThinking {
|
|
data,
|
|
} => {
|
|
content.push(AgentMessageContent::RedactedThinking(data));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert tool uses
|
|
let mut tool_names_by_id = HashMap::default();
|
|
for tool_use in msg.tool_uses {
|
|
tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
|
|
content.push(AgentMessageContent::ToolUse(
|
|
language_model::LanguageModelToolUse {
|
|
id: tool_use.id,
|
|
name: tool_use.name.into(),
|
|
raw_input: serde_json::to_string(&tool_use.input)
|
|
.unwrap_or_default(),
|
|
input: tool_use.input,
|
|
is_input_complete: true,
|
|
thought_signature: None,
|
|
},
|
|
));
|
|
}
|
|
|
|
// Convert tool results
|
|
let mut tool_results = IndexMap::default();
|
|
for tool_result in msg.tool_results {
|
|
let name = tool_names_by_id
|
|
.remove(&tool_result.tool_use_id)
|
|
.unwrap_or_else(|| SharedString::from("unknown"));
|
|
tool_results.insert(
|
|
tool_result.tool_use_id.clone(),
|
|
language_model::LanguageModelToolResult {
|
|
tool_use_id: tool_result.tool_use_id,
|
|
tool_name: name.into(),
|
|
is_error: tool_result.is_error,
|
|
content: vec![tool_result.content],
|
|
output: tool_result.output,
|
|
},
|
|
);
|
|
}
|
|
|
|
if let Some(last_user_message_id) = &last_user_message_id
|
|
&& let Some(token_usage) = thread.request_token_usage.get(ix).copied()
|
|
{
|
|
request_token_usage.insert(last_user_message_id.clone(), token_usage);
|
|
}
|
|
|
|
crate::Message::Agent(AgentMessage {
|
|
content,
|
|
tool_results,
|
|
reasoning_details: None,
|
|
})
|
|
}
|
|
language_model::Role::System => {
|
|
// Skip system messages as they're not supported in the new format
|
|
continue;
|
|
}
|
|
};
|
|
|
|
messages.push(Arc::new(message));
|
|
}
|
|
|
|
Ok(Self {
|
|
title: thread.summary,
|
|
messages,
|
|
updated_at: thread.updated_at,
|
|
detailed_summary: match thread.detailed_summary_state {
|
|
crate::legacy_thread::DetailedSummaryState::NotGenerated
|
|
| crate::legacy_thread::DetailedSummaryState::Generating => None,
|
|
crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text),
|
|
},
|
|
initial_project_snapshot: thread.initial_project_snapshot,
|
|
cumulative_token_usage: thread.cumulative_token_usage,
|
|
request_token_usage,
|
|
model: thread.model,
|
|
profile: thread.profile,
|
|
subagent_context: None,
|
|
speed: None,
|
|
thinking_enabled: false,
|
|
thinking_effort: None,
|
|
draft_prompt: None,
|
|
ui_scroll_position: None,
|
|
sandboxed_terminal_temp_dir: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum DataType {
|
|
#[serde(rename = "json")]
|
|
Json,
|
|
#[serde(rename = "zstd")]
|
|
Zstd,
|
|
}
|
|
|
|
impl Bind for DataType {
|
|
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
|
let value = match self {
|
|
DataType::Json => "json",
|
|
DataType::Zstd => "zstd",
|
|
};
|
|
value.bind(statement, start_index)
|
|
}
|
|
}
|
|
|
|
impl Column for DataType {
|
|
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
|
let (value, next_index) = String::column(statement, start_index)?;
|
|
let data_type = match value.as_str() {
|
|
"json" => DataType::Json,
|
|
"zstd" => DataType::Zstd,
|
|
_ => anyhow::bail!("Unknown data type: {}", value),
|
|
};
|
|
Ok((data_type, next_index))
|
|
}
|
|
}
|
|
|
|
pub(crate) struct ThreadsDatabase {
|
|
executor: BackgroundExecutor,
|
|
connection: Arc<Mutex<Connection>>,
|
|
}
|
|
|
|
struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
|
|
|
|
impl Global for GlobalThreadsDatabase {}
|
|
|
|
impl ThreadsDatabase {
|
|
pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
|
|
if cx.has_global::<GlobalThreadsDatabase>() {
|
|
return cx.global::<GlobalThreadsDatabase>().0.clone();
|
|
}
|
|
let executor = cx.background_executor().clone();
|
|
let task = executor
|
|
.spawn({
|
|
let executor = executor.clone();
|
|
async move {
|
|
match ThreadsDatabase::new(executor) {
|
|
Ok(db) => Ok(Arc::new(db)),
|
|
Err(err) => Err(Arc::new(err)),
|
|
}
|
|
}
|
|
})
|
|
.shared();
|
|
|
|
cx.set_global(GlobalThreadsDatabase(task.clone()));
|
|
task
|
|
}
|
|
|
|
pub fn new(executor: BackgroundExecutor) -> Result<Self> {
|
|
let connection = if *ZED_STATELESS {
|
|
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
|
} else if cfg!(any(feature = "test-support", test)) {
|
|
// rust stores the name of the test on the current thread.
|
|
// We use this to automatically create a database that will
|
|
// be shared within the test (for the test_retrieve_old_thread)
|
|
// but not with concurrent tests.
|
|
let thread = std::thread::current();
|
|
let test_name = thread.name();
|
|
Connection::open_memory(Some(&format!(
|
|
"THREAD_FALLBACK_{}",
|
|
test_name.unwrap_or_default()
|
|
)))
|
|
} else {
|
|
let threads_dir = paths::data_dir().join("threads");
|
|
std::fs::create_dir_all(&threads_dir)?;
|
|
let sqlite_path = threads_dir.join("threads.db");
|
|
Connection::open_file(&sqlite_path.to_string_lossy())
|
|
};
|
|
|
|
connection.exec(indoc! {"
|
|
CREATE TABLE IF NOT EXISTS threads (
|
|
id TEXT PRIMARY KEY,
|
|
summary TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
data_type TEXT NOT NULL,
|
|
data BLOB NOT NULL
|
|
)
|
|
"})?()
|
|
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
|
|
|
|
if let Ok(mut s) = connection.exec(indoc! {"
|
|
ALTER TABLE threads ADD COLUMN parent_id TEXT
|
|
"})
|
|
{
|
|
s().ok();
|
|
}
|
|
|
|
if let Ok(mut s) = connection.exec(indoc! {"
|
|
ALTER TABLE threads ADD COLUMN folder_paths TEXT;
|
|
ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
|
|
"})
|
|
{
|
|
s().ok();
|
|
}
|
|
|
|
if let Ok(mut s) = connection.exec(indoc! {"
|
|
ALTER TABLE threads ADD COLUMN created_at TEXT;
|
|
"})
|
|
{
|
|
if s().is_ok() {
|
|
connection.exec(indoc! {"
|
|
UPDATE threads SET created_at = updated_at WHERE created_at IS NULL
|
|
"})?()?;
|
|
}
|
|
}
|
|
|
|
let db = Self {
|
|
executor,
|
|
connection: Arc::new(Mutex::new(connection)),
|
|
};
|
|
|
|
Ok(db)
|
|
}
|
|
|
|
fn save_thread_sync(
|
|
connection: &Arc<Mutex<Connection>>,
|
|
id: acp::SessionId,
|
|
thread: DbThread,
|
|
folder_paths: &PathList,
|
|
) -> Result<()> {
|
|
const COMPRESSION_LEVEL: i32 = 3;
|
|
|
|
#[derive(Serialize)]
|
|
struct SerializedThread {
|
|
#[serde(flatten)]
|
|
thread: DbThread,
|
|
version: &'static str,
|
|
}
|
|
|
|
let title = thread.title.to_string();
|
|
let updated_at = thread.updated_at.to_rfc3339();
|
|
let parent_id = thread
|
|
.subagent_context
|
|
.as_ref()
|
|
.map(|ctx| ctx.parent_thread_id.0.clone());
|
|
let serialized_folder_paths = folder_paths.serialize();
|
|
let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
|
|
if folder_paths.is_empty() {
|
|
(None, None)
|
|
} else {
|
|
(
|
|
Some(serialized_folder_paths.paths),
|
|
Some(serialized_folder_paths.order),
|
|
)
|
|
};
|
|
let json_data = serde_json::to_string(&SerializedThread {
|
|
thread,
|
|
version: DbThread::VERSION,
|
|
})?;
|
|
|
|
let connection = connection.lock();
|
|
|
|
let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
|
|
let data_type = DataType::Zstd;
|
|
let data = compressed;
|
|
|
|
// Use the thread's updated_at as created_at for new threads.
|
|
// This ensures the creation time reflects when the thread was conceptually
|
|
// created, not when it was saved to the database.
|
|
let created_at = updated_at.clone();
|
|
|
|
let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
|
|
INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
parent_id = excluded.parent_id,
|
|
folder_paths = excluded.folder_paths,
|
|
folder_paths_order = excluded.folder_paths_order,
|
|
summary = excluded.summary,
|
|
updated_at = excluded.updated_at,
|
|
data_type = excluded.data_type,
|
|
data = excluded.data
|
|
"})?;
|
|
|
|
insert((
|
|
id.0,
|
|
parent_id,
|
|
folder_paths_str,
|
|
folder_paths_order_str,
|
|
title,
|
|
updated_at,
|
|
data_type,
|
|
data,
|
|
created_at,
|
|
))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
|
|
let connection = self.connection.clone();
|
|
|
|
self.executor.spawn(async move {
|
|
let connection = connection.lock();
|
|
|
|
let mut select = connection
|
|
.select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, Option<String>)>(indoc! {"
|
|
SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at FROM threads ORDER BY updated_at DESC, created_at DESC
|
|
"})?;
|
|
|
|
let rows = select(())?;
|
|
let mut threads = Vec::new();
|
|
|
|
for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, created_at) in rows {
|
|
let folder_paths = folder_paths
|
|
.map(|paths| {
|
|
PathList::deserialize(&util::path_list::SerializedPathList {
|
|
paths,
|
|
order: folder_paths_order.unwrap_or_default(),
|
|
})
|
|
})
|
|
.unwrap_or_default();
|
|
let created_at = created_at
|
|
.as_deref()
|
|
.map(DateTime::parse_from_rfc3339)
|
|
.transpose()?
|
|
.map(|dt| dt.with_timezone(&Utc));
|
|
|
|
threads.push(DbThreadMetadata {
|
|
id: acp::SessionId::new(id),
|
|
parent_session_id: parent_id.map(acp::SessionId::new),
|
|
title: summary.into(),
|
|
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
|
created_at,
|
|
folder_paths,
|
|
});
|
|
}
|
|
|
|
Ok(threads)
|
|
})
|
|
}
|
|
|
|
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
|
|
let connection = self.connection.clone();
|
|
|
|
self.executor.spawn(async move {
|
|
let connection = connection.lock();
|
|
let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
|
|
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
|
|
"})?;
|
|
|
|
let rows = select(id.0)?;
|
|
if let Some((data_type, data)) = rows.into_iter().next() {
|
|
Ok(Some(Self::deserialize_thread(data_type, data)?))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn save_thread(
|
|
&self,
|
|
id: acp::SessionId,
|
|
thread: DbThread,
|
|
folder_paths: PathList,
|
|
) -> Task<Result<()>> {
|
|
let connection = self.connection.clone();
|
|
|
|
self.executor
|
|
.spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
|
|
}
|
|
|
|
fn deserialize_thread(data_type: DataType, data: Vec<u8>) -> Result<DbThread> {
|
|
let json_data = match data_type {
|
|
DataType::Zstd => {
|
|
let decompressed = zstd::decode_all(&data[..])?;
|
|
String::from_utf8(decompressed)?
|
|
}
|
|
DataType::Json => String::from_utf8(data)?,
|
|
};
|
|
DbThread::from_json(json_data.as_bytes())
|
|
}
|
|
|
|
fn sandboxed_terminal_temp_dir(data_type: DataType, data: Vec<u8>) -> Option<PathBuf> {
|
|
match Self::deserialize_thread(data_type, data) {
|
|
Ok(thread) => thread.sandboxed_terminal_temp_dir,
|
|
Err(error) => {
|
|
log::warn!("failed to deserialize thread before deleting it: {error:#}");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
fn remove_sandboxed_terminal_temp_dir(temp_dir: PathBuf) {
|
|
match std::fs::remove_dir_all(&temp_dir) {
|
|
Ok(()) => {}
|
|
Err(error) if error.kind() == ErrorKind::NotFound => {}
|
|
Err(error) => {
|
|
log::warn!(
|
|
"failed to remove sandboxed terminal temp directory {}: {error}",
|
|
temp_dir.display()
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
|
|
let connection = self.connection.clone();
|
|
|
|
self.executor.spawn(async move {
|
|
let sandboxed_terminal_temp_dir = {
|
|
let connection = connection.lock();
|
|
|
|
let mut select =
|
|
connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
|
|
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
|
|
"})?;
|
|
|
|
let sandboxed_terminal_temp_dir = select(id.0.clone())?
|
|
.into_iter()
|
|
.next()
|
|
.and_then(|(data_type, data)| {
|
|
Self::sandboxed_terminal_temp_dir(data_type, data)
|
|
});
|
|
|
|
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
|
|
DELETE FROM threads WHERE id = ?
|
|
"})?;
|
|
|
|
delete(id.0)?;
|
|
|
|
sandboxed_terminal_temp_dir
|
|
};
|
|
|
|
if let Some(temp_dir) = sandboxed_terminal_temp_dir {
|
|
Self::remove_sandboxed_terminal_temp_dir(temp_dir);
|
|
}
|
|
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn delete_threads(&self) -> Task<Result<()>> {
|
|
let connection = self.connection.clone();
|
|
|
|
self.executor.spawn(async move {
|
|
let sandboxed_terminal_temp_dirs = {
|
|
let connection = connection.lock();
|
|
|
|
let mut select = connection.select_bound::<(), (DataType, Vec<u8>)>(indoc! {"
|
|
SELECT data_type, data FROM threads
|
|
"})?;
|
|
|
|
let sandboxed_terminal_temp_dirs = select(())?
|
|
.into_iter()
|
|
.filter_map(|(data_type, data)| {
|
|
Self::sandboxed_terminal_temp_dir(data_type, data)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let mut delete = connection.exec_bound::<()>(indoc! {"
|
|
DELETE FROM threads
|
|
"})?;
|
|
|
|
delete(())?;
|
|
|
|
sandboxed_terminal_temp_dirs
|
|
};
|
|
|
|
for temp_dir in sandboxed_terminal_temp_dirs {
|
|
Self::remove_sandboxed_terminal_temp_dir(temp_dir);
|
|
}
|
|
|
|
Ok(())
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use chrono::{DateTime, TimeZone, Utc};
|
|
use collections::HashMap;
|
|
use gpui::TestAppContext;
|
|
use std::sync::Arc;
|
|
|
|
#[test]
|
|
fn test_shared_thread_roundtrip() {
|
|
let original = SharedThread {
|
|
title: "Test Thread".into(),
|
|
messages: vec![],
|
|
updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
model: None,
|
|
version: SharedThread::VERSION.to_string(),
|
|
};
|
|
|
|
let bytes = original.to_bytes().expect("Failed to serialize");
|
|
let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
|
|
|
|
assert_eq!(restored.title, original.title);
|
|
assert_eq!(restored.version, original.version);
|
|
assert_eq!(restored.updated_at, original.updated_at);
|
|
}
|
|
|
|
fn session_id(value: &str) -> acp::SessionId {
|
|
acp::SessionId::new(Arc::<str>::from(value))
|
|
}
|
|
|
|
fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
|
|
DbThread {
|
|
title: title.to_string().into(),
|
|
messages: Vec::new(),
|
|
updated_at,
|
|
detailed_summary: None,
|
|
initial_project_snapshot: None,
|
|
cumulative_token_usage: Default::default(),
|
|
request_token_usage: HashMap::default(),
|
|
model: None,
|
|
profile: None,
|
|
subagent_context: None,
|
|
speed: None,
|
|
thinking_enabled: false,
|
|
thinking_effort: None,
|
|
draft_prompt: None,
|
|
ui_scroll_position: None,
|
|
sandboxed_terminal_temp_dir: None,
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_list_threads_orders_by_created_at(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let older_id = session_id("thread-a");
|
|
let newer_id = session_id("thread-b");
|
|
|
|
let older_thread = make_thread(
|
|
"Thread A",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
let newer_thread = make_thread(
|
|
"Thread B",
|
|
Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
|
|
);
|
|
|
|
database
|
|
.save_thread(older_id.clone(), older_thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
database
|
|
.save_thread(newer_id.clone(), newer_thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let entries = database.list_threads().await.unwrap();
|
|
assert_eq!(entries.len(), 2);
|
|
assert_eq!(entries[0].id, newer_id);
|
|
assert_eq!(entries[1].id, older_id);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_save_thread_replaces_metadata(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let thread_id = session_id("thread-a");
|
|
let original_thread = make_thread(
|
|
"Thread A",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
let updated_thread = make_thread(
|
|
"Thread B",
|
|
Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
|
|
);
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), original_thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
database
|
|
.save_thread(thread_id.clone(), updated_thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let entries = database.list_threads().await.unwrap();
|
|
assert_eq!(entries.len(), 1);
|
|
assert_eq!(entries[0].id, thread_id);
|
|
assert_eq!(entries[0].title.as_ref(), "Thread B");
|
|
assert_eq!(
|
|
entries[0].updated_at,
|
|
Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap()
|
|
);
|
|
assert!(
|
|
entries[0].created_at.is_some(),
|
|
"created_at should be populated"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_subagent_context_defaults_to_none() {
|
|
let json = r#"{
|
|
"title": "Old Thread",
|
|
"messages": [],
|
|
"updated_at": "2024-01-01T00:00:00Z"
|
|
}"#;
|
|
|
|
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
|
|
|
|
assert!(
|
|
db_thread.subagent_context.is_none(),
|
|
"Legacy threads without subagent_context should default to None"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_draft_prompt_defaults_to_none() {
|
|
let json = r#"{
|
|
"title": "Old Thread",
|
|
"messages": [],
|
|
"updated_at": "2024-01-01T00:00:00Z"
|
|
}"#;
|
|
|
|
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
|
|
|
|
assert!(
|
|
db_thread.draft_prompt.is_none(),
|
|
"Legacy threads without draft_prompt field should default to None"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sandboxed_terminal_temp_dir_defaults_to_none() {
|
|
let json = r#"{
|
|
"title": "Old Thread",
|
|
"messages": [],
|
|
"updated_at": "2024-01-01T00:00:00Z"
|
|
}"#;
|
|
|
|
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
|
|
|
|
assert!(
|
|
db_thread.sandboxed_terminal_temp_dir.is_none(),
|
|
"Legacy threads without sandboxed_terminal_temp_dir should default to None"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_sandboxed_terminal_temp_dir_roundtrips_through_save_load(
|
|
cx: &mut TestAppContext,
|
|
) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
let thread_id = session_id("sandbox-temp-dir-thread");
|
|
let temp_dir = tempfile::Builder::new()
|
|
.prefix("zed-agent-terminal-test-")
|
|
.tempdir()
|
|
.unwrap()
|
|
.keep();
|
|
let mut thread = make_thread(
|
|
"Sandbox Temp Dir Thread",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
thread.sandboxed_terminal_temp_dir = Some(temp_dir.clone());
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let loaded = database
|
|
.load_thread(thread_id)
|
|
.await
|
|
.unwrap()
|
|
.expect("thread should exist");
|
|
assert_eq!(loaded.sandboxed_terminal_temp_dir, Some(temp_dir.clone()));
|
|
std::fs::remove_dir_all(temp_dir).unwrap();
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_delete_thread_removes_sandboxed_terminal_temp_dir(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
let thread_id = session_id("sandbox-temp-dir-delete-thread");
|
|
let temp_dir = tempfile::Builder::new()
|
|
.prefix("zed-agent-terminal-test-")
|
|
.tempdir()
|
|
.unwrap()
|
|
.keep();
|
|
std::fs::write(temp_dir.join("sentinel"), b"content").unwrap();
|
|
let mut thread = make_thread(
|
|
"Sandbox Temp Dir Delete Thread",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
thread.sandboxed_terminal_temp_dir = Some(temp_dir.clone());
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
database.delete_thread(thread_id).await.unwrap();
|
|
|
|
assert!(!temp_dir.exists());
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subagent_context_roundtrips_through_save_load(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let parent_id = session_id("parent-thread");
|
|
let child_id = session_id("child-thread");
|
|
|
|
let mut child_thread = make_thread(
|
|
"Subagent Thread",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
child_thread.subagent_context = Some(crate::SubagentContext {
|
|
parent_thread_id: parent_id.clone(),
|
|
depth: 2,
|
|
});
|
|
|
|
database
|
|
.save_thread(child_id.clone(), child_thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let loaded = database
|
|
.load_thread(child_id)
|
|
.await
|
|
.unwrap()
|
|
.expect("thread should exist");
|
|
|
|
let context = loaded
|
|
.subagent_context
|
|
.expect("subagent_context should be restored");
|
|
assert_eq!(context.parent_thread_id, parent_id);
|
|
assert_eq!(context.depth, 2);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_non_subagent_thread_has_no_subagent_context(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let thread_id = session_id("regular-thread");
|
|
let thread = make_thread(
|
|
"Regular Thread",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let loaded = database
|
|
.load_thread(thread_id)
|
|
.await
|
|
.unwrap()
|
|
.expect("thread should exist");
|
|
|
|
assert!(
|
|
loaded.subagent_context.is_none(),
|
|
"Regular threads should have no subagent_context"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let thread_id = session_id("folder-thread");
|
|
let thread = make_thread(
|
|
"Folder Thread",
|
|
Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
|
|
);
|
|
|
|
let folder_paths = PathList::new(&[
|
|
std::path::PathBuf::from("/home/user/project-a"),
|
|
std::path::PathBuf::from("/home/user/project-b"),
|
|
]);
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), thread, folder_paths.clone())
|
|
.await
|
|
.unwrap();
|
|
|
|
let threads = database.list_threads().await.unwrap();
|
|
assert_eq!(threads.len(), 1);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let thread_id = session_id("no-folder-thread");
|
|
let thread = make_thread(
|
|
"No Folder Thread",
|
|
Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
|
|
);
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let threads = database.list_threads().await.unwrap();
|
|
assert_eq!(threads.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_scroll_position_defaults_to_none() {
|
|
let json = r#"{
|
|
"title": "Old Thread",
|
|
"messages": [],
|
|
"updated_at": "2024-01-01T00:00:00Z"
|
|
}"#;
|
|
|
|
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
|
|
|
|
assert!(
|
|
db_thread.ui_scroll_position.is_none(),
|
|
"Legacy threads without scroll_position field should default to None"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
|
|
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
|
|
|
let thread_id = session_id("thread-with-scroll");
|
|
|
|
let mut thread = make_thread(
|
|
"Thread With Scroll",
|
|
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
|
);
|
|
thread.ui_scroll_position = Some(SerializedScrollPosition {
|
|
item_ix: 42,
|
|
offset_in_item: 13.5,
|
|
});
|
|
|
|
database
|
|
.save_thread(thread_id.clone(), thread, PathList::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
let loaded = database
|
|
.load_thread(thread_id)
|
|
.await
|
|
.unwrap()
|
|
.expect("thread should exist");
|
|
|
|
let scroll = loaded
|
|
.ui_scroll_position
|
|
.expect("scroll_position should be restored");
|
|
assert_eq!(scroll.item_ix, 42);
|
|
assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
|
|
}
|
|
}
|