Merge branch 'main' into feature/branch-specific-open-files

This commit is contained in:
Niall O'Brien 2026-05-29 12:34:35 +01:00 committed by GitHub
commit 1add778d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 351 additions and 668 deletions

2
Cargo.lock generated
View file

@ -234,6 +234,7 @@ dependencies = [
"agent_settings",
"agent_skills",
"anyhow",
"assets",
"async-channel 2.5.0",
"async-io",
"chrono",
@ -290,6 +291,7 @@ dependencies = [
"tempfile",
"text",
"theme",
"theme_settings",
"thiserror 2.0.17",
"ui",
"unindent",

View file

@ -78,6 +78,7 @@ zed_env_vars.workspace = true
zstd.workspace = true
[dev-dependencies]
assets.workspace = true
async-io.workspace = true
agent_servers = { workspace = true, "features" = ["test-support"] }
client = { workspace = true, "features" = ["test-support"] }
@ -103,6 +104,7 @@ reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }
theme = { workspace = true, "features" = ["test-support"] }
theme_settings.workspace = true
unindent = { workspace = true }

View file

@ -1,4 +1,5 @@
use std::{
any::Any,
future::Future,
path::Path,
sync::Arc,
@ -14,26 +15,40 @@ use agent_settings::{AgentSettings, ToolRules};
use criterion::{
BatchSize, BenchmarkId, Criterion, Throughput, black_box, criterion_group, criterion_main,
};
use futures::{pin_mut, task::noop_waker};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext, UpdateGlobal as _};
use editor::{Editor, EditorStyle};
use futures::{StreamExt as _, pin_mut, task::noop_waker};
use gpui::{
AnyWindowHandle, AppContext as _, BackgroundExecutor, Entity, Focusable as _, TestAppContext,
UpdateGlobal as _,
};
use language::{FakeLspAdapter, rust_lang};
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
use prompt_store::ProjectContext;
use rand::{Rng as _, SeedableRng as _, rngs::StdRng};
use serde_json::{Value, json};
use settings::{Settings as _, SettingsStore};
use ui::IntoElement as _;
const SEED: u64 = 0x5EED_5EED;
const OLD_TEXT_CHUNK_SIZE: usize = 512;
const NEW_TEXT_CHUNK_SIZE: usize = 512;
const FILE_PROJECT_PATH: &str = "root/src/workspace_snapshot.rs";
const FILE_ABS_PATH: &str = "/root/src/workspace_snapshot.rs";
#[derive(Clone)]
struct EditOp {
old_text: String,
new_text: String,
}
#[derive(Clone)]
struct EditFixture {
name: &'static str,
old_file_text: String,
expected_file_text: String,
old_text: String,
new_text: String,
edits: Vec<EditOp>,
}
struct BenchmarkHarness {
@ -43,6 +58,12 @@ struct BenchmarkHarness {
partial_payloads: Vec<Value>,
final_payload: Value,
expected_file_text: String,
editor: Option<Entity<Editor>>,
window: Option<AnyWindowHandle>,
// Keeps the LSP buffer-registration handle and the fake language server alive
// for the lifetime of the benchmark so `didChange`/diagnostics keep flowing
// while edits are applied.
keep_alive: Vec<Box<dyn Any>>,
}
impl Drop for BenchmarkHarness {
@ -50,19 +71,18 @@ impl Drop for BenchmarkHarness {
// Release our handles to the entities first.
self.edit_tool.take();
self.thread.take();
self.editor.take();
self.keep_alive.clear();
if let Some(cx) = self.cx.take() {
// `ActionLog` holds buffers strongly via `tracked_buffers`, and spawns a background
// diff-maintenance task that also captures a strong `Entity<Buffer>`. Releasing the
// last handle to the action log only marks its entity for deferred release; the
// entity's value (and the buffer handles inside) is not actually dropped until
// `flush_effects` runs `release_dropped_entities`. Even then, the cancelled task's
// captured handle does not drop until the executor pumps the cancellation through.
//
// Without this two-step teardown, GPUI's test leak detector panics on
// `TestAppContext` drop because the buffer still appears alive. See
// `ActionLog::track_buffer_internal` and `LeakDetector::drop` in
// `crates/gpui/src/app/entity_map.rs`.
if let Some(mut cx) = self.cx.take() {
// Close the editor window so the editor entity and the buffer handles
// it holds are released, then pump the executor so cancelled editor /
// action-log background tasks drop their captured handles before the
// leak detector runs on `TestAppContext` drop.
if let Some(window) = self.window.take() {
cx.update_window(window, |_, window, _| window.remove_window())
.ok();
}
cx.update(|_| {});
cx.executor().run_until_parked();
cx.quit();
@ -76,9 +96,10 @@ fn edit_file_tool_streaming(c: &mut Criterion) {
group.sample_size(10);
for fixture in fixtures {
group.throughput(Throughput::Bytes(fixture.new_text.len() as u64));
let new_bytes: usize = fixture.edits.iter().map(|edit| edit.new_text.len()).sum();
group.throughput(Throughput::Bytes(new_bytes as u64));
group.bench_with_input(
BenchmarkId::new(fixture.name, fixture.old_text.len()),
BenchmarkId::new(fixture.name, fixture.old_file_text.len()),
&fixture,
|bench, fixture| {
bench.iter_batched(
@ -107,26 +128,168 @@ fn edit_file_tool_streaming(c: &mut Criterion) {
fn setup_harness(fixture: EditFixture) -> BenchmarkHarness {
let mut cx = init_context();
let executor = cx.executor();
let (edit_tool, thread) = block_on_executor(
let parts = block_on_executor(
&executor,
setup_edit_tool(&mut cx, fixture.old_file_text.clone()),
setup_editor_and_tool(&mut cx, fixture.old_file_text.clone()),
);
let partial_payloads = streamed_partial_payloads(&fixture.old_text, &fixture.new_text);
// Let the LSP handshake, initial parse, and first layout settle before timing.
cx.executor().run_until_parked();
let partial_payloads = streamed_partial_payloads(&fixture.edits);
let final_payload = json!({
"path": "root/src/workspace_snapshot.rs",
"edits": [{
"old_text": fixture.old_text,
"new_text": fixture.new_text,
}],
"path": FILE_PROJECT_PATH,
"edits": fixture
.edits
.iter()
.map(|edit| json!({ "old_text": edit.old_text, "new_text": edit.new_text }))
.collect::<Vec<_>>(),
});
BenchmarkHarness {
cx: Some(cx),
edit_tool: Some(edit_tool),
thread: Some(thread),
edit_tool: Some(parts.edit_tool),
thread: Some(parts.thread),
partial_payloads,
final_payload,
expected_file_text: fixture.expected_file_text,
editor: Some(parts.editor),
window: Some(parts.window),
keep_alive: parts.keep_alive,
}
}
struct HarnessParts {
edit_tool: Arc<EditFileTool>,
thread: Entity<Thread>,
editor: Entity<Editor>,
window: AnyWindowHandle,
keep_alive: Vec<Box<dyn Any>>,
}
/// Builds a project + edit tool, opens the target buffer in an editor view inside
/// a window, and attaches a fake Rust language server. This mirrors the real app:
/// the edited file is open in a pane with a language server, so each buffer edit
/// drives the editor's observer cascade (matching brackets, code actions, outline,
/// bracket colorization), a tree-sitter reparse, and an LSP `didChange` +
/// diagnostics round-trip — the costs that dominate a real agent edit.
async fn setup_editor_and_tool(cx: &mut TestAppContext, file_text: String) -> HarnessParts {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"src": {
"workspace_snapshot.rs": file_text,
},
}),
)
.await;
let project = Project::test(fs, [Path::new("/root")], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
language_registry.add(rust_lang());
let mut fake_servers = language_registry.register_fake_lsp(
"Rust",
FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
text_document_sync: Some(lsp::TextDocumentSyncCapability::Kind(
lsp::TextDocumentSyncKind::INCREMENTAL,
)),
..Default::default()
},
..Default::default()
},
);
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let action_log: Entity<ActionLog> =
thread.read_with(cx, |thread, _cx| thread.action_log().clone());
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
action_log,
language_registry,
));
// Open the same buffer the tool will edit and register it with the language
// servers so edits produce `didChange` notifications.
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(FILE_ABS_PATH, cx)
})
.await
.expect("failed to open buffer");
let lsp_handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
let fake_server = fake_servers
.next()
.await
.expect("fake language server should start");
// Publish diagnostics on every edit, mirroring a real server reacting to
// `didChange`, so the editor's diagnostics path runs per edit.
let server = fake_server.clone();
fake_server.handle_notification::<lsp::notification::DidChangeTextDocument, _>(
move |params, _cx| {
server.notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
uri: params.text_document.uri.clone(),
version: Some(params.text_document.version),
diagnostics: vec![lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 1)),
severity: Some(lsp::DiagnosticSeverity::WARNING),
message: "bench diagnostic".to_string(),
..Default::default()
}],
});
},
);
// Attach an editor view in a window and lay it out once so the viewport-gated
// observers (bracket colorization, selection highlights) have a visible range.
let window = cx.add_window(|window, cx| {
let mut editor = Editor::for_buffer(buffer.clone(), Some(project.clone()), window, cx);
editor.set_style(EditorStyle::default(), window, cx);
window.focus(&editor.focus_handle(cx), cx);
editor
});
let editor = window.root(cx).expect("window should have an editor root");
let window: AnyWindowHandle = window.into();
// Lay out and paint a real frame so the editor establishes a viewport (this
// is what makes the viewport-gated observers like bracket colorization run).
{
let mut visual_cx = gpui::VisualTestContext::from_window(window, &*cx);
visual_cx.draw(
gpui::point(gpui::px(0.0), gpui::px(0.0)),
gpui::size(gpui::px(1024.0), gpui::px(768.0)),
|_, _| editor.clone().into_any_element(),
);
}
let keep_alive: Vec<Box<dyn Any>> = vec![
Box::new(lsp_handle),
Box::new(fake_server),
Box::new(fake_servers),
Box::new(buffer),
];
HarnessParts {
edit_tool,
thread,
editor,
window,
keep_alive,
}
}
@ -135,6 +298,9 @@ fn init_context() -> TestAppContext {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
assets::Assets.load_test_fonts(cx);
theme_settings::init(theme::LoadThemes::JustBase, cx);
editor::init(cx);
SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| {
store.update_user_settings(cx, |settings| {
settings
@ -142,6 +308,7 @@ fn init_context() -> TestAppContext {
.all_languages
.defaults
.ensure_final_newline_on_save = Some(false);
settings.project.all_languages.defaults.colorize_brackets = Some(true);
});
});
@ -161,48 +328,6 @@ fn init_context() -> TestAppContext {
cx
}
async fn setup_edit_tool(
cx: &mut TestAppContext,
file_text: String,
) -> (Arc<EditFileTool>, Entity<Thread>) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"src": {
"workspace_snapshot.rs": file_text,
},
}),
)
.await;
let project = Project::test(fs, [Path::new("/root")], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
Some(model),
cx,
)
});
let action_log: Entity<ActionLog> =
thread.read_with(cx, |thread, _cx| thread.action_log().clone());
let edit_tool = Arc::new(EditFileTool::new(
project,
thread.downgrade(),
action_log,
language_registry,
));
(edit_tool, thread)
}
fn run_streamed_edit(harness: &mut BenchmarkHarness) -> EditFileToolOutput {
let (mut sender, input): (_, ToolInput<EditFileToolInput>) = ToolInput::test();
for payload in &harness.partial_payloads {
@ -247,33 +372,36 @@ fn block_on_executor<R>(executor: &BackgroundExecutor, future: impl Future<Outpu
panic!("future did not complete while running edit_file_tool benchmark");
}
fn streamed_partial_payloads(old_text: &str, new_text: &str) -> Vec<Value> {
let path = "root/src/workspace_snapshot.rs";
let mut payloads = Vec::new();
/// Builds the streamed partial payloads for a (possibly multi-edit) session,
/// mirroring how the agent reveals one edit at a time: earlier edits stay
/// complete in the array while the current edit streams its `old_text` then its
/// `new_text` in chunks.
fn streamed_partial_payloads(edits: &[EditOp]) -> Vec<Value> {
let path = FILE_PROJECT_PATH;
let mut payloads = vec![json!({ "path": path }), json!({ "path": path })];
payloads.push(json!({ "path": path }));
payloads.push(json!({ "path": path }));
for index in 0..edits.len() {
let completed: Vec<Value> = edits[..index]
.iter()
.map(|edit| json!({ "old_text": edit.old_text, "new_text": edit.new_text }))
.collect();
let edit = &edits[index];
for old_end in chunk_ends(old_text, OLD_TEXT_CHUNK_SIZE) {
payloads.push(json!({
"path": path,
"edits": [{ "old_text": &old_text[..old_end] }],
}));
}
for old_end in chunk_ends(&edit.old_text, OLD_TEXT_CHUNK_SIZE) {
let mut arr = completed.clone();
arr.push(json!({ "old_text": &edit.old_text[..old_end] }));
payloads.push(json!({ "path": path, "edits": arr }));
}
payloads.push(json!({
"path": path,
"edits": [{ "old_text": old_text, "new_text": "" }],
}));
let mut arr = completed.clone();
arr.push(json!({ "old_text": edit.old_text, "new_text": "" }));
payloads.push(json!({ "path": path, "edits": arr }));
for new_end in chunk_ends(new_text, NEW_TEXT_CHUNK_SIZE) {
payloads.push(json!({
"path": path,
"edits": [{
"old_text": old_text,
"new_text": &new_text[..new_end],
}],
}));
for new_end in chunk_ends(&edit.new_text, NEW_TEXT_CHUNK_SIZE) {
let mut arr = completed.clone();
arr.push(json!({ "old_text": edit.old_text, "new_text": &edit.new_text[..new_end] }));
payloads.push(json!({ "path": path, "edits": arr }));
}
}
payloads
@ -326,6 +454,7 @@ fn fixtures() -> Vec<EditFixture> {
EditPattern::InsertHelperBlocks { every_nth_line: 9 },
SEED + 3,
),
make_large_multi_edit_fixture("large_multi_edit", 80, 16, SEED + 4),
]
}
@ -375,11 +504,106 @@ fn make_fixture(
name,
old_file_text,
expected_file_text,
old_text,
new_text,
edits: vec![EditOp { old_text, new_text }],
}
}
fn make_large_multi_edit_fixture(
name: &'static str,
function_count: usize,
edit_count: usize,
seed: u64,
) -> EditFixture {
const HEADER_LINES: usize = 10;
const FUNCTION_LINES: usize = 12;
const FUNCTION_BODY_LINES: usize = 11;
let mut rng = StdRng::seed_from_u64(seed);
let old_lines = random_rust_module(&mut rng, function_count);
let old_file_text = old_lines.join("\n");
let step = (function_count / edit_count).max(1);
let mut picks: Vec<usize> = (0..edit_count)
.map(|k| (k * step).min(function_count - 1))
.collect();
picks.dedup();
let replacements: Vec<(usize, Vec<String>)> = picks
.iter()
.map(|&function_index| {
(
function_index,
large_function_lines(&mut rng, function_index),
)
})
.collect();
let edits = replacements
.iter()
.map(|(function_index, new_function)| {
let start = HEADER_LINES + function_index * FUNCTION_LINES;
let end = start + FUNCTION_BODY_LINES;
EditOp {
old_text: old_lines[start..end].join("\n"),
new_text: new_function.join("\n"),
}
})
.collect();
let mut new_lines = old_lines;
for (function_index, new_function) in replacements.iter().rev() {
let start = HEADER_LINES + function_index * FUNCTION_LINES;
let end = start + FUNCTION_BODY_LINES;
new_lines.splice(start..end, new_function.iter().cloned());
}
let expected_file_text = new_lines.join("\n");
EditFixture {
name,
old_file_text,
expected_file_text,
edits,
}
}
fn large_function_lines(rng: &mut StdRng, index: usize) -> Vec<String> {
let function_name = identifier(rng, index + 40_000);
let argument_name = identifier(rng, index + 41_000);
let mut lines = vec![
format!(
" pub fn {function_name}(&mut self, {argument_name}: usize) -> Result<usize> {{"
),
format!(" let mut accumulator = {argument_name};"),
];
let body_lines = rng.random_range(30..42);
for body_index in 0..body_lines {
let local_name = identifier(rng, index + 50_000 + body_index);
let multiplier = rng.random_range(2..19);
let offset = rng.random_range(1..256);
match body_index % 4 {
0 => lines.push(format!(
" let {local_name} = accumulator.saturating_mul({multiplier}).saturating_add({offset});"
)),
1 => lines.push(format!(
" accumulator = {local_name}.saturating_sub(self.version % {offset}.max(1));"
)),
2 => lines.push(format!(
" if {local_name} % {multiplier} == 0 {{ accumulator = accumulator.saturating_add({local_name}); }}"
)),
_ => lines.push(format!(
" self.buffers.insert(\"{local_name}\".to_string(), accumulator);"
)),
}
}
lines.push(" self.version = self.version.saturating_add(accumulator);".to_string());
lines.push(" Ok(accumulator)".to_string());
lines.push(" }".to_string());
lines
}
fn edit_range(lines: &[String], pattern: &EditPattern) -> std::ops::Range<usize> {
let mut range = match pattern {
EditPattern::LocalizedRewrite {

View file

@ -4069,63 +4069,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_send_retry_on_http_send_error(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.expect("thread send should start");
cx.run_until_parked();
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::HttpSend {
provider: LanguageModelProviderName::new("OpenAI"),
error: anyhow::anyhow!("response headers timed out after 10s"),
});
fake_model.end_last_completion_stream();
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Recovered!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
ThreadEvent::Stop(..) => break,
_ => {}
}
}
assert_eq!(retry_events.len(), 1);
assert!(matches!(
retry_events[0],
acp_thread::RetryStatus { attempt: 1, .. }
));
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello!
## Assistant
Recovered!
"}
)
});
}
#[gpui::test]
async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;

View file

@ -620,21 +620,14 @@ impl EditPipeline {
log::debug!("new_text_chunk: done=true, final_text='{}'", final_text);
if !final_text.is_empty() {
let char_ops = streaming_diff.push_new(&final_text);
apply_char_operations(
&char_ops,
buffer,
&original_snapshot,
&mut edit_cursor,
&context.action_log,
cx,
);
}
let remaining_ops = streaming_diff.finish();
let mut char_ops = if final_text.is_empty() {
Vec::new()
} else {
streaming_diff.push_new(&final_text)
};
char_ops.extend(streaming_diff.finish());
apply_char_operations(
&remaining_ops,
&char_ops,
buffer,
&original_snapshot,
&mut edit_cursor,
@ -902,16 +895,17 @@ fn apply_char_operations(
action_log: &Entity<ActionLog>,
cx: &mut AsyncApp,
) {
let mut edits: Vec<_> = Vec::new();
for op in ops {
match op {
CharOperation::Insert { text } => {
let anchor = snapshot.anchor_after(*edit_cursor);
agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
edits.push((anchor..anchor, text.as_str().into()));
}
CharOperation::Delete { bytes } => {
let delete_end = *edit_cursor + bytes;
let anchor_range = snapshot.anchor_range_inside(*edit_cursor..delete_end);
agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
edits.push((anchor_range, Arc::<str>::from("")));
*edit_cursor = delete_end;
}
CharOperation::Keep { bytes } => {
@ -919,6 +913,9 @@ fn apply_char_operations(
}
}
}
if !edits.is_empty() {
agent_edit_buffer(buffer, edits, action_log, cx);
}
}
fn extract_match(

View file

@ -12,15 +12,12 @@ use language_model::{
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter,
};
use open_ai::{
ReasoningEffort,
responses::{StreamResponseOptions, stream_response_with_options},
};
use open_ai::{ReasoningEffort, responses::stream_response};
use rand::RngCore as _;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::time::{SystemTime, UNIX_EPOCH};
use ui::{ConfiguredApiCard, prelude::*};
use url::form_urlencoded;
use util::ResultExt as _;
@ -38,31 +35,6 @@ const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
const CREDENTIALS_KEY: &str = "https://chatgpt.com/backend-api/codex";
const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
const CODEX_RESPONSE_HEADER_TIMEOUT: Duration = Duration::from_secs(10);
fn codex_extra_headers(
account_id: Option<&str>,
session_id: Option<&str>,
) -> Vec<(String, String)> {
let mut extra_headers: Vec<(String, String)> = vec![
("originator".into(), "zed".into()),
("OpenAI-Beta".into(), "responses=experimental".into()),
];
if let Some(id) = account_id {
if !id.is_empty() {
extra_headers.push(("ChatGPT-Account-Id".into(), id.into()));
}
}
if let Some(id) = session_id {
if !id.is_empty() {
extra_headers.push(("session-id".into(), id.into()));
}
}
extra_headers
}
#[derive(Serialize, Deserialize, Clone, Debug)]
struct CodexCredentials {
@ -500,7 +472,6 @@ impl LanguageModel for OpenAiSubscribedLanguageModel {
// The Codex backend rejects `max_output_tokens` (`Unsupported parameter`),
// unlike the public OpenAI Responses API. Pass `None` so the field is
// omitted from the serialized request body entirely.
let session_id = request.thread_id.clone();
let mut responses_request = into_open_ai_response(
request,
self.model.id(),
@ -539,24 +510,26 @@ impl LanguageModel for OpenAiSubscribedLanguageModel {
let future = cx.spawn(async move |cx| {
let creds = get_fresh_credentials(&state, &http_client, cx).await?;
let extra_headers =
codex_extra_headers(creds.account_id.as_deref(), session_id.as_deref());
let mut extra_headers: Vec<(String, String)> = vec![
("originator".into(), "zed".into()),
("OpenAI-Beta".into(), "responses=experimental".into()),
];
if let Some(ref id) = creds.account_id {
if !id.is_empty() {
extra_headers.push(("ChatGPT-Account-Id".into(), id.clone()));
}
}
let access_token = creds.access_token.clone();
let background_executor = cx.background_executor().clone();
request_limiter
.stream(async move {
stream_response_with_options(
stream_response(
http_client.as_ref(),
PROVIDER_NAME.0.as_str(),
CODEX_BASE_URL,
&access_token,
responses_request,
extra_headers,
StreamResponseOptions::response_header_timeout(
CODEX_RESPONSE_HEADER_TIMEOUT,
background_executor.timer(CODEX_RESPONSE_HEADER_TIMEOUT),
),
)
.await
.map_err(LanguageModelCompletionError::from)
@ -1135,7 +1108,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use language_model::{LanguageModelRequestMessage, Role};
use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
@ -1185,30 +1157,6 @@ mod tests {
}
}
#[test]
fn test_codex_extra_headers_include_session_id() {
assert_eq!(
codex_extra_headers(Some("account-1"), Some("thread-1")),
vec![
("originator".into(), "zed".into()),
("OpenAI-Beta".into(), "responses=experimental".into()),
("ChatGPT-Account-Id".into(), "account-1".into()),
("session-id".into(), "thread-1".into()),
]
);
}
#[test]
fn test_codex_extra_headers_omit_empty_optional_ids() {
assert_eq!(
codex_extra_headers(Some(""), Some("")),
vec![
("originator".into(), "zed".into()),
("OpenAI-Beta".into(), "responses=experimental".into()),
]
);
}
fn make_expired_credentials() -> CodexCredentials {
CodexCredentials {
access_token: "old_access".to_string(),
@ -1229,13 +1177,6 @@ mod tests {
}
}
fn make_fresh_credentials_with_account() -> CodexCredentials {
CodexCredentials {
account_id: Some("account-1".to_string()),
..make_fresh_credentials()
}
}
fn fake_token_response() -> String {
serde_json::json!({
"access_token": "fresh_access",
@ -1245,127 +1186,6 @@ mod tests {
.to_string()
}
#[gpui::test]
async fn test_stream_completion_sends_codex_session_header(cx: &mut TestAppContext) {
let captured_headers = Arc::new(Mutex::new(None::<http_client::http::HeaderMap>));
let captured_headers_clone = captured_headers.clone();
let http_client = FakeHttpClient::create(move |request| {
*captured_headers_clone.lock() = Some(request.headers().clone());
async move {
let body = r#"data: {"type":"response.completed","response":{"id":"resp_1","status":"completed"}}"#;
Ok(http_client::Response::builder()
.status(200)
.body(http_client::AsyncBody::from(format!("{body}\n\n")))?)
}
});
let state = cx.new(|_cx| State {
credentials: Some(make_fresh_credentials_with_account()),
sign_in_task: None,
refresh_task: None,
load_task: None,
credentials_provider: Arc::new(FakeCredentialsProvider::new()),
auth_generation: 0,
last_auth_error: None,
});
let model = OpenAiSubscribedLanguageModel {
id: LanguageModelId::from(ChatGptModel::Gpt55.id().to_string()),
model: ChatGptModel::Gpt55,
state,
http_client,
request_limiter: RateLimiter::new(4),
};
let request = LanguageModelRequest {
thread_id: Some("thread-1".to_string()),
prompt_id: Some("prompt-1".to_string()),
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec!["Hello".into()],
cache: false,
reasoning_details: None,
}],
..Default::default()
};
let mut stream = model
.stream_completion(request, &cx.to_async())
.await
.expect("stream should start");
stream
.next()
.await
.expect("stream should emit event")
.expect("event should parse");
let captured_headers = captured_headers
.lock()
.clone()
.expect("request headers should be captured");
assert_eq!(
captured_headers
.get("session-id")
.and_then(|value| value.to_str().ok()),
Some("thread-1")
);
assert_eq!(
captured_headers
.get("ChatGPT-Account-Id")
.and_then(|value| value.to_str().ok()),
Some("account-1")
);
}
#[gpui::test]
async fn test_stream_completion_times_out_before_codex_headers(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::create(|_request| {
futures::future::pending::<anyhow::Result<http_client::Response<AsyncBody>>>()
});
let state = cx.new(|_cx| State {
credentials: Some(make_fresh_credentials()),
sign_in_task: None,
refresh_task: None,
load_task: None,
credentials_provider: Arc::new(FakeCredentialsProvider::new()),
auth_generation: 0,
last_auth_error: None,
});
let model = OpenAiSubscribedLanguageModel {
id: LanguageModelId::from(ChatGptModel::Gpt55.id().to_string()),
model: ChatGptModel::Gpt55,
state,
http_client,
request_limiter: RateLimiter::new(4),
};
let request = LanguageModelRequest {
thread_id: Some("thread-1".to_string()),
prompt_id: Some("prompt-1".to_string()),
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec!["Hello".into()],
cache: false,
reasoning_details: None,
}],
..Default::default()
};
let stream_completion = model.stream_completion(request, &cx.to_async());
cx.run_until_parked();
cx.executor().advance_clock(CODEX_RESPONSE_HEADER_TIMEOUT);
let error = match stream_completion.await {
Ok(_) => panic!("stream should time out before headers arrive"),
Err(error) => error,
};
assert!(matches!(
error,
LanguageModelCompletionError::HttpSend { provider, .. }
if provider == PROVIDER_NAME
));
}
#[gpui::test]
async fn test_concurrent_refresh_deduplicates(cx: &mut TestAppContext) {
let refresh_count = Arc::new(AtomicUsize::new(0));

View file

@ -316,12 +316,6 @@ fn map_open_ai_error(error: open_ai::RequestError) -> LanguageModelCompletionErr
retry_after,
)
}
open_ai::RequestError::ResponseHeaderTimeout { timeout, .. } => {
LanguageModelCompletionError::HttpSend {
provider: PROVIDER_NAME,
error: anyhow::anyhow!("response headers timed out after {timeout:?}"),
}
}
open_ai::RequestError::Other(error) => LanguageModelCompletionError::Other(error),
}
}

View file

@ -11,7 +11,7 @@ use http_client::{
pub use language_model_core::ReasoningEffort;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use std::{convert::TryFrom, future::Future};
use strum::EnumIter;
use thiserror::Error;
@ -684,8 +684,6 @@ pub enum RequestError {
body: String,
headers: HeaderMap<HeaderValue>,
},
#[error("response headers from {provider}'s API timed out after {timeout:?}")]
ResponseHeaderTimeout { provider: String, timeout: Duration },
#[error(transparent)]
Other(#[from] anyhow::Error),
}
@ -905,10 +903,6 @@ impl From<RequestError> for language_model_core::LanguageModelCompletionError {
Self::from_http_status(provider.into(), status_code, body, retry_after)
}
RequestError::ResponseHeaderTimeout { provider, timeout } => Self::HttpSend {
provider: provider.into(),
error: anyhow!("response headers timed out after {timeout:?}"),
},
RequestError::Other(e) => Self::Other(e),
}
}

View file

@ -1,266 +1,11 @@
use anyhow::{Result, anyhow};
use futures::{
AsyncBufReadExt, AsyncReadExt, FutureExt, StreamExt, future::BoxFuture, io::BufReader,
stream::BoxStream,
};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{future::Future, time::Duration};
use crate::{ReasoningEffort, RequestError, Role, ServiceTier, ToolChoice};
#[derive(Default)]
pub struct StreamResponseOptions {
response_header_timeout: Option<(Duration, BoxFuture<'static, ()>)>,
}
impl StreamResponseOptions {
pub fn response_header_timeout(
timeout: Duration,
timer: impl Future<Output = ()> + Send + 'static,
) -> Self {
Self {
response_header_timeout: Some((timeout, timer.boxed())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{FutureExt, StreamExt, future};
use http_client::{
AsyncBody, HttpClient, Request as HttpRequest, Response as HttpResponse, Url,
};
use std::{
io::{Cursor, Read},
pin::Pin,
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll, Waker},
};
struct TestHttpClient {
handler: Arc<
dyn Fn(
HttpRequest<AsyncBody>,
) -> BoxFuture<'static, anyhow::Result<HttpResponse<AsyncBody>>>
+ Send
+ Sync,
>,
}
impl TestHttpClient {
fn new<F>(handler: F) -> Self
where
F: Fn(
HttpRequest<AsyncBody>,
) -> BoxFuture<'static, anyhow::Result<HttpResponse<AsyncBody>>>
+ Send
+ Sync
+ 'static,
{
Self {
handler: Arc::new(handler),
}
}
}
impl HttpClient for TestHttpClient {
fn user_agent(&self) -> Option<&http_client::http::HeaderValue> {
None
}
fn proxy(&self) -> Option<&Url> {
None
}
fn send(
&self,
request: HttpRequest<AsyncBody>,
) -> BoxFuture<'static, anyhow::Result<HttpResponse<AsyncBody>>> {
(self.handler)(request)
}
}
struct DelayedBody {
state: Arc<DelayedBodyState>,
bytes: Cursor<Vec<u8>>,
}
struct DelayedBodyState {
released: AtomicBool,
waker: Mutex<Option<Waker>>,
}
struct DelayedBodyHandle {
state: Arc<DelayedBodyState>,
}
impl DelayedBody {
fn new(bytes: Vec<u8>) -> (Self, DelayedBodyHandle) {
let state = Arc::new(DelayedBodyState {
released: AtomicBool::new(false),
waker: Mutex::new(None),
});
(
Self {
state: state.clone(),
bytes: Cursor::new(bytes),
},
DelayedBodyHandle { state },
)
}
}
impl DelayedBodyHandle {
fn release(&self) {
self.state.released.store(true, Ordering::SeqCst);
if let Some(waker) = self.state.waker.lock().expect("lock poisoned").take() {
waker.wake();
}
}
}
impl futures::AsyncRead for DelayedBody {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buffer: &mut [u8],
) -> Poll<std::io::Result<usize>> {
if !self.state.released.load(Ordering::SeqCst) {
self.state
.waker
.lock()
.expect("lock poisoned")
.replace(cx.waker().clone());
return Poll::Pending;
}
Poll::Ready(self.bytes.read(buffer))
}
}
fn test_request() -> Request {
Request {
model: "gpt-test".into(),
instructions: None,
input: Vec::new(),
include: Vec::new(),
stream: true,
temperature: None,
top_p: None,
max_output_tokens: None,
parallel_tool_calls: None,
tool_choice: None,
tools: Vec::new(),
prompt_cache_key: None,
reasoning: None,
store: None,
service_tier: None,
}
}
#[test]
fn stream_response_times_out_before_headers() {
futures::executor::block_on(async {
let client = TestHttpClient::new(|_| {
future::pending::<anyhow::Result<HttpResponse<AsyncBody>>>().boxed()
});
let result = stream_response_with_options(
&client,
"Test Provider",
"https://api.test/v1",
"test-key",
test_request(),
Vec::new(),
StreamResponseOptions::response_header_timeout(
Duration::from_secs(10),
future::ready(()),
),
)
.await;
assert!(matches!(
result,
Err(RequestError::ResponseHeaderTimeout {
provider,
timeout
}) if provider == "Test Provider" && timeout == Duration::from_secs(10)
));
});
}
#[test]
fn stream_response_does_not_timeout_after_headers_arrive() {
futures::executor::block_on(async {
let body = r#"data: {"type":"response.completed","response":{"id":"resp_1","status":"completed"}}"#;
let (delayed_body, delayed_body_handle) =
DelayedBody::new(format!("{body}\n\n").into_bytes());
let delayed_body = Mutex::new(Some(delayed_body));
let client = TestHttpClient::new(move |_| {
let delayed_body = delayed_body
.lock()
.expect("lock poisoned")
.take()
.expect("test sends only one request");
async {
Ok(HttpResponse::builder()
.status(200)
.body(AsyncBody::from_reader(delayed_body))?)
}
.boxed()
});
let (timeout_tx, timeout_rx) = futures::channel::oneshot::channel::<()>();
let mut stream = stream_response_with_options(
&client,
"Test Provider",
"https://api.test/v1",
"test-key",
test_request(),
Vec::new(),
StreamResponseOptions::response_header_timeout(
Duration::from_secs(10),
async move {
assert!(
timeout_rx.await.is_ok(),
"timer should be dropped after headers arrive"
);
},
),
)
.await
.expect("headers should arrive before timeout");
assert!(
timeout_tx.send(()).is_err(),
"timeout future should be dropped after headers arrive"
);
assert!(
stream.next().now_or_never().is_none(),
"stream should wait for delayed body bytes"
);
delayed_body_handle.release();
let event = stream
.next()
.await
.expect("stream should produce an event")
.expect("event should parse");
assert!(matches!(event, StreamEvent::Completed { .. }));
});
}
}
#[derive(Serialize, Debug)]
pub struct Request {
pub model: String,
@ -695,27 +440,6 @@ pub async fn stream_response(
api_key: &str,
request: Request,
extra_headers: Vec<(String, String)>,
) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
stream_response_with_options(
client,
provider_name,
api_url,
api_key,
request,
extra_headers,
StreamResponseOptions::default(),
)
.await
}
pub async fn stream_response_with_options(
client: &dyn HttpClient,
provider_name: &str,
api_url: &str,
api_key: &str,
request: Request,
extra_headers: Vec<(String, String)>,
options: StreamResponseOptions,
) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
let uri = format!("{api_url}/responses");
let mut request_builder = HttpRequest::builder()
@ -734,24 +458,7 @@ pub async fn stream_response_with_options(
))
.map_err(|e| RequestError::Other(e.into()))?;
let mut response = if let Some((timeout, timer)) = options.response_header_timeout {
let send_request = client.send(request).fuse();
let timer = timer.fuse();
futures::pin_mut!(send_request);
futures::pin_mut!(timer);
futures::select! {
response = send_request => response?,
() = timer => {
return Err(RequestError::ResponseHeaderTimeout {
provider: provider_name.to_owned(),
timeout,
});
}
}
} else {
client.send(request).await?
};
let mut response = client.send(request).await?;
if response.status().is_success() {
if is_streaming {
let reader = BufReader::new(response.into_body());