Add ChatGPT subscription provider via OAuth 2.0 PKCE (#53166) (cherry-pick to preview) (#56807)
Some checks are pending
run_tests / orchestrate (push) Waiting to run
run_tests / check_style (push) Waiting to run
run_tests / clippy_windows (push) Blocked by required conditions
run_tests / clippy_linux (push) Blocked by required conditions
run_tests / clippy_mac (push) Blocked by required conditions
run_tests / clippy_mac_x86_64 (push) Blocked by required conditions
run_tests / run_tests_windows (push) Blocked by required conditions
run_tests / run_tests_linux (push) Blocked by required conditions
run_tests / run_tests_mac (push) Blocked by required conditions
run_tests / doctests (push) Blocked by required conditions
run_tests / check_workspace_binaries (push) Blocked by required conditions
run_tests / build_visual_tests_binary (push) Blocked by required conditions
run_tests / check_wasm (push) Blocked by required conditions
run_tests / check_dependencies (push) Blocked by required conditions
run_tests / check_docs (push) Blocked by required conditions
run_tests / check_licenses (push) Blocked by required conditions
run_tests / check_scripts (push) Blocked by required conditions
run_tests / check_postgres_and_protobuf_migrations (push) Blocked by required conditions
run_tests / extension_tests (push) Blocked by required conditions
run_tests / tests_pass (push) Blocked by required conditions

Cherry-pick of #53166 to preview

----
Adds a new language model provider that lets users authenticate with
their ChatGPT Plus/Pro subscription and use OpenAI models
(codex-mini-latest, o4-mini, o3) directly in the Zed agent — without
needing a separate API key.

## How it works

1. **OAuth 2.0 + PKCE sign-in**: Uses OpenAI's official Codex CLI client
ID to run an authorization code flow. A local HTTP server on
`127.0.0.1:1455` captures the callback, exchanges the code for tokens,
and stores them in the system keychain.

2. **Token refresh**: Access tokens are automatically refreshed when
they're within 5 minutes of expiry, using the stored refresh token.

3. **Responses API**: Requests go to
`https://chatgpt.com/backend-api/codex/responses` using the existing
`open_ai::responses` client (Responses API format, not Chat Completions
which was deprecated for this endpoint in Feb 2026).

4. **Required headers**: `originator: zed`, `OpenAI-Beta:
responses=experimental`, `ChatGPT-Account-Id` (extracted from JWT),
`store: false` in the body.

## Files changed

- `crates/open_ai/src/responses.rs`: Add `store: Option<bool>` field to
`Request`; add `extra_headers` param to `stream_response` for
per-provider header injection
- `crates/language_models/src/provider/openai_subscribed.rs`: New
provider (sign-in UI, OAuth flow, token storage/refresh, model list)
- `crates/language_models/src/provider/open_ai.rs`,
`open_ai_compatible.rs`, `opencode.rs`: Pass `vec![]` for new
`extra_headers` param
- `crates/language_models/src/language_models.rs`: Register the new
provider
- `crates/language_models/Cargo.toml`: Add `rand` and `sha2` deps for
PKCE

## Open questions / known gaps

- [ ] **Terms of service**: Usage appears to be within OpenAI's ToS
(interactive use via their official CLI client ID), but needs legal
sign-off before shipping
- [ ] **Redirect URI**: Currently `http://localhost:1455/auth/callback`
— may need to match exactly what OpenAI's Codex CLI uses
- [ ] **UI polish**: The sign-in card is functional but minimal; needs
design review
- [ ] **Error messages**: OAuth error responses from the callback URL
aren't surfaced to the user yet
- [ ] **`o3` availability**: o3 may require a higher subscription tier;
consider gating it

## Testing

Sign-in flow was designed to match the Copilot Chat provider pattern.
Manual testing against the live OAuth endpoint is needed.

Release Notes:

- Added ChatGPT subscription provider, allowing users to use their
ChatGPT Plus/Pro subscription with the Zed agent

---------

Co-authored-by: Zed Zippy
<234243425+zed-zippy[bot]@users.noreply.github.com>
Co-authored-by: Richard Feldman <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Co-authored-by: morgankrey <morgan@zed.dev>
Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>
Co-authored-by: Richard Feldman <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
zed-zippy[bot] 2026-05-14 22:03:36 +00:00 committed by GitHub
parent ba3d506bd6
commit 28529b423e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2058 additions and 159 deletions

19
Cargo.lock generated
View file

@ -3597,6 +3597,7 @@ dependencies = [
"http_client",
"log",
"net",
"oauth_callback_server",
"parking_lot",
"pollster 0.4.0",
"postage",
@ -3608,7 +3609,6 @@ dependencies = [
"sha2",
"slotmap",
"tempfile",
"tiny_http",
"url",
"util",
]
@ -9683,20 +9683,26 @@ dependencies = [
"log",
"menu",
"mistral",
"oauth_callback_server",
"ollama",
"open_ai",
"open_router",
"opencode",
"parking_lot",
"pretty_assertions",
"rand 0.9.4",
"release_channel",
"schemars 1.0.4",
"serde",
"serde_json",
"settings",
"sha2",
"smol",
"strum 0.27.2",
"tokio",
"ui",
"ui_input",
"url",
"util",
"x_ai",
]
@ -11509,6 +11515,17 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "oauth_callback_server"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.32",
"log",
"tiny_http",
"url",
]
[[package]]
name = "objc"
version = "0.2.7"

View file

@ -140,6 +140,7 @@ members = [
"crates/net",
"crates/node_runtime",
"crates/notifications",
"crates/oauth_callback_server",
"crates/ollama",
"crates/onboarding",
"crates/opencode",
@ -399,6 +400,7 @@ nc = { path = "crates/nc" }
net = { path = "crates/net" }
node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
oauth_callback_server = { path = "crates/oauth_callback_server" }
ollama = { path = "crates/ollama" }
onboarding = { path = "crates/onboarding" }
opencode = { path = "crates/opencode" }

View file

@ -26,6 +26,7 @@ gpui.workspace = true
http_client = { workspace = true, features = ["test-support"] }
log.workspace = true
net.workspace = true
oauth_callback_server.workspace = true
parking_lot.workspace = true
rand.workspace = true
postage.workspace = true
@ -36,7 +37,6 @@ settings.workspace = true
sha2.workspace = true
slotmap.workspace = true
tempfile.workspace = true
tiny_http.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true

View file

@ -20,18 +20,18 @@ use anyhow::{Context as _, Result, anyhow, bail};
use async_trait::async_trait;
use base64::Engine as _;
use futures::AsyncReadExt as _;
use futures::FutureExt as _;
use futures::channel::mpsc;
use futures::future::BoxFuture;
use http_client::{AsyncBody, HttpClient, Request};
use parking_lot::Mutex as SyncMutex;
use rand::Rng as _;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use url::Url;
use util::ResultExt as _;
/// The CIMD URL where Zed's OAuth client metadata document is hosted.
pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
@ -992,58 +992,14 @@ impl OAuthCallback {
/// Parse the query string from a callback URL like
/// `http://127.0.0.1:<port>/callback?code=...&state=...`.
pub fn parse_query(query: &str) -> Result<Self> {
let mut code: Option<String> = None;
let mut state: Option<String> = None;
let mut error: Option<String> = None;
let mut error_description: Option<String> = None;
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
match key.as_ref() {
"code" => {
if !value.is_empty() {
code = Some(value.into_owned());
}
}
"state" => {
if !value.is_empty() {
state = Some(value.into_owned());
}
}
"error" => {
if !value.is_empty() {
error = Some(value.into_owned());
}
}
"error_description" => {
if !value.is_empty() {
error_description = Some(value.into_owned());
}
}
_ => {}
}
}
// Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
// checking for missing code/state.
if let Some(error_code) = error {
bail!(
"OAuth authorization failed: {} ({})",
error_code,
error_description.as_deref().unwrap_or("no description")
);
}
let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
Ok(Self { code, state })
let params = oauth_callback_server::OAuthCallbackParams::parse_query(query)?;
Ok(Self {
code: params.code,
state: params.state,
})
}
}
/// How long to wait for the browser to complete the OAuth flow before giving
/// up and releasing the loopback port.
const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
/// Start a loopback HTTP server to receive the OAuth authorization callback.
///
/// Binds to an ephemeral loopback port for each flow.
@ -1056,104 +1012,24 @@ const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
/// contains `code` and `state` query parameters, responds with a minimal
/// HTML page telling the user they can close the tab, and shuts down.
///
/// The callback server shuts down when the returned oneshot receiver is dropped
/// (e.g. because the authentication task was cancelled), or after a timeout
/// ([CALLBACK_TIMEOUT]).
pub async fn start_callback_server() -> Result<(
String,
futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
)> {
let server = tiny_http::Server::http("127.0.0.1:0")
.map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
let port = server
.server_addr()
.to_ip()
.context("server not bound to a TCP address")?
.port();
let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
let (tx, rx) = futures::channel::oneshot::channel();
// `tiny_http` is blocking, so we run it on a background thread.
// The `recv_timeout` loop lets us check for cancellation (the receiver
// being dropped) and enforce an overall timeout.
std::thread::spawn(move || {
let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
loop {
if tx.is_canceled() {
return;
}
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return;
}
let timeout = remaining.min(Duration::from_millis(500));
let Some(request) = (match server.recv_timeout(timeout) {
Ok(req) => req,
Err(_) => {
let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
return;
}
}) else {
// Timeout with no request — loop back and check cancellation.
continue;
};
let result = handle_callback_request(&request);
let (status_code, body) = match &result {
Ok(_) => (
200,
"<html><body><h1>Authorization successful</h1>\
<p>You can close this tab and return to Zed.</p></body></html>",
),
Err(err) => {
log::error!("OAuth callback error: {}", err);
(
400,
"<html><body><h1>Authorization failed</h1>\
<p>Something went wrong. Please try again from Zed.</p></body></html>",
)
}
};
let response = tiny_http::Response::from_string(body)
.with_status_code(status_code)
.with_header(
tiny_http::Header::from_str("Content-Type: text/html")
.expect("failed to construct response header"),
)
.with_header(
tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
.expect("failed to construct response header"),
);
request.respond(response).log_err();
let _ = tx.send(result);
return;
/// The callback server shuts down when the returned future is dropped (e.g.
/// because the authentication task was cancelled), or after a timeout.
pub fn start_callback_server() -> Result<(String, BoxFuture<'static, Result<OAuthCallback>>)> {
let (redirect_uri, rx) = oauth_callback_server::start_oauth_callback_server()?;
let future = async move {
match rx.await {
Ok(Ok(params)) => Ok(OAuthCallback {
code: params.code,
state: params.state,
}),
Ok(Err(e)) => Err(e),
Err(_) => Err(anyhow!(
"OAuth callback server was shut down before receiving a response"
)),
}
});
Ok((redirect_uri, rx))
}
/// Extract the `code` and `state` query parameters from an OAuth callback
/// request to `/callback`.
fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
let url = Url::parse(&format!("http://localhost{}", request.url()))
.context("malformed callback request URL")?;
if url.path() != "/callback" {
bail!("unexpected path in OAuth callback: {}", url.path());
}
let query = url
.query()
.ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
OAuthCallback::parse_query(query)
.boxed();
Ok((redirect_uri, future))
}
// -- JSON fetch helper -------------------------------------------------------

View file

@ -47,19 +47,24 @@ lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
oauth_callback_server.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
opencode = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
rand.workspace = true
release_channel.workspace = true
schemars.workspace = true
sha2.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strum.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
ui_input.workspace = true
url.workspace = true
util.workspace = true
x_ai = { workspace = true, features = ["schemars"] }
@ -71,5 +76,6 @@ feature_flags.workspace = true
gpui = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
parking_lot.workspace = true
pretty_assertions.workspace = true
settings = { workspace = true, features = ["test-support"] }

View file

@ -27,6 +27,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
use crate::provider::openai_subscribed::OpenAiSubscribedProvider;
use crate::provider::opencode::OpenCodeLanguageModelProvider;
use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
@ -324,10 +325,18 @@ fn register_language_model_providers(
registry.register_provider(
Arc::new(OpenCodeLanguageModelProvider::new(
client.http_client(),
credentials_provider,
credentials_provider.clone(),
cx,
)),
cx,
);
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
registry.register_provider(
Arc::new(OpenAiSubscribedProvider::new(
client.http_client(),
credentials_provider,
cx,
)),
cx,
);
}

View file

@ -10,6 +10,7 @@ pub mod ollama;
pub mod open_ai;
pub mod open_ai_compatible;
pub mod open_router;
pub mod openai_subscribed;
pub mod opencode;
pub mod vercel_ai_gateway;

View file

@ -383,6 +383,7 @@ impl OpenAiLanguageModel {
&api_url,
&api_key,
request,
vec![],
);
let response = request.await?;
Ok(response)

View file

@ -289,6 +289,7 @@ impl OpenAiCompatibleLanguageModel {
&api_url,
&api_key,
request,
vec![],
);
let response = request.await?;
Ok(response)

File diff suppressed because it is too large Load diff

View file

@ -479,6 +479,7 @@ impl OpenCodeLanguageModel {
&api_url,
&api_key,
request,
vec![],
);
let response = request.await?;
Ok(response)

View file

@ -0,0 +1,23 @@
[package]
name = "oauth_callback_server"
version = "0.1.0"
edition.workspace = true
publish = false
license = "Apache-2.0"
description = "Loopback OAuth 2.0 callback server and shared HTML response page for Zed sign-in flows"
[lints]
workspace = true
[lib]
path = "src/oauth_callback_server.rs"
doctest = false
[dependencies]
anyhow.workspace = true
[target.'cfg(not(target_family = "wasm"))'.dependencies]
futures.workspace = true
log.workspace = true
tiny_http.workspace = true
url.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-APACHE

View file

@ -0,0 +1,493 @@
//! Loopback OAuth 2.0 callback server and shared HTML response page.
//!
//! Used by Zed's OAuth-based sign-in flows (e.g. MCP servers, ChatGPT
//! Subscription) to receive the authorization code redirect from the user's
//! browser. The HTML response page rendered to the browser is kept alongside
//! the server so all OAuth callback presentation lives in one place.
/// Generate a styled HTML page for OAuth callback responses.
///
/// Returns a complete HTML document (no HTTP headers) with a centered card
/// layout styled to match Zed's dark theme. The `title` is rendered as a
/// heading and `message` as body text below it.
///
/// When `is_error` is true, a red X icon is shown instead of the green
/// checkmark.
pub fn oauth_callback_page(title: &str, message: &str, is_error: bool) -> String {
let title = html_escape(title);
let message = html_escape(message);
let (icon_bg, icon_svg) = if is_error {
(
"#f38ba8",
r#"<svg viewBox="0 0 24 24"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>"#,
)
} else {
(
"#a6e3a1",
r#"<svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>"#,
)
};
format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>{title} Zed</title>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: #1e1e2e;
color: #cdd6f4;
display: flex;
align-items: center;
justify-content: center;
min-height: 100vh;
padding: 1rem;
}}
.card {{
background: #313244;
border-radius: 12px;
padding: 2.5rem;
max-width: 420px;
width: 100%;
text-align: center;
box-shadow: 0 4px 24px rgba(0, 0, 0, 0.3);
}}
.icon {{
width: 48px;
height: 48px;
margin: 0 auto 1.5rem;
background: {icon_bg};
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
}}
.icon svg {{
width: 24px;
height: 24px;
stroke: #1e1e2e;
stroke-width: 3;
fill: none;
}}
h1 {{
font-size: 1.25rem;
font-weight: 600;
margin-bottom: 0.75rem;
color: #cdd6f4;
}}
p {{
font-size: 0.925rem;
line-height: 1.5;
color: #a6adc8;
}}
.brand {{
margin-top: 1.5rem;
font-size: 0.8rem;
color: #585b70;
letter-spacing: 0.05em;
}}
</style>
</head>
<body>
<div class="card">
<div class="icon">
{icon_svg}
</div>
<h1>{title}</h1>
<p>{message}</p>
<div class="brand">Zed</div>
</div>
</body>
</html>"#,
title = title,
message = message,
icon_bg = icon_bg,
icon_svg = icon_svg,
)
}
fn html_escape(input: &str) -> String {
let mut output = String::with_capacity(input.len());
for ch in input.chars() {
match ch {
'&' => output.push_str("&amp;"),
'<' => output.push_str("&lt;"),
'>' => output.push_str("&gt;"),
'"' => output.push_str("&quot;"),
'\'' => output.push_str("&#x27;"),
_ => output.push(ch),
}
}
output
}
#[cfg(not(target_family = "wasm"))]
mod server {
use super::oauth_callback_page;
use anyhow::{Context as _, Result, anyhow};
use std::str::FromStr;
use std::time::Duration;
use url::Url;
/// Parsed OAuth callback parameters from the authorization server redirect.
pub struct OAuthCallbackParams {
pub code: String,
pub state: String,
}
/// Configuration for the loopback OAuth callback server.
///
/// OAuth servers compare `redirect_uri` against a per-client allow-list using
/// exact string matching (RFC 6749 §3.1.2), so the `host`, `preferred_port`,
/// and `path` here must match what's registered for the OAuth client_id.
#[derive(Clone, Copy)]
pub struct OAuthCallbackServerConfig {
/// Host portion of the redirect URI (typically `127.0.0.1` or `localhost`).
pub host: &'static str,
/// Preferred port. Use `0` for an OS-assigned ephemeral port.
pub preferred_port: u16,
/// Optional fallback port if `preferred_port` is unavailable. Only used
/// when `preferred_port` is non-zero.
pub fallback_port: Option<u16>,
/// Callback path on the redirect URI (e.g. `/callback`, `/auth/callback`).
pub path: &'static str,
}
impl Default for OAuthCallbackServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1",
preferred_port: 0,
fallback_port: None,
path: "/callback",
}
}
}
impl OAuthCallbackParams {
/// Parse the query string from a callback URL like
/// `http://127.0.0.1:<port>/callback?code=...&state=...`.
pub fn parse_query(query: &str) -> Result<Self> {
let mut code: Option<String> = None;
let mut state: Option<String> = None;
let mut error: Option<String> = None;
let mut error_description: Option<String> = None;
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
match key.as_ref() {
"code" => {
if !value.is_empty() {
code = Some(value.into_owned());
}
}
"state" => {
if !value.is_empty() {
state = Some(value.into_owned());
}
}
"error" => {
if !value.is_empty() {
error = Some(value.into_owned());
}
}
"error_description" => {
if !value.is_empty() {
error_description = Some(value.into_owned());
}
}
_ => {}
}
}
if let Some(error_code) = error {
anyhow::bail!(
"OAuth authorization failed: {} ({})",
error_code,
error_description.as_deref().unwrap_or("no description")
);
}
let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
let state =
state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
Ok(Self { code, state })
}
}
/// How long to wait for the browser to complete the OAuth flow before giving
/// up and releasing the loopback port.
const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
/// Start a loopback HTTP server to receive the OAuth authorization callback.
///
/// Binds to an ephemeral loopback port. Returns `(redirect_uri, callback_future)`.
/// The caller should use the redirect URI in the authorization request, open
/// the browser, then await the future to receive the callback.
pub fn start_oauth_callback_server() -> Result<(
String,
futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
)> {
start_oauth_callback_server_with_config(OAuthCallbackServerConfig::default())
}
/// Start a loopback HTTP server with custom host/port/path.
///
/// Use this when the OAuth client requires a specific redirect URI that the
/// default ephemeral-port `http://127.0.0.1:<port>/callback` doesn't match.
pub fn start_oauth_callback_server_with_config(
config: OAuthCallbackServerConfig,
) -> Result<(
String,
futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
)> {
let server = bind_callback_server(&config)?;
let port = server
.server_addr()
.to_ip()
.ok_or_else(|| anyhow!("server not bound to a TCP address"))?
.port();
let redirect_uri = format!("http://{}:{}{}", config.host, port, config.path);
let expected_path = config.path;
let (tx, rx) = futures::channel::oneshot::channel();
std::thread::spawn(move || {
let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT;
loop {
if tx.is_canceled() {
return;
}
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return;
}
let timeout = remaining.min(Duration::from_millis(500));
let Some(request) = (match server.recv_timeout(timeout) {
Ok(req) => req,
Err(_) => {
let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
return;
}
}) else {
continue;
};
let raw_url = request.url().to_string();
let raw_path = raw_url.split('?').next().unwrap_or(&raw_url);
if raw_path == CANCEL_PATH {
let response = tiny_http::Response::from_string("Cancelled")
.with_status_code(200)
.with_header(
tiny_http::Header::from_str("Content-Type: text/plain")
.expect("failed to construct response header"),
)
.with_header(
tiny_http::Header::from_str("Connection: close")
.expect("failed to construct response header"),
);
if let Err(err) = request.respond(response) {
log::error!("Failed to send OAuth cancel response: {}", err);
}
let _ = tx.send(Err(anyhow!(
"OAuth callback server was cancelled by another sign-in attempt"
)));
return;
}
let result = handle_oauth_callback_request(&request, expected_path);
let (status_code, body) = match &result {
Ok(_) => (
200,
oauth_callback_page(
"Authorization Successful",
"You can close this tab and return to Zed.",
false,
),
),
Err(err) => {
log::error!("OAuth callback error: {}", err);
(
400,
oauth_callback_page(
"Authorization Failed",
"Something went wrong. Please try again from Zed.",
true,
),
)
}
};
let response = tiny_http::Response::from_string(body)
.with_status_code(status_code)
.with_header(
tiny_http::Header::from_str("Content-Type: text/html")
.expect("failed to construct response header"),
)
.with_header(
tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
.expect("failed to construct response header"),
);
if let Err(err) = request.respond(response) {
log::error!("Failed to send OAuth callback response: {}", err);
}
let _ = tx.send(result);
return;
}
});
Ok((redirect_uri, rx))
}
fn handle_oauth_callback_request(
request: &tiny_http::Request,
expected_path: &str,
) -> Result<OAuthCallbackParams> {
let url = Url::parse(&format!("http://localhost{}", request.url()))
.context("malformed callback request URL")?;
if url.path() != expected_path {
anyhow::bail!("unexpected path in OAuth callback: {}", url.path());
}
let query = url
.query()
.ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
OAuthCallbackParams::parse_query(query)
}
/// Callback path reserved for evicting a previously-running OAuth callback
/// server bound to the same port. Always handled, regardless of `config.path`.
const CANCEL_PATH: &str = "/cancel";
const BIND_MAX_ATTEMPTS: u32 = 10;
const BIND_RETRY_DELAY: Duration = Duration::from_millis(200);
const CANCEL_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
fn bind_callback_server(config: &OAuthCallbackServerConfig) -> Result<tiny_http::Server> {
// Ephemeral ports always succeed; skip the cancel-retry dance entirely.
if config.preferred_port == 0 {
let addr = format!("{}:0", config.host);
return tiny_http::Server::http(&addr).map_err(|err| {
anyhow!(err).context(format!(
"Failed to bind loopback listener for OAuth callback on {addr}"
))
});
}
match try_bind_with_cancel(config.host, config.preferred_port) {
Ok(server) => Ok(server),
Err(primary_err) => {
let Some(fallback_port) = config.fallback_port else {
return Err(primary_err.context(format!(
"Failed to bind loopback listener for OAuth callback on {}:{}",
config.host, config.preferred_port,
)));
};
log::warn!(
"OAuth callback port {}:{} unavailable; falling back to port {}",
config.host,
config.preferred_port,
fallback_port,
);
try_bind_with_cancel(config.host, fallback_port).map_err(|fallback_err| {
fallback_err.context(format!(
"Failed to bind loopback listener for OAuth callback on {}:{} or {}:{}",
config.host, config.preferred_port, config.host, fallback_port,
))
})
}
}
}
/// Attempts to bind to a fixed `host:port`. On `AddrInUse`, sends a single
/// `GET /cancel` to the existing listener (to evict a previous OAuth flow
/// from this or a compatible client) and retries.
fn try_bind_with_cancel(host: &'static str, port: u16) -> Result<tiny_http::Server> {
let addr = format!("{host}:{port}");
let mut cancel_attempted = false;
let mut last_err: Option<anyhow::Error> = None;
for _ in 0..BIND_MAX_ATTEMPTS {
match tiny_http::Server::http(&addr) {
Ok(server) => return Ok(server),
Err(err) => {
let is_addr_in_use = err
.downcast_ref::<std::io::Error>()
.map(|io_err| io_err.kind() == std::io::ErrorKind::AddrInUse)
.unwrap_or(false);
if !is_addr_in_use {
return Err(anyhow!(err).context(format!(
"Failed to bind loopback listener for OAuth callback on {addr}"
)));
}
if !cancel_attempted {
cancel_attempted = true;
if let Err(cancel_err) = send_cancel_request(host, port) {
log::warn!(
"Failed to cancel previous OAuth callback server on {addr}: {cancel_err}"
);
}
}
last_err = Some(anyhow!(err));
std::thread::sleep(BIND_RETRY_DELAY);
}
}
}
Err(last_err
.unwrap_or_else(|| anyhow!("unknown bind error"))
.context(format!(
"OAuth callback port {addr} remained in use after {BIND_MAX_ATTEMPTS} attempts"
)))
}
/// Sends `GET /cancel` to a listener on `host:port`, asking it to shut down.
///
/// Best-effort: errors here are surfaced to the caller for logging but do
/// not block the subsequent rebind attempt.
fn send_cancel_request(host: &str, port: u16) -> std::io::Result<()> {
use std::io::{Read as _, Write as _};
use std::net::{TcpStream, ToSocketAddrs as _};
let addr = format!("{host}:{port}")
.to_socket_addrs()?
.next()
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("could not resolve {host}:{port}"),
)
})?;
let mut stream = TcpStream::connect_timeout(&addr, CANCEL_REQUEST_TIMEOUT)?;
stream.set_read_timeout(Some(CANCEL_REQUEST_TIMEOUT))?;
stream.set_write_timeout(Some(CANCEL_REQUEST_TIMEOUT))?;
stream.write_all(b"GET /cancel HTTP/1.1\r\n")?;
stream.write_all(format!("Host: {host}:{port}\r\n").as_bytes())?;
stream.write_all(b"Connection: close\r\n\r\n")?;
// Drain the response so the server can close cleanly. We don't care
// about the body; errors here are harmless.
let mut buf = [0u8; 64];
let _ = stream.read(&mut buf);
Ok(())
}
}
#[cfg(not(target_family = "wasm"))]
pub use server::{
OAuthCallbackParams, OAuthCallbackServerConfig, start_oauth_callback_server,
start_oauth_callback_server_with_config,
};

View file

@ -259,8 +259,9 @@ pub fn into_open_ai_response(
ResponseRequest {
model: model_id.into(),
instructions: None,
input: input_items,
store: false,
store: Some(false),
include,
stream,
temperature,

View file

@ -9,9 +9,10 @@ use crate::{ReasoningEffort, RequestError, Role, ToolChoice};
#[derive(Serialize, Debug)]
pub struct Request {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub input: Vec<ResponseInputItem>,
pub store: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub include: Vec<ResponseIncludable>,
#[serde(default)]
@ -32,6 +33,8 @@ pub struct Request {
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ReasoningConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
@ -395,13 +398,17 @@ pub async fn stream_response(
api_url: &str,
api_key: &str,
request: Request,
extra_headers: Vec<(String, String)>,
) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
let uri = format!("{api_url}/responses");
let request_builder = HttpRequest::builder()
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key.trim()));
for (name, value) in &extra_headers {
request_builder = request_builder.header(name.as_str(), value.as_str());
}
let is_streaming = request.stream;
let request = request_builder

View file

@ -1063,9 +1063,8 @@ impl ContextServerStore {
// Start a loopback HTTP server on an ephemeral port. The redirect URI
// includes this port so the browser sends the callback directly to our
// process.
let (redirect_uri, callback_rx) = oauth::start_callback_server()
.await
.context("Failed to start OAuth callback server")?;
let (redirect_uri, callback_rx) =
oauth::start_callback_server().context("Failed to start OAuth callback server")?;
let http_client = cx.update(|cx| cx.http_client());
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
@ -1093,9 +1092,6 @@ impl ContextServerStore {
let callback = callback_rx
.await
.map_err(|_| {
anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
})?
.context("OAuth callback server received an invalid request")?;
if callback.state != state_param {