mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Persist token count and scroll position across agent restarts (#50620)
Release Notes: - Token counts and scroll position are restored when loading a previous agent thread
This commit is contained in:
parent
c1cbcb612d
commit
832782f6b3
7 changed files with 178 additions and 13 deletions
|
|
@ -972,6 +972,8 @@ pub struct AcpThread {
|
|||
had_error: bool,
|
||||
/// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
|
||||
draft_prompt: Option<Vec<acp::ContentBlock>>,
|
||||
/// The initial scroll position for the thread view, set during session registration.
|
||||
ui_scroll_position: Option<gpui::ListOffset>,
|
||||
}
|
||||
|
||||
impl From<&AcpThread> for ActionLogTelemetry {
|
||||
|
|
@ -1210,6 +1212,7 @@ impl AcpThread {
|
|||
pending_terminal_exit: HashMap::default(),
|
||||
had_error: false,
|
||||
draft_prompt: None,
|
||||
ui_scroll_position: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1229,6 +1232,14 @@ impl AcpThread {
|
|||
self.draft_prompt = prompt;
|
||||
}
|
||||
|
||||
pub fn ui_scroll_position(&self) -> Option<gpui::ListOffset> {
|
||||
self.ui_scroll_position
|
||||
}
|
||||
|
||||
pub fn set_ui_scroll_position(&mut self, position: Option<gpui::ListOffset>) {
|
||||
self.ui_scroll_position = position;
|
||||
}
|
||||
|
||||
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
|
||||
&self.connection
|
||||
}
|
||||
|
|
|
|||
|
|
@ -352,6 +352,8 @@ impl NativeAgent {
|
|||
let parent_session_id = thread.parent_thread_id();
|
||||
let title = thread.title();
|
||||
let draft_prompt = thread.draft_prompt().map(Vec::from);
|
||||
let scroll_position = thread.ui_scroll_position();
|
||||
let token_usage = thread.latest_token_usage();
|
||||
let project = thread.project.clone();
|
||||
let action_log = thread.action_log.clone();
|
||||
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
|
||||
|
|
@ -367,6 +369,8 @@ impl NativeAgent {
|
|||
cx,
|
||||
);
|
||||
acp_thread.set_draft_prompt(draft_prompt);
|
||||
acp_thread.set_ui_scroll_position(scroll_position);
|
||||
acp_thread.update_token_usage(token_usage, cx);
|
||||
acp_thread
|
||||
});
|
||||
|
||||
|
|
@ -1917,7 +1921,9 @@ mod internal_tests {
|
|||
use gpui::TestAppContext;
|
||||
use indoc::formatdoc;
|
||||
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
||||
use language_model::{LanguageModelProviderId, LanguageModelProviderName};
|
||||
use language_model::{
|
||||
LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
|
||||
};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::{path, rel_path::rel_path};
|
||||
|
|
@ -2549,6 +2555,13 @@ mod internal_tests {
|
|||
cx.run_until_parked();
|
||||
|
||||
model.send_last_completion_stream_text_chunk("Lorem.");
|
||||
model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 150,
|
||||
output_tokens: 75,
|
||||
..Default::default()
|
||||
},
|
||||
));
|
||||
model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
summary_model
|
||||
|
|
@ -2587,6 +2600,12 @@ mod internal_tests {
|
|||
acp_thread.update(cx, |thread, _cx| {
|
||||
thread.set_draft_prompt(Some(draft_blocks.clone()));
|
||||
});
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.set_ui_scroll_position(Some(gpui::ListOffset {
|
||||
item_ix: 5,
|
||||
offset_in_item: gpui::px(12.5),
|
||||
}));
|
||||
});
|
||||
thread.update(cx, |_thread, cx| cx.notify());
|
||||
cx.run_until_parked();
|
||||
|
||||
|
|
@ -2632,6 +2651,24 @@ mod internal_tests {
|
|||
acp_thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
|
||||
});
|
||||
|
||||
// Ensure token usage survived the round-trip.
|
||||
acp_thread.read_with(cx, |thread, _| {
|
||||
let usage = thread
|
||||
.token_usage()
|
||||
.expect("token usage should be restored after reload");
|
||||
assert_eq!(usage.input_tokens, 150);
|
||||
assert_eq!(usage.output_tokens, 75);
|
||||
});
|
||||
|
||||
// Ensure scroll position survived the round-trip.
|
||||
acp_thread.read_with(cx, |thread, _| {
|
||||
let scroll = thread
|
||||
.ui_scroll_position()
|
||||
.expect("scroll position should be restored after reload");
|
||||
assert_eq!(scroll.item_ix, 5);
|
||||
assert_eq!(scroll.offset_in_item, gpui::px(12.5));
|
||||
});
|
||||
}
|
||||
|
||||
fn thread_entries(
|
||||
|
|
|
|||
|
|
@ -66,6 +66,14 @@ pub struct DbThread {
|
|||
pub thinking_effort: Option<String>,
|
||||
#[serde(default)]
|
||||
pub draft_prompt: Option<Vec<acp::ContentBlock>>,
|
||||
#[serde(default)]
|
||||
pub ui_scroll_position: Option<SerializedScrollPosition>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SerializedScrollPosition {
|
||||
pub item_ix: usize,
|
||||
pub offset_in_item: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -108,6 +116,7 @@ impl SharedThread {
|
|||
thinking_enabled: false,
|
||||
thinking_effort: None,
|
||||
draft_prompt: None,
|
||||
ui_scroll_position: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -286,6 +295,7 @@ impl DbThread {
|
|||
thinking_enabled: false,
|
||||
thinking_effort: None,
|
||||
draft_prompt: None,
|
||||
ui_scroll_position: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -637,6 +647,7 @@ mod tests {
|
|||
thinking_enabled: false,
|
||||
thinking_effort: None,
|
||||
draft_prompt: None,
|
||||
ui_scroll_position: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -841,4 +852,53 @@ mod tests {
|
|||
assert_eq!(threads.len(), 1);
|
||||
assert!(threads[0].folder_paths.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scroll_position_defaults_to_none() {
|
||||
let json = r#"{
|
||||
"title": "Old Thread",
|
||||
"messages": [],
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}"#;
|
||||
|
||||
let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
|
||||
|
||||
assert!(
|
||||
db_thread.ui_scroll_position.is_none(),
|
||||
"Legacy threads without scroll_position field should default to None"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
|
||||
let database = ThreadsDatabase::new(cx.executor()).unwrap();
|
||||
|
||||
let thread_id = session_id("thread-with-scroll");
|
||||
|
||||
let mut thread = make_thread(
|
||||
"Thread With Scroll",
|
||||
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
|
||||
);
|
||||
thread.ui_scroll_position = Some(SerializedScrollPosition {
|
||||
item_ix: 42,
|
||||
offset_in_item: 13.5,
|
||||
});
|
||||
|
||||
database
|
||||
.save_thread(thread_id.clone(), thread, PathList::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loaded = database
|
||||
.load_thread(thread_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("thread should exist");
|
||||
|
||||
let scroll = loaded
|
||||
.ui_scroll_position
|
||||
.expect("scroll_position should be restored");
|
||||
assert_eq!(scroll.item_ix, 42);
|
||||
assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -901,6 +901,7 @@ pub struct Thread {
|
|||
subagent_context: Option<SubagentContext>,
|
||||
/// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
|
||||
draft_prompt: Option<Vec<acp::ContentBlock>>,
|
||||
ui_scroll_position: Option<gpui::ListOffset>,
|
||||
/// Weak references to running subagent threads for cancellation propagation
|
||||
running_subagents: Vec<WeakEntity<Thread>>,
|
||||
}
|
||||
|
|
@ -1017,6 +1018,7 @@ impl Thread {
|
|||
imported: false,
|
||||
subagent_context: None,
|
||||
draft_prompt: None,
|
||||
ui_scroll_position: None,
|
||||
running_subagents: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -1233,6 +1235,10 @@ impl Thread {
|
|||
imported: db_thread.imported,
|
||||
subagent_context: db_thread.subagent_context,
|
||||
draft_prompt: db_thread.draft_prompt,
|
||||
ui_scroll_position: db_thread.ui_scroll_position.map(|sp| gpui::ListOffset {
|
||||
item_ix: sp.item_ix,
|
||||
offset_in_item: gpui::px(sp.offset_in_item),
|
||||
}),
|
||||
running_subagents: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -1258,6 +1264,12 @@ impl Thread {
|
|||
thinking_enabled: self.thinking_enabled,
|
||||
thinking_effort: self.thinking_effort.clone(),
|
||||
draft_prompt: self.draft_prompt.clone(),
|
||||
ui_scroll_position: self.ui_scroll_position.map(|lo| {
|
||||
crate::db::SerializedScrollPosition {
|
||||
item_ix: lo.item_ix,
|
||||
offset_in_item: lo.offset_in_item.as_f32(),
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
cx.background_spawn(async move {
|
||||
|
|
@ -1307,6 +1319,14 @@ impl Thread {
|
|||
self.draft_prompt = prompt;
|
||||
}
|
||||
|
||||
pub fn ui_scroll_position(&self) -> Option<gpui::ListOffset> {
|
||||
self.ui_scroll_position
|
||||
}
|
||||
|
||||
pub fn set_ui_scroll_position(&mut self, position: Option<gpui::ListOffset>) {
|
||||
self.ui_scroll_position = position;
|
||||
}
|
||||
|
||||
pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
|
||||
self.model.as_ref()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,6 +146,7 @@ mod tests {
|
|||
thinking_enabled: false,
|
||||
thinking_effort: None,
|
||||
draft_prompt: None,
|
||||
ui_scroll_position: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -845,6 +845,10 @@ impl ConnectionView {
|
|||
);
|
||||
});
|
||||
|
||||
if let Some(scroll_position) = thread.read(cx).ui_scroll_position() {
|
||||
list_state.scroll_to(scroll_position);
|
||||
}
|
||||
|
||||
AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx);
|
||||
|
||||
let connection = thread.read(cx).connection().clone();
|
||||
|
|
|
|||
|
|
@ -248,7 +248,8 @@ pub struct ThreadView {
|
|||
pub resumed_without_history: bool,
|
||||
pub resume_thread_metadata: Option<AgentSessionInfo>,
|
||||
pub _cancel_task: Option<Task<()>>,
|
||||
_draft_save_task: Option<Task<()>>,
|
||||
_save_task: Option<Task<()>>,
|
||||
_draft_resolve_task: Option<Task<()>>,
|
||||
pub skip_queue_processing_count: usize,
|
||||
pub user_interrupted_generation: bool,
|
||||
pub can_fast_track_queue: bool,
|
||||
|
|
@ -396,7 +397,7 @@ impl ThreadView {
|
|||
} else {
|
||||
Some(editor.update(cx, |editor, cx| editor.draft_contents(cx)))
|
||||
};
|
||||
this._draft_save_task = Some(cx.spawn(async move |this, cx| {
|
||||
this._draft_resolve_task = Some(cx.spawn(async move |this, cx| {
|
||||
let draft = if let Some(task) = draft_contents_task {
|
||||
let blocks = task.await.ok().filter(|b| !b.is_empty());
|
||||
blocks
|
||||
|
|
@ -407,15 +408,7 @@ impl ThreadView {
|
|||
this.thread.update(cx, |thread, _cx| {
|
||||
thread.set_draft_prompt(draft);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
cx.background_executor()
|
||||
.timer(SERIALIZATION_THROTTLE_TIME)
|
||||
.await;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(thread) = this.as_native_thread(cx) {
|
||||
thread.update(cx, |_thread, cx| cx.notify());
|
||||
}
|
||||
this.schedule_save(cx);
|
||||
})
|
||||
.ok();
|
||||
}));
|
||||
|
|
@ -471,7 +464,8 @@ impl ThreadView {
|
|||
is_loading_contents: false,
|
||||
new_server_version_available: None,
|
||||
_cancel_task: None,
|
||||
_draft_save_task: None,
|
||||
_save_task: None,
|
||||
_draft_resolve_task: None,
|
||||
skip_queue_processing_count: 0,
|
||||
user_interrupted_generation: false,
|
||||
can_fast_track_queue: false,
|
||||
|
|
@ -487,12 +481,50 @@ impl ThreadView {
|
|||
_history_subscription: history_subscription,
|
||||
show_codex_windows_warning,
|
||||
};
|
||||
let list_state_for_scroll = this.list_state.clone();
|
||||
let thread_view = cx.entity().downgrade();
|
||||
this.list_state
|
||||
.set_scroll_handler(move |_event, _window, cx| {
|
||||
let list_state = list_state_for_scroll.clone();
|
||||
let thread_view = thread_view.clone();
|
||||
// N.B. We must defer because the scroll handler is called while the
|
||||
// ListState's RefCell is mutably borrowed. Reading logical_scroll_top()
|
||||
// directly would panic from a double borrow.
|
||||
cx.defer(move |cx| {
|
||||
let scroll_top = list_state.logical_scroll_top();
|
||||
let _ = thread_view.update(cx, |this, cx| {
|
||||
if let Some(thread) = this.as_native_thread(cx) {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.set_ui_scroll_position(Some(scroll_top));
|
||||
});
|
||||
}
|
||||
this.schedule_save(cx);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if should_auto_submit {
|
||||
this.send(window, cx);
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
/// Schedule a throttled save of the thread state (draft prompt, scroll position, etc.).
|
||||
/// Multiple calls within `SERIALIZATION_THROTTLE_TIME` are coalesced into a single save.
|
||||
fn schedule_save(&mut self, cx: &mut Context<Self>) {
|
||||
self._save_task = Some(cx.spawn(async move |this, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SERIALIZATION_THROTTLE_TIME)
|
||||
.await;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(thread) = this.as_native_thread(cx) {
|
||||
thread.update(cx, |_thread, cx| cx.notify());
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}));
|
||||
}
|
||||
|
||||
pub fn handle_message_editor_event(
|
||||
&mut self,
|
||||
_editor: &Entity<MessageEditor>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue