mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
copilot: Fix issue when switching between OpenAI and Anthropic models (#56655)
We were storing reasoning output inside `RedactedThinking` which causes issues when switching mid-turn from an OpenAI to an Anthropic model. This implementation fixes this by storing it inside `reasoning_details`, which matches our responses implementation in `open_ai.rs` See https://github.com/microsoft/vscode-copilot-chat/blob/main/src/platform/endpoint/node/responsesApi.ts For whatever reason the copilot chat extension sets `summary: []`, this is what our implementation does too Closes #56385 Release Notes: - Fixed an issue where the agent would error when using Copilot as a provider and switching between OpenAI and Anthropic models
This commit is contained in:
parent
5aeb8a7e0f
commit
1dc07b40b9
2 changed files with 245 additions and 25 deletions
|
|
@ -139,12 +139,7 @@ pub enum ResponseInputItem {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<ItemStatus>,
|
||||
},
|
||||
Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
summary: Vec<ResponseReasoningItem>,
|
||||
encrypted_content: String,
|
||||
},
|
||||
Reasoning(ResponseReasoningInputItem),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
|
|
@ -162,7 +157,17 @@ pub struct IncompleteDetails {
|
|||
pub reason: Option<IncompleteReason>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ResponseReasoningInputItem {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub summary: Vec<ResponseReasoningItem>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub encrypted_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ResponseReasoningItem {
|
||||
#[serde(rename = "type")]
|
||||
pub kind: String,
|
||||
|
|
|
|||
|
|
@ -659,12 +659,14 @@ pub fn map_to_language_model_completion_events(
|
|||
|
||||
pub struct CopilotResponsesEventMapper {
|
||||
pending_stop_reason: Option<StopReason>,
|
||||
reasoning_items: Vec<copilot_responses::ResponseReasoningInputItem>,
|
||||
}
|
||||
|
||||
impl CopilotResponsesEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending_stop_reason: None,
|
||||
reasoning_items: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -740,13 +742,13 @@ impl CopilotResponsesEventMapper {
|
|||
events
|
||||
}
|
||||
copilot_responses::ResponseOutputItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
encrypted_content,
|
||||
..
|
||||
} => {
|
||||
let mut events = Vec::new();
|
||||
|
||||
if let Some(blocks) = summary {
|
||||
if let Some(blocks) = summary.as_ref() {
|
||||
let mut text = String::new();
|
||||
for block in blocks {
|
||||
text.push_str(&block.text);
|
||||
|
|
@ -759,8 +761,10 @@ impl CopilotResponsesEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(data) = encrypted_content {
|
||||
events.push(Ok(LanguageModelCompletionEvent::RedactedThinking { data }));
|
||||
if let Some(reasoning_item) =
|
||||
reasoning_input_item_from_output(&id, encrypted_content)
|
||||
{
|
||||
events.extend(self.capture_reasoning_item(reasoning_item));
|
||||
}
|
||||
|
||||
events
|
||||
|
|
@ -769,6 +773,7 @@ impl CopilotResponsesEventMapper {
|
|||
|
||||
copilot_responses::StreamEvent::Completed { response } => {
|
||||
let mut events = Vec::new();
|
||||
events.extend(self.capture_reasoning_items_from_output(&response.output));
|
||||
if let Some(usage) = response.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
|
|
@ -800,6 +805,7 @@ impl CopilotResponsesEventMapper {
|
|||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
events.extend(self.capture_reasoning_items_from_output(&response.output));
|
||||
if let Some(usage) = response.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
|
|
@ -840,6 +846,116 @@ impl CopilotResponsesEventMapper {
|
|||
| copilot_responses::StreamEvent::Unknown => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn capture_reasoning_items_from_output(
|
||||
&mut self,
|
||||
output: &[copilot_responses::ResponseOutputItem],
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut events = Vec::new();
|
||||
for item in output {
|
||||
if let copilot_responses::ResponseOutputItem::Reasoning {
|
||||
id,
|
||||
summary: _,
|
||||
encrypted_content,
|
||||
} = item
|
||||
{
|
||||
if let Some(reasoning_item) =
|
||||
reasoning_input_item_from_output(&id, encrypted_content.clone())
|
||||
{
|
||||
events.extend(self.capture_reasoning_item(reasoning_item));
|
||||
}
|
||||
}
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
fn capture_reasoning_item(
|
||||
&mut self,
|
||||
reasoning_item: copilot_responses::ResponseReasoningInputItem,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
if self.reasoning_items.contains(&reasoning_item) {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
if let Some(id) = reasoning_item.id.as_ref()
|
||||
&& let Some(existing_reasoning_item) = self
|
||||
.reasoning_items
|
||||
.iter_mut()
|
||||
.find(|existing_reasoning_item| existing_reasoning_item.id.as_ref() == Some(id))
|
||||
{
|
||||
*existing_reasoning_item = reasoning_item;
|
||||
} else {
|
||||
self.reasoning_items.push(reasoning_item);
|
||||
}
|
||||
|
||||
self.emit_response_message_metadata()
|
||||
}
|
||||
|
||||
fn emit_response_message_metadata(
|
||||
&self,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let details = serde_json::to_value(CopilotResponseMessageMetadata {
|
||||
reasoning_items: self.reasoning_items.clone(),
|
||||
});
|
||||
|
||||
match details {
|
||||
Ok(details) => vec![Ok(LanguageModelCompletionEvent::ReasoningDetails(details))],
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct CopilotResponseMessageMetadata {
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
reasoning_items: Vec<copilot_responses::ResponseReasoningInputItem>,
|
||||
}
|
||||
|
||||
fn append_reasoning_details_to_response_items(
|
||||
reasoning_details: Option<&serde_json::Value>,
|
||||
replayed_reasoning_item_indexes: &mut HashMap<String, usize>,
|
||||
input_items: &mut Vec<copilot_responses::ResponseInputItem>,
|
||||
) {
|
||||
let Some(reasoning_details) = reasoning_details else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(metadata) =
|
||||
serde_json::from_value::<CopilotResponseMessageMetadata>(reasoning_details.clone()).ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
for mut reasoning_item in metadata.reasoning_items {
|
||||
reasoning_item.summary.clear();
|
||||
if let Some(id) = reasoning_item.id.as_ref() {
|
||||
if let Some(index) = replayed_reasoning_item_indexes.get(id) {
|
||||
input_items[*index] =
|
||||
copilot_responses::ResponseInputItem::Reasoning(reasoning_item);
|
||||
return;
|
||||
}
|
||||
|
||||
replayed_reasoning_item_indexes.insert(id.clone(), input_items.len());
|
||||
}
|
||||
|
||||
input_items.push(copilot_responses::ResponseInputItem::Reasoning(
|
||||
reasoning_item,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn reasoning_input_item_from_output(
|
||||
id: &str,
|
||||
encrypted_content: Option<String>,
|
||||
) -> Option<copilot_responses::ResponseReasoningInputItem> {
|
||||
if encrypted_content.is_none() {
|
||||
return None;
|
||||
}
|
||||
Some(copilot_responses::ResponseReasoningInputItem {
|
||||
id: Some(id.to_string()),
|
||||
summary: Vec::new(),
|
||||
encrypted_content,
|
||||
})
|
||||
}
|
||||
|
||||
fn into_copilot_chat(
|
||||
|
|
@ -1100,6 +1216,7 @@ fn into_copilot_responses(
|
|||
} = request;
|
||||
|
||||
let mut input_items: Vec<responses::ResponseInputItem> = Vec::new();
|
||||
let mut replayed_reasoning_item_indexes = HashMap::default();
|
||||
|
||||
for message in messages {
|
||||
match message.role {
|
||||
|
|
@ -1181,6 +1298,12 @@ fn into_copilot_responses(
|
|||
}
|
||||
|
||||
Role::Assistant => {
|
||||
append_reasoning_details_to_response_items(
|
||||
message.reasoning_details.as_ref(),
|
||||
&mut replayed_reasoning_item_indexes,
|
||||
&mut input_items,
|
||||
);
|
||||
|
||||
for content in &message.content {
|
||||
if let MessageContent::ToolUse(tool_use) = content {
|
||||
input_items.push(responses::ResponseInputItem::FunctionCall {
|
||||
|
|
@ -1193,16 +1316,6 @@ fn into_copilot_responses(
|
|||
}
|
||||
}
|
||||
|
||||
for content in &message.content {
|
||||
if let MessageContent::RedactedThinking(data) = content {
|
||||
input_items.push(responses::ResponseInputItem::Reasoning {
|
||||
id: None,
|
||||
summary: Vec::new(),
|
||||
encrypted_content: data.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
|
||||
for content in &message.content {
|
||||
match content {
|
||||
|
|
@ -1297,6 +1410,7 @@ mod tests {
|
|||
use super::*;
|
||||
use copilot_chat::responses;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
|
||||
fn map_events(events: Vec<responses::StreamEvent>) -> Vec<LanguageModelCompletionEvent> {
|
||||
futures::executor::block_on(async {
|
||||
|
|
@ -1310,6 +1424,37 @@ mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
fn test_responses_model() -> CopilotChatModel {
|
||||
serde_json::from_value(json!({
|
||||
"billing": {
|
||||
"is_premium": false,
|
||||
"multiplier": 1.0
|
||||
},
|
||||
"capabilities": {
|
||||
"family": "test",
|
||||
"limits": {
|
||||
"max_context_window_tokens": 128000,
|
||||
"max_output_tokens": 4096
|
||||
},
|
||||
"supports": {
|
||||
"streaming": true,
|
||||
"tool_calls": true,
|
||||
"parallel_tool_calls": false,
|
||||
"vision": false
|
||||
},
|
||||
"type": "chat"
|
||||
},
|
||||
"id": "test-model",
|
||||
"is_chat_default": false,
|
||||
"is_chat_fallback": false,
|
||||
"model_picker_enabled": true,
|
||||
"name": "Test Model",
|
||||
"vendor": "OpenAI",
|
||||
"supported_endpoints": ["/responses"]
|
||||
}))
|
||||
.expect("valid test model")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_maps_text_and_usage() {
|
||||
let events = vec![
|
||||
|
|
@ -1435,10 +1580,80 @@ mod tests {
|
|||
mapped[0],
|
||||
LanguageModelCompletionEvent::Thinking { ref text, signature: None } if text == "Chain"
|
||||
));
|
||||
assert!(matches!(
|
||||
mapped[1],
|
||||
LanguageModelCompletionEvent::RedactedThinking { ref data } if data == "ENC"
|
||||
));
|
||||
match &mapped[1] {
|
||||
LanguageModelCompletionEvent::ReasoningDetails(details) => assert_eq!(
|
||||
details,
|
||||
&json!({
|
||||
"reasoning_items": [
|
||||
{
|
||||
"id": "r1",
|
||||
"summary": [],
|
||||
"encrypted_content": "ENC"
|
||||
}
|
||||
]
|
||||
})
|
||||
),
|
||||
other => panic!("expected reasoning details, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_copilot_responses_replays_reasoning_details() {
|
||||
let model = test_responses_model();
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![
|
||||
MessageContent::RedactedThinking("legacy-redacted".into()),
|
||||
MessageContent::Text("Done".into()),
|
||||
],
|
||||
cache: false,
|
||||
reasoning_details: Some(json!({
|
||||
"reasoning_items": [
|
||||
{
|
||||
"id": "r1",
|
||||
"summary": [
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": "Chain"
|
||||
}
|
||||
],
|
||||
"encrypted_content": "ENC"
|
||||
}
|
||||
]
|
||||
})),
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(into_copilot_responses(&model, request))
|
||||
.expect("serialized request");
|
||||
let input = serialized["input"].as_array().expect("input items");
|
||||
|
||||
assert_eq!(
|
||||
input.first(),
|
||||
Some(&json!({
|
||||
"type": "reasoning",
|
||||
"id": "r1",
|
||||
"summary": [],
|
||||
"encrypted_content": "ENC"
|
||||
}))
|
||||
);
|
||||
assert_eq!(
|
||||
input.get(1),
|
||||
Some(&json!({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Done"
|
||||
}
|
||||
],
|
||||
"status": "completed"
|
||||
}))
|
||||
);
|
||||
assert!(!serialized.to_string().contains("legacy-redacted"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Reference in a new issue