mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
Add context compaction markers
This commit is contained in:
parent
2ea99a81f1
commit
a4da81844e
7 changed files with 172 additions and 20 deletions
|
|
@ -182,6 +182,7 @@ pub enum AgentThreadEntry {
|
|||
AssistantMessage(AssistantMessage),
|
||||
ToolCall(ToolCall),
|
||||
CompletedPlan(Vec<PlanEntry>),
|
||||
ContextCompaction,
|
||||
}
|
||||
|
||||
impl AgentThreadEntry {
|
||||
|
|
@ -191,6 +192,7 @@ impl AgentThreadEntry {
|
|||
Self::AssistantMessage(message) => message.indented,
|
||||
Self::ToolCall(_) => false,
|
||||
Self::CompletedPlan(_) => false,
|
||||
Self::ContextCompaction => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -207,6 +209,7 @@ impl AgentThreadEntry {
|
|||
}
|
||||
md
|
||||
}
|
||||
Self::ContextCompaction => "--- Context Compacted ---\n\n".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1467,7 +1470,8 @@ impl AcpThread {
|
|||
}) => return true,
|
||||
AgentThreadEntry::ToolCall(_)
|
||||
| AgentThreadEntry::AssistantMessage(_)
|
||||
| AgentThreadEntry::CompletedPlan(_) => {}
|
||||
| AgentThreadEntry::CompletedPlan(_)
|
||||
| AgentThreadEntry::ContextCompaction => {}
|
||||
}
|
||||
}
|
||||
false
|
||||
|
|
@ -1495,7 +1499,8 @@ impl AcpThread {
|
|||
}
|
||||
AgentThreadEntry::ToolCall(_)
|
||||
| AgentThreadEntry::AssistantMessage(_)
|
||||
| AgentThreadEntry::CompletedPlan(_) => {}
|
||||
| AgentThreadEntry::CompletedPlan(_)
|
||||
| AgentThreadEntry::ContextCompaction => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1514,7 +1519,8 @@ impl AcpThread {
|
|||
}
|
||||
AgentThreadEntry::ToolCall(_)
|
||||
| AgentThreadEntry::AssistantMessage(_)
|
||||
| AgentThreadEntry::CompletedPlan(_) => {}
|
||||
| AgentThreadEntry::CompletedPlan(_)
|
||||
| AgentThreadEntry::ContextCompaction => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1525,9 +1531,9 @@ impl AcpThread {
|
|||
for entry in self.entries.iter().rev() {
|
||||
match entry {
|
||||
AgentThreadEntry::UserMessage(..) => return false,
|
||||
AgentThreadEntry::AssistantMessage(..) | AgentThreadEntry::CompletedPlan(..) => {
|
||||
continue;
|
||||
}
|
||||
AgentThreadEntry::AssistantMessage(..)
|
||||
| AgentThreadEntry::CompletedPlan(..)
|
||||
| AgentThreadEntry::ContextCompaction => continue,
|
||||
AgentThreadEntry::ToolCall(..) => return true,
|
||||
}
|
||||
}
|
||||
|
|
@ -1871,6 +1877,10 @@ impl AcpThread {
|
|||
cx.emit(AcpThreadEvent::NewEntry);
|
||||
}
|
||||
|
||||
pub fn push_context_compaction(&mut self, cx: &mut Context<Self>) {
|
||||
self.push_entry(AgentThreadEntry::ContextCompaction, cx);
|
||||
}
|
||||
|
||||
pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
self.connection.set_title(&self.session_id, cx).is_some()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1853,6 +1853,11 @@ impl NativeAgentConnection {
|
|||
thread.update_retry_status(status, cx)
|
||||
})?;
|
||||
}
|
||||
ThreadEvent::ContextCompaction => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_context_compaction(cx);
|
||||
})?;
|
||||
}
|
||||
ThreadEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse::new(stop_reason));
|
||||
|
|
|
|||
|
|
@ -123,11 +123,39 @@ enum RetryStrategy {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Message {
|
||||
User(UserMessage),
|
||||
Agent(AgentMessage),
|
||||
Resume,
|
||||
Compaction(CompactionInfo),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum CompactionInfo {
|
||||
Summary(SharedString),
|
||||
ProviderNative {
|
||||
provider: LanguageModelProviderId,
|
||||
items: Vec<serde_json::Value>,
|
||||
},
|
||||
}
|
||||
|
||||
impl CompactionInfo {
|
||||
fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
match self {
|
||||
Self::Summary(summary) => vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![format!(
|
||||
"The previous conversation was compacted. Use this summary as context:\n\n{}",
|
||||
summary
|
||||
)
|
||||
.into()],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}],
|
||||
Self::ProviderNative { .. } => Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Message {
|
||||
|
|
@ -148,6 +176,7 @@ impl Message {
|
|||
}
|
||||
}
|
||||
Message::Agent(message) => message.to_request(),
|
||||
Message::Compaction(info) => info.to_request(),
|
||||
Message::Resume => vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Continue where you left off".into()],
|
||||
|
|
@ -162,12 +191,13 @@ impl Message {
|
|||
Message::User(message) => message.to_markdown(),
|
||||
Message::Agent(message) => message.to_markdown(),
|
||||
Message::Resume => "[resume]\n".into(),
|
||||
Message::Compaction(_) => "--- Context Compacted ---\n".into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role(&self) -> Role {
|
||||
match self {
|
||||
Message::User(_) | Message::Resume => Role::User,
|
||||
Message::User(_) | Message::Resume | Message::Compaction(_) => Role::User,
|
||||
Message::Agent(_) => Role::Assistant,
|
||||
}
|
||||
}
|
||||
|
|
@ -688,6 +718,7 @@ pub enum ThreadEvent {
|
|||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
SubagentSpawned(acp::SessionId),
|
||||
Retry(acp_thread::RetryStatus),
|
||||
ContextCompaction,
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
|
@ -1225,6 +1256,7 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
Message::Resume => {}
|
||||
Message::Compaction(_) => stream.send_context_compaction(),
|
||||
}
|
||||
}
|
||||
rx
|
||||
|
|
@ -1834,7 +1866,7 @@ impl Thread {
|
|||
Message::User(message) => {
|
||||
self.request_token_usage.remove(&message.id);
|
||||
}
|
||||
Message::Agent(_) | Message::Resume => {}
|
||||
Message::Agent(_) | Message::Resume | Message::Compaction(_) => {}
|
||||
}
|
||||
}
|
||||
self.clear_summary();
|
||||
|
|
@ -2919,8 +2951,7 @@ impl Thread {
|
|||
.rev()
|
||||
.find_map(|message| match &**message {
|
||||
Message::User(user_message) => Some(user_message),
|
||||
Message::Agent(_) => None,
|
||||
Message::Resume => None,
|
||||
Message::Agent(_) | Message::Resume | Message::Compaction(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -3225,7 +3256,7 @@ impl Thread {
|
|||
match &**message {
|
||||
Message::User(_) => markdown.push_str("## User\n\n"),
|
||||
Message::Agent(_) => markdown.push_str("## Assistant\n\n"),
|
||||
Message::Resume => {}
|
||||
Message::Resume | Message::Compaction(_) => {}
|
||||
}
|
||||
markdown.push_str(&message.to_markdown());
|
||||
}
|
||||
|
|
@ -3801,6 +3832,12 @@ impl ThreadEventStream {
|
|||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||
}
|
||||
|
||||
fn send_context_compaction(&self) {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::ContextCompaction))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: acp::StopReason) {
|
||||
self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
|
||||
}
|
||||
|
|
@ -4543,6 +4580,75 @@ mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_compaction_renders_for_request_and_markdown() {
|
||||
let message = Message::Compaction(CompactionInfo::Summary("Older context".into()));
|
||||
|
||||
assert_eq!(message.role(), Role::User);
|
||||
assert_eq!(message.to_markdown(), "--- Context Compacted ---\n");
|
||||
|
||||
let request_messages = message.to_request();
|
||||
assert_eq!(request_messages.len(), 1);
|
||||
assert_eq!(request_messages[0].role, Role::User);
|
||||
assert!(!request_messages[0].cache);
|
||||
assert_eq!(request_messages[0].reasoning_details, None);
|
||||
assert_eq!(request_messages[0].content.len(), 1);
|
||||
let language_model::MessageContent::Text(text) = &request_messages[0].content[0] else {
|
||||
panic!("expected text summary context");
|
||||
};
|
||||
assert_eq!(
|
||||
text.as_str(),
|
||||
"The previous conversation was compacted. Use this summary as context:\n\nOlder context"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replay_emits_context_compaction(cx: &mut TestAppContext) {
|
||||
let (thread, _event_stream) = setup_thread_for_test(cx).await;
|
||||
let user_message_id = UserMessageId::new();
|
||||
|
||||
let mut replay_events = cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.messages.push(Arc::new(Message::User(UserMessage {
|
||||
id: user_message_id.clone(),
|
||||
content: vec![UserMessageContent::Text("before".to_string())].into(),
|
||||
})));
|
||||
thread
|
||||
.messages
|
||||
.push(Arc::new(Message::Compaction(CompactionInfo::Summary(
|
||||
"summary".into(),
|
||||
))));
|
||||
thread.messages.push(Arc::new(Message::Agent(AgentMessage {
|
||||
content: vec![AgentMessageContent::Text("after".to_string())],
|
||||
..Default::default()
|
||||
})));
|
||||
|
||||
thread.replay(cx)
|
||||
})
|
||||
});
|
||||
|
||||
let event = replay_events.next().await;
|
||||
assert!(
|
||||
matches!(
|
||||
&event,
|
||||
Some(Ok(ThreadEvent::UserMessage(UserMessage { id, .. }))) if id == &user_message_id
|
||||
),
|
||||
"expected replayed user message, got {event:?}"
|
||||
);
|
||||
|
||||
let event = replay_events.next().await;
|
||||
assert!(
|
||||
matches!(&event, Some(Ok(ThreadEvent::ContextCompaction))),
|
||||
"expected context compaction event, got {event:?}"
|
||||
);
|
||||
|
||||
let event = replay_events.next().await;
|
||||
assert!(
|
||||
matches!(&event, Some(Ok(ThreadEvent::AgentText(text))) if text == "after"),
|
||||
"expected replayed agent text, got {event:?}"
|
||||
);
|
||||
}
|
||||
|
||||
fn setup_parent_with_subagents(
|
||||
cx: &mut TestAppContext,
|
||||
parent: &Entity<Thread>,
|
||||
|
|
|
|||
|
|
@ -3621,6 +3621,7 @@ mod tests {
|
|||
acp_thread::AgentThreadEntry::AssistantMessage(_) => "assistant",
|
||||
acp_thread::AgentThreadEntry::ToolCall(_) => "tool_call",
|
||||
acp_thread::AgentThreadEntry::CompletedPlan(_) => "plan",
|
||||
acp_thread::AgentThreadEntry::ContextCompaction => "compaction",
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
|
|
|||
|
|
@ -5393,6 +5393,19 @@ impl ThreadView {
|
|||
AgentThreadEntry::CompletedPlan(entries) => {
|
||||
self.render_completed_plan(entries, window, cx)
|
||||
}
|
||||
AgentThreadEntry::ContextCompaction => h_flex()
|
||||
.id(("context_compaction", entry_ix))
|
||||
.px_5()
|
||||
.py_1()
|
||||
.gap_2()
|
||||
.child(Divider::horizontal())
|
||||
.child(
|
||||
Label::new("Context Compacted")
|
||||
.size(LabelSize::Custom(self.tool_name_font_size()))
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Divider::horizontal())
|
||||
.into_any(),
|
||||
};
|
||||
|
||||
let is_subagent_output = self.is_subagent()
|
||||
|
|
@ -6502,7 +6515,8 @@ impl ThreadView {
|
|||
}
|
||||
AgentThreadEntry::ToolCall(_)
|
||||
| AgentThreadEntry::AssistantMessage(_)
|
||||
| AgentThreadEntry::CompletedPlan(_) => {}
|
||||
| AgentThreadEntry::CompletedPlan(_)
|
||||
| AgentThreadEntry::ContextCompaction => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -232,6 +232,11 @@ impl EntryViewState {
|
|||
self.set_entry(index, Entry::CompletedPlan);
|
||||
}
|
||||
}
|
||||
AgentThreadEntry::ContextCompaction => {
|
||||
if !matches!(self.entries.get(index), Some(Entry::ContextCompaction)) {
|
||||
self.set_entry(index, Entry::ContextCompaction);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -252,7 +257,8 @@ impl EntryViewState {
|
|||
match entry {
|
||||
Entry::UserMessage { .. }
|
||||
| Entry::AssistantMessage { .. }
|
||||
| Entry::CompletedPlan => {}
|
||||
| Entry::CompletedPlan
|
||||
| Entry::ContextCompaction => {}
|
||||
Entry::ToolCall(ToolCallEntry { content, .. }) => {
|
||||
for view in content.values() {
|
||||
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
|
||||
|
|
@ -321,6 +327,7 @@ pub enum Entry {
|
|||
AssistantMessage(AssistantMessageEntry),
|
||||
ToolCall(ToolCallEntry),
|
||||
CompletedPlan,
|
||||
ContextCompaction,
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
|
|
@ -329,14 +336,17 @@ impl Entry {
|
|||
Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
|
||||
Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
|
||||
Self::ToolCall(tool_call) => Some(tool_call.focus_handle.clone()),
|
||||
Self::CompletedPlan => None,
|
||||
Self::CompletedPlan | Self::ContextCompaction => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
|
||||
match self {
|
||||
Self::UserMessage(editor) => Some(editor),
|
||||
Self::AssistantMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
|
||||
Self::AssistantMessage(_)
|
||||
| Self::ToolCall(_)
|
||||
| Self::CompletedPlan
|
||||
| Self::ContextCompaction => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -363,7 +373,10 @@ impl Entry {
|
|||
) -> Option<ScrollHandle> {
|
||||
match self {
|
||||
Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
|
||||
Self::UserMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
|
||||
Self::UserMessage(_)
|
||||
| Self::ToolCall(_)
|
||||
| Self::CompletedPlan
|
||||
| Self::ContextCompaction => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -378,7 +391,10 @@ impl Entry {
|
|||
pub fn has_content(&self) -> bool {
|
||||
match self {
|
||||
Self::ToolCall(ToolCallEntry { content, .. }) => !content.is_empty(),
|
||||
Self::UserMessage(_) | Self::AssistantMessage(_) | Self::CompletedPlan => false,
|
||||
Self::UserMessage(_)
|
||||
| Self::AssistantMessage(_)
|
||||
| Self::CompletedPlan
|
||||
| Self::ContextCompaction => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -395,7 +411,7 @@ impl Focusable for Entry {
|
|||
Self::UserMessage(editor) => editor.read(cx).focus_handle(cx),
|
||||
Self::AssistantMessage(message) => message.focus_handle.clone(),
|
||||
Self::ToolCall(tool_call) => tool_call.focus_handle.clone(),
|
||||
Self::CompletedPlan => cx.focus_handle(),
|
||||
Self::CompletedPlan | Self::ContextCompaction => cx.focus_handle(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ pub struct LanguageModelId(pub SharedString);
|
|||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelName(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
pub struct LanguageModelProviderId(pub SharedString);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
|
|
|
|||
Loading…
Reference in a new issue