mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
Cherry-pick of #53166 to stable ---- Adds a new language model provider that lets users authenticate with their ChatGPT Plus/Pro subscription and use OpenAI models (codex-mini-latest, o4-mini, o3) directly in the Zed agent — without needing a separate API key. ## How it works 1. **OAuth 2.0 + PKCE sign-in**: Uses OpenAI's official Codex CLI client ID to run an authorization code flow. A local HTTP server on `127.0.0.1:1455` captures the callback, exchanges the code for tokens, and stores them in the system keychain. 2. **Token refresh**: Access tokens are automatically refreshed when they're within 5 minutes of expiry, using the stored refresh token. 3. **Responses API**: Requests go to `https://chatgpt.com/backend-api/codex/responses` using the existing `open_ai::responses` client (Responses API format, not Chat Completions which was deprecated for this endpoint in Feb 2026). 4. **Required headers**: `originator: zed`, `OpenAI-Beta: responses=experimental`, `ChatGPT-Account-Id` (extracted from JWT), `store: false` in the body. ## Files changed - `crates/open_ai/src/responses.rs`: Add `store: Option<bool>` field to `Request`; add `extra_headers` param to `stream_response` for per-provider header injection - `crates/language_models/src/provider/openai_subscribed.rs`: New provider (sign-in UI, OAuth flow, token storage/refresh, model list) - `crates/language_models/src/provider/open_ai.rs`, `open_ai_compatible.rs`, `opencode.rs`: Pass `vec![]` for new `extra_headers` param - `crates/language_models/src/language_models.rs`: Register the new provider - `crates/language_models/Cargo.toml`: Add `rand` and `sha2` deps for PKCE ## Open questions / known gaps - [ ] **Terms of service**: Usage appears to be within OpenAI's ToS (interactive use via their official CLI client ID), but needs legal sign-off before shipping - [ ] **Redirect URI**: Currently `http://localhost:1455/auth/callback` — may need to match exactly what OpenAI's Codex CLI uses - [ ] **UI polish**: The sign-in card is functional but minimal; needs design review - [ ] **Error messages**: OAuth error responses from the callback URL aren't surfaced to the user yet - [ ] **`o3` availability**: o3 may require a higher subscription tier; consider gating it ## Testing Sign-in flow was designed to match the Copilot Chat provider pattern. Manual testing against the live OAuth endpoint is needed. Release Notes: - Added ChatGPT subscription provider, allowing users to use their ChatGPT Plus/Pro subscription with the Zed agent --------- Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com> Co-authored-by: Richard Feldman <richard@zed.dev> Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Agus Zubiaga <agus@zed.dev> Co-authored-by: morgankrey <morgan@zed.dev> Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com> Co-authored-by: Richard Feldman <richard@zed.dev> Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
parent
63a20d1170
commit
902b2e3d3e
17 changed files with 2058 additions and 159 deletions
19
Cargo.lock
generated
19
Cargo.lock
generated
|
|
@ -3580,6 +3580,7 @@ dependencies = [
|
|||
"http_client",
|
||||
"log",
|
||||
"net",
|
||||
"oauth_callback_server",
|
||||
"parking_lot",
|
||||
"pollster 0.4.0",
|
||||
"postage",
|
||||
|
|
@ -3591,7 +3592,6 @@ dependencies = [
|
|||
"sha2",
|
||||
"slotmap",
|
||||
"tempfile",
|
||||
"tiny_http",
|
||||
"url",
|
||||
"util",
|
||||
]
|
||||
|
|
@ -9675,20 +9675,26 @@ dependencies = [
|
|||
"log",
|
||||
"menu",
|
||||
"mistral",
|
||||
"oauth_callback_server",
|
||||
"ollama",
|
||||
"open_ai",
|
||||
"open_router",
|
||||
"opencode",
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"rand 0.9.4",
|
||||
"release_channel",
|
||||
"schemars 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"sha2",
|
||||
"smol",
|
||||
"strum 0.27.2",
|
||||
"tokio",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"url",
|
||||
"util",
|
||||
"x_ai",
|
||||
]
|
||||
|
|
@ -11500,6 +11506,17 @@ dependencies = [
|
|||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oauth_callback_server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.32",
|
||||
"log",
|
||||
"tiny_http",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc"
|
||||
version = "0.2.7"
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ members = [
|
|||
"crates/net",
|
||||
"crates/node_runtime",
|
||||
"crates/notifications",
|
||||
"crates/oauth_callback_server",
|
||||
"crates/ollama",
|
||||
"crates/onboarding",
|
||||
"crates/opencode",
|
||||
|
|
@ -397,6 +398,7 @@ nc = { path = "crates/nc" }
|
|||
net = { path = "crates/net" }
|
||||
node_runtime = { path = "crates/node_runtime" }
|
||||
notifications = { path = "crates/notifications" }
|
||||
oauth_callback_server = { path = "crates/oauth_callback_server" }
|
||||
ollama = { path = "crates/ollama" }
|
||||
onboarding = { path = "crates/onboarding" }
|
||||
opencode = { path = "crates/opencode" }
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ gpui.workspace = true
|
|||
http_client = { workspace = true, features = ["test-support"] }
|
||||
log.workspace = true
|
||||
net.workspace = true
|
||||
oauth_callback_server.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rand.workspace = true
|
||||
postage.workspace = true
|
||||
|
|
@ -36,7 +37,6 @@ settings.workspace = true
|
|||
sha2.workspace = true
|
||||
slotmap.workspace = true
|
||||
tempfile.workspace = true
|
||||
tiny_http.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
|
||||
|
|
|
|||
|
|
@ -20,18 +20,18 @@ use anyhow::{Context as _, Result, anyhow, bail};
|
|||
use async_trait::async_trait;
|
||||
use base64::Engine as _;
|
||||
use futures::AsyncReadExt as _;
|
||||
use futures::FutureExt as _;
|
||||
use futures::channel::mpsc;
|
||||
use futures::future::BoxFuture;
|
||||
use http_client::{AsyncBody, HttpClient, Request};
|
||||
use parking_lot::Mutex as SyncMutex;
|
||||
use rand::Rng as _;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use url::Url;
|
||||
use util::ResultExt as _;
|
||||
|
||||
/// The CIMD URL where Zed's OAuth client metadata document is hosted.
|
||||
pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
|
||||
|
|
@ -992,58 +992,14 @@ impl OAuthCallback {
|
|||
/// Parse the query string from a callback URL like
|
||||
/// `http://127.0.0.1:<port>/callback?code=...&state=...`.
|
||||
pub fn parse_query(query: &str) -> Result<Self> {
|
||||
let mut code: Option<String> = None;
|
||||
let mut state: Option<String> = None;
|
||||
let mut error: Option<String> = None;
|
||||
let mut error_description: Option<String> = None;
|
||||
|
||||
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
|
||||
match key.as_ref() {
|
||||
"code" => {
|
||||
if !value.is_empty() {
|
||||
code = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
"state" => {
|
||||
if !value.is_empty() {
|
||||
state = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
"error" => {
|
||||
if !value.is_empty() {
|
||||
error = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
"error_description" => {
|
||||
if !value.is_empty() {
|
||||
error_description = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
|
||||
// checking for missing code/state.
|
||||
if let Some(error_code) = error {
|
||||
bail!(
|
||||
"OAuth authorization failed: {} ({})",
|
||||
error_code,
|
||||
error_description.as_deref().unwrap_or("no description")
|
||||
);
|
||||
}
|
||||
|
||||
let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
|
||||
let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
|
||||
|
||||
Ok(Self { code, state })
|
||||
let params = oauth_callback_server::OAuthCallbackParams::parse_query(query)?;
|
||||
Ok(Self {
|
||||
code: params.code,
|
||||
state: params.state,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// How long to wait for the browser to complete the OAuth flow before giving
|
||||
/// up and releasing the loopback port.
|
||||
const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
|
||||
|
||||
/// Start a loopback HTTP server to receive the OAuth authorization callback.
|
||||
///
|
||||
/// Binds to an ephemeral loopback port for each flow.
|
||||
|
|
@ -1056,104 +1012,24 @@ const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
|
|||
/// contains `code` and `state` query parameters, responds with a minimal
|
||||
/// HTML page telling the user they can close the tab, and shuts down.
|
||||
///
|
||||
/// The callback server shuts down when the returned oneshot receiver is dropped
|
||||
/// (e.g. because the authentication task was cancelled), or after a timeout
|
||||
/// ([CALLBACK_TIMEOUT]).
|
||||
pub async fn start_callback_server() -> Result<(
|
||||
String,
|
||||
futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
|
||||
)> {
|
||||
let server = tiny_http::Server::http("127.0.0.1:0")
|
||||
.map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
|
||||
let port = server
|
||||
.server_addr()
|
||||
.to_ip()
|
||||
.context("server not bound to a TCP address")?
|
||||
.port();
|
||||
|
||||
let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
|
||||
|
||||
let (tx, rx) = futures::channel::oneshot::channel();
|
||||
|
||||
// `tiny_http` is blocking, so we run it on a background thread.
|
||||
// The `recv_timeout` loop lets us check for cancellation (the receiver
|
||||
// being dropped) and enforce an overall timeout.
|
||||
std::thread::spawn(move || {
|
||||
let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
|
||||
|
||||
loop {
|
||||
if tx.is_canceled() {
|
||||
return;
|
||||
}
|
||||
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
|
||||
if remaining.is_zero() {
|
||||
return;
|
||||
}
|
||||
|
||||
let timeout = remaining.min(Duration::from_millis(500));
|
||||
let Some(request) = (match server.recv_timeout(timeout) {
|
||||
Ok(req) => req,
|
||||
Err(_) => {
|
||||
let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
|
||||
return;
|
||||
}
|
||||
}) else {
|
||||
// Timeout with no request — loop back and check cancellation.
|
||||
continue;
|
||||
};
|
||||
|
||||
let result = handle_callback_request(&request);
|
||||
|
||||
let (status_code, body) = match &result {
|
||||
Ok(_) => (
|
||||
200,
|
||||
"<html><body><h1>Authorization successful</h1>\
|
||||
<p>You can close this tab and return to Zed.</p></body></html>",
|
||||
),
|
||||
Err(err) => {
|
||||
log::error!("OAuth callback error: {}", err);
|
||||
(
|
||||
400,
|
||||
"<html><body><h1>Authorization failed</h1>\
|
||||
<p>Something went wrong. Please try again from Zed.</p></body></html>",
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let response = tiny_http::Response::from_string(body)
|
||||
.with_status_code(status_code)
|
||||
.with_header(
|
||||
tiny_http::Header::from_str("Content-Type: text/html")
|
||||
.expect("failed to construct response header"),
|
||||
)
|
||||
.with_header(
|
||||
tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
|
||||
.expect("failed to construct response header"),
|
||||
);
|
||||
request.respond(response).log_err();
|
||||
|
||||
let _ = tx.send(result);
|
||||
return;
|
||||
/// The callback server shuts down when the returned future is dropped (e.g.
|
||||
/// because the authentication task was cancelled), or after a timeout.
|
||||
pub fn start_callback_server() -> Result<(String, BoxFuture<'static, Result<OAuthCallback>>)> {
|
||||
let (redirect_uri, rx) = oauth_callback_server::start_oauth_callback_server()?;
|
||||
let future = async move {
|
||||
match rx.await {
|
||||
Ok(Ok(params)) => Ok(OAuthCallback {
|
||||
code: params.code,
|
||||
state: params.state,
|
||||
}),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_) => Err(anyhow!(
|
||||
"OAuth callback server was shut down before receiving a response"
|
||||
)),
|
||||
}
|
||||
});
|
||||
|
||||
Ok((redirect_uri, rx))
|
||||
}
|
||||
|
||||
/// Extract the `code` and `state` query parameters from an OAuth callback
|
||||
/// request to `/callback`.
|
||||
fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
|
||||
let url = Url::parse(&format!("http://localhost{}", request.url()))
|
||||
.context("malformed callback request URL")?;
|
||||
|
||||
if url.path() != "/callback" {
|
||||
bail!("unexpected path in OAuth callback: {}", url.path());
|
||||
}
|
||||
|
||||
let query = url
|
||||
.query()
|
||||
.ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
|
||||
OAuthCallback::parse_query(query)
|
||||
.boxed();
|
||||
Ok((redirect_uri, future))
|
||||
}
|
||||
|
||||
// -- JSON fetch helper -------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -47,19 +47,24 @@ lmstudio = { workspace = true, features = ["schemars"] }
|
|||
log.workspace = true
|
||||
menu.workspace = true
|
||||
mistral = { workspace = true, features = ["schemars"] }
|
||||
oauth_callback_server.workspace = true
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
opencode = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
rand.workspace = true
|
||||
release_channel.workspace = true
|
||||
schemars.workspace = true
|
||||
sha2.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strum.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
|
||||
|
|
@ -71,5 +76,6 @@ feature_flags.workspace = true
|
|||
gpui = { workspace = true, features = ["test-support"] }
|
||||
http_client = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
parking_lot.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider;
|
|||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
|
||||
use crate::provider::open_router::OpenRouterLanguageModelProvider;
|
||||
use crate::provider::openai_subscribed::OpenAiSubscribedProvider;
|
||||
use crate::provider::opencode::OpenCodeLanguageModelProvider;
|
||||
use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
|
||||
use crate::provider::x_ai::XAiLanguageModelProvider;
|
||||
|
|
@ -324,10 +325,18 @@ fn register_language_model_providers(
|
|||
registry.register_provider(
|
||||
Arc::new(OpenCodeLanguageModelProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider,
|
||||
credentials_provider.clone(),
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
|
||||
registry.register_provider(
|
||||
Arc::new(OpenAiSubscribedProvider::new(
|
||||
client.http_client(),
|
||||
credentials_provider,
|
||||
cx,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ pub mod ollama;
|
|||
pub mod open_ai;
|
||||
pub mod open_ai_compatible;
|
||||
pub mod open_router;
|
||||
pub mod openai_subscribed;
|
||||
pub mod opencode;
|
||||
|
||||
pub mod vercel_ai_gateway;
|
||||
|
|
|
|||
|
|
@ -383,6 +383,7 @@ impl OpenAiLanguageModel {
|
|||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
vec![],
|
||||
);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
|
|
|
|||
|
|
@ -289,6 +289,7 @@ impl OpenAiCompatibleLanguageModel {
|
|||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
vec![],
|
||||
);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
|
|
|
|||
1464
crates/language_models/src/provider/openai_subscribed.rs
Normal file
1464
crates/language_models/src/provider/openai_subscribed.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -479,6 +479,7 @@ impl OpenCodeLanguageModel {
|
|||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
vec![],
|
||||
);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
|
|
|
|||
23
crates/oauth_callback_server/Cargo.toml
Normal file
23
crates/oauth_callback_server/Cargo.toml
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
[package]
|
||||
name = "oauth_callback_server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish = false
|
||||
license = "Apache-2.0"
|
||||
description = "Loopback OAuth 2.0 callback server and shared HTML response page for Zed sign-in flows"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/oauth_callback_server.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
|
||||
[target.'cfg(not(target_family = "wasm"))'.dependencies]
|
||||
futures.workspace = true
|
||||
log.workspace = true
|
||||
tiny_http.workspace = true
|
||||
url.workspace = true
|
||||
1
crates/oauth_callback_server/LICENSE-APACHE
Symbolic link
1
crates/oauth_callback_server/LICENSE-APACHE
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
../../LICENSE-APACHE
|
||||
493
crates/oauth_callback_server/src/oauth_callback_server.rs
Normal file
493
crates/oauth_callback_server/src/oauth_callback_server.rs
Normal file
|
|
@ -0,0 +1,493 @@
|
|||
//! Loopback OAuth 2.0 callback server and shared HTML response page.
|
||||
//!
|
||||
//! Used by Zed's OAuth-based sign-in flows (e.g. MCP servers, ChatGPT
|
||||
//! Subscription) to receive the authorization code redirect from the user's
|
||||
//! browser. The HTML response page rendered to the browser is kept alongside
|
||||
//! the server so all OAuth callback presentation lives in one place.
|
||||
|
||||
/// Generate a styled HTML page for OAuth callback responses.
|
||||
///
|
||||
/// Returns a complete HTML document (no HTTP headers) with a centered card
|
||||
/// layout styled to match Zed's dark theme. The `title` is rendered as a
|
||||
/// heading and `message` as body text below it.
|
||||
///
|
||||
/// When `is_error` is true, a red X icon is shown instead of the green
|
||||
/// checkmark.
|
||||
pub fn oauth_callback_page(title: &str, message: &str, is_error: bool) -> String {
|
||||
let title = html_escape(title);
|
||||
let message = html_escape(message);
|
||||
let (icon_bg, icon_svg) = if is_error {
|
||||
(
|
||||
"#f38ba8",
|
||||
r#"<svg viewBox="0 0 24 24"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>"#,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
"#a6e3a1",
|
||||
r#"<svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>"#,
|
||||
)
|
||||
};
|
||||
format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>{title} — Zed</title>
|
||||
<style>
|
||||
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: #1e1e2e;
|
||||
color: #cdd6f4;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-height: 100vh;
|
||||
padding: 1rem;
|
||||
}}
|
||||
.card {{
|
||||
background: #313244;
|
||||
border-radius: 12px;
|
||||
padding: 2.5rem;
|
||||
max-width: 420px;
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
box-shadow: 0 4px 24px rgba(0, 0, 0, 0.3);
|
||||
}}
|
||||
.icon {{
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
margin: 0 auto 1.5rem;
|
||||
background: {icon_bg};
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}}
|
||||
.icon svg {{
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
stroke: #1e1e2e;
|
||||
stroke-width: 3;
|
||||
fill: none;
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.75rem;
|
||||
color: #cdd6f4;
|
||||
}}
|
||||
p {{
|
||||
font-size: 0.925rem;
|
||||
line-height: 1.5;
|
||||
color: #a6adc8;
|
||||
}}
|
||||
.brand {{
|
||||
margin-top: 1.5rem;
|
||||
font-size: 0.8rem;
|
||||
color: #585b70;
|
||||
letter-spacing: 0.05em;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="icon">
|
||||
{icon_svg}
|
||||
</div>
|
||||
<h1>{title}</h1>
|
||||
<p>{message}</p>
|
||||
<div class="brand">Zed</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
title = title,
|
||||
message = message,
|
||||
icon_bg = icon_bg,
|
||||
icon_svg = icon_svg,
|
||||
)
|
||||
}
|
||||
|
||||
fn html_escape(input: &str) -> String {
|
||||
let mut output = String::with_capacity(input.len());
|
||||
for ch in input.chars() {
|
||||
match ch {
|
||||
'&' => output.push_str("&"),
|
||||
'<' => output.push_str("<"),
|
||||
'>' => output.push_str(">"),
|
||||
'"' => output.push_str("""),
|
||||
'\'' => output.push_str("'"),
|
||||
_ => output.push(ch),
|
||||
}
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
mod server {
|
||||
use super::oauth_callback_page;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
/// Parsed OAuth callback parameters from the authorization server redirect.
|
||||
pub struct OAuthCallbackParams {
|
||||
pub code: String,
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
/// Configuration for the loopback OAuth callback server.
|
||||
///
|
||||
/// OAuth servers compare `redirect_uri` against a per-client allow-list using
|
||||
/// exact string matching (RFC 6749 §3.1.2), so the `host`, `preferred_port`,
|
||||
/// and `path` here must match what's registered for the OAuth client_id.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct OAuthCallbackServerConfig {
|
||||
/// Host portion of the redirect URI (typically `127.0.0.1` or `localhost`).
|
||||
pub host: &'static str,
|
||||
/// Preferred port. Use `0` for an OS-assigned ephemeral port.
|
||||
pub preferred_port: u16,
|
||||
/// Optional fallback port if `preferred_port` is unavailable. Only used
|
||||
/// when `preferred_port` is non-zero.
|
||||
pub fallback_port: Option<u16>,
|
||||
/// Callback path on the redirect URI (e.g. `/callback`, `/auth/callback`).
|
||||
pub path: &'static str,
|
||||
}
|
||||
|
||||
impl Default for OAuthCallbackServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: "127.0.0.1",
|
||||
preferred_port: 0,
|
||||
fallback_port: None,
|
||||
path: "/callback",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthCallbackParams {
|
||||
/// Parse the query string from a callback URL like
|
||||
/// `http://127.0.0.1:<port>/callback?code=...&state=...`.
|
||||
pub fn parse_query(query: &str) -> Result<Self> {
|
||||
let mut code: Option<String> = None;
|
||||
let mut state: Option<String> = None;
|
||||
let mut error: Option<String> = None;
|
||||
let mut error_description: Option<String> = None;
|
||||
|
||||
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
|
||||
match key.as_ref() {
|
||||
"code" => {
|
||||
if !value.is_empty() {
|
||||
code = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
"state" => {
|
||||
if !value.is_empty() {
|
||||
state = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
"error" => {
|
||||
if !value.is_empty() {
|
||||
error = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
"error_description" => {
|
||||
if !value.is_empty() {
|
||||
error_description = Some(value.into_owned());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(error_code) = error {
|
||||
anyhow::bail!(
|
||||
"OAuth authorization failed: {} ({})",
|
||||
error_code,
|
||||
error_description.as_deref().unwrap_or("no description")
|
||||
);
|
||||
}
|
||||
|
||||
let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
|
||||
let state =
|
||||
state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
|
||||
|
||||
Ok(Self { code, state })
|
||||
}
|
||||
}
|
||||
|
||||
/// How long to wait for the browser to complete the OAuth flow before giving
|
||||
/// up and releasing the loopback port.
|
||||
const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
|
||||
|
||||
/// Start a loopback HTTP server to receive the OAuth authorization callback.
|
||||
///
|
||||
/// Binds to an ephemeral loopback port. Returns `(redirect_uri, callback_future)`.
|
||||
/// The caller should use the redirect URI in the authorization request, open
|
||||
/// the browser, then await the future to receive the callback.
|
||||
pub fn start_oauth_callback_server() -> Result<(
|
||||
String,
|
||||
futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
|
||||
)> {
|
||||
start_oauth_callback_server_with_config(OAuthCallbackServerConfig::default())
|
||||
}
|
||||
|
||||
/// Start a loopback HTTP server with custom host/port/path.
|
||||
///
|
||||
/// Use this when the OAuth client requires a specific redirect URI that the
|
||||
/// default ephemeral-port `http://127.0.0.1:<port>/callback` doesn't match.
|
||||
pub fn start_oauth_callback_server_with_config(
|
||||
config: OAuthCallbackServerConfig,
|
||||
) -> Result<(
|
||||
String,
|
||||
futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
|
||||
)> {
|
||||
let server = bind_callback_server(&config)?;
|
||||
let port = server
|
||||
.server_addr()
|
||||
.to_ip()
|
||||
.ok_or_else(|| anyhow!("server not bound to a TCP address"))?
|
||||
.port();
|
||||
|
||||
let redirect_uri = format!("http://{}:{}{}", config.host, port, config.path);
|
||||
let expected_path = config.path;
|
||||
|
||||
let (tx, rx) = futures::channel::oneshot::channel();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT;
|
||||
|
||||
loop {
|
||||
if tx.is_canceled() {
|
||||
return;
|
||||
}
|
||||
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
|
||||
if remaining.is_zero() {
|
||||
return;
|
||||
}
|
||||
|
||||
let timeout = remaining.min(Duration::from_millis(500));
|
||||
let Some(request) = (match server.recv_timeout(timeout) {
|
||||
Ok(req) => req,
|
||||
Err(_) => {
|
||||
let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
|
||||
return;
|
||||
}
|
||||
}) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let raw_url = request.url().to_string();
|
||||
let raw_path = raw_url.split('?').next().unwrap_or(&raw_url);
|
||||
if raw_path == CANCEL_PATH {
|
||||
let response = tiny_http::Response::from_string("Cancelled")
|
||||
.with_status_code(200)
|
||||
.with_header(
|
||||
tiny_http::Header::from_str("Content-Type: text/plain")
|
||||
.expect("failed to construct response header"),
|
||||
)
|
||||
.with_header(
|
||||
tiny_http::Header::from_str("Connection: close")
|
||||
.expect("failed to construct response header"),
|
||||
);
|
||||
if let Err(err) = request.respond(response) {
|
||||
log::error!("Failed to send OAuth cancel response: {}", err);
|
||||
}
|
||||
let _ = tx.send(Err(anyhow!(
|
||||
"OAuth callback server was cancelled by another sign-in attempt"
|
||||
)));
|
||||
return;
|
||||
}
|
||||
|
||||
let result = handle_oauth_callback_request(&request, expected_path);
|
||||
|
||||
let (status_code, body) = match &result {
|
||||
Ok(_) => (
|
||||
200,
|
||||
oauth_callback_page(
|
||||
"Authorization Successful",
|
||||
"You can close this tab and return to Zed.",
|
||||
false,
|
||||
),
|
||||
),
|
||||
Err(err) => {
|
||||
log::error!("OAuth callback error: {}", err);
|
||||
(
|
||||
400,
|
||||
oauth_callback_page(
|
||||
"Authorization Failed",
|
||||
"Something went wrong. Please try again from Zed.",
|
||||
true,
|
||||
),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let response = tiny_http::Response::from_string(body)
|
||||
.with_status_code(status_code)
|
||||
.with_header(
|
||||
tiny_http::Header::from_str("Content-Type: text/html")
|
||||
.expect("failed to construct response header"),
|
||||
)
|
||||
.with_header(
|
||||
tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
|
||||
.expect("failed to construct response header"),
|
||||
);
|
||||
if let Err(err) = request.respond(response) {
|
||||
log::error!("Failed to send OAuth callback response: {}", err);
|
||||
}
|
||||
|
||||
let _ = tx.send(result);
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
Ok((redirect_uri, rx))
|
||||
}
|
||||
|
||||
fn handle_oauth_callback_request(
|
||||
request: &tiny_http::Request,
|
||||
expected_path: &str,
|
||||
) -> Result<OAuthCallbackParams> {
|
||||
let url = Url::parse(&format!("http://localhost{}", request.url()))
|
||||
.context("malformed callback request URL")?;
|
||||
|
||||
if url.path() != expected_path {
|
||||
anyhow::bail!("unexpected path in OAuth callback: {}", url.path());
|
||||
}
|
||||
|
||||
let query = url
|
||||
.query()
|
||||
.ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
|
||||
OAuthCallbackParams::parse_query(query)
|
||||
}
|
||||
|
||||
/// Callback path reserved for evicting a previously-running OAuth callback
|
||||
/// server bound to the same port. Always handled, regardless of `config.path`.
|
||||
const CANCEL_PATH: &str = "/cancel";
|
||||
|
||||
const BIND_MAX_ATTEMPTS: u32 = 10;
|
||||
const BIND_RETRY_DELAY: Duration = Duration::from_millis(200);
|
||||
const CANCEL_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
fn bind_callback_server(config: &OAuthCallbackServerConfig) -> Result<tiny_http::Server> {
|
||||
// Ephemeral ports always succeed; skip the cancel-retry dance entirely.
|
||||
if config.preferred_port == 0 {
|
||||
let addr = format!("{}:0", config.host);
|
||||
return tiny_http::Server::http(&addr).map_err(|err| {
|
||||
anyhow!(err).context(format!(
|
||||
"Failed to bind loopback listener for OAuth callback on {addr}"
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
match try_bind_with_cancel(config.host, config.preferred_port) {
|
||||
Ok(server) => Ok(server),
|
||||
Err(primary_err) => {
|
||||
let Some(fallback_port) = config.fallback_port else {
|
||||
return Err(primary_err.context(format!(
|
||||
"Failed to bind loopback listener for OAuth callback on {}:{}",
|
||||
config.host, config.preferred_port,
|
||||
)));
|
||||
};
|
||||
log::warn!(
|
||||
"OAuth callback port {}:{} unavailable; falling back to port {}",
|
||||
config.host,
|
||||
config.preferred_port,
|
||||
fallback_port,
|
||||
);
|
||||
try_bind_with_cancel(config.host, fallback_port).map_err(|fallback_err| {
|
||||
fallback_err.context(format!(
|
||||
"Failed to bind loopback listener for OAuth callback on {}:{} or {}:{}",
|
||||
config.host, config.preferred_port, config.host, fallback_port,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to bind to a fixed `host:port`. On `AddrInUse`, sends a single
|
||||
/// `GET /cancel` to the existing listener (to evict a previous OAuth flow
|
||||
/// from this or a compatible client) and retries.
|
||||
fn try_bind_with_cancel(host: &'static str, port: u16) -> Result<tiny_http::Server> {
|
||||
let addr = format!("{host}:{port}");
|
||||
let mut cancel_attempted = false;
|
||||
let mut last_err: Option<anyhow::Error> = None;
|
||||
|
||||
for _ in 0..BIND_MAX_ATTEMPTS {
|
||||
match tiny_http::Server::http(&addr) {
|
||||
Ok(server) => return Ok(server),
|
||||
Err(err) => {
|
||||
let is_addr_in_use = err
|
||||
.downcast_ref::<std::io::Error>()
|
||||
.map(|io_err| io_err.kind() == std::io::ErrorKind::AddrInUse)
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_addr_in_use {
|
||||
return Err(anyhow!(err).context(format!(
|
||||
"Failed to bind loopback listener for OAuth callback on {addr}"
|
||||
)));
|
||||
}
|
||||
|
||||
if !cancel_attempted {
|
||||
cancel_attempted = true;
|
||||
if let Err(cancel_err) = send_cancel_request(host, port) {
|
||||
log::warn!(
|
||||
"Failed to cancel previous OAuth callback server on {addr}: {cancel_err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
last_err = Some(anyhow!(err));
|
||||
std::thread::sleep(BIND_RETRY_DELAY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err
|
||||
.unwrap_or_else(|| anyhow!("unknown bind error"))
|
||||
.context(format!(
|
||||
"OAuth callback port {addr} remained in use after {BIND_MAX_ATTEMPTS} attempts"
|
||||
)))
|
||||
}
|
||||
|
||||
/// Sends `GET /cancel` to a listener on `host:port`, asking it to shut down.
|
||||
///
|
||||
/// Best-effort: errors here are surfaced to the caller for logging but do
|
||||
/// not block the subsequent rebind attempt.
|
||||
fn send_cancel_request(host: &str, port: u16) -> std::io::Result<()> {
|
||||
use std::io::{Read as _, Write as _};
|
||||
use std::net::{TcpStream, ToSocketAddrs as _};
|
||||
|
||||
let addr = format!("{host}:{port}")
|
||||
.to_socket_addrs()?
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("could not resolve {host}:{port}"),
|
||||
)
|
||||
})?;
|
||||
let mut stream = TcpStream::connect_timeout(&addr, CANCEL_REQUEST_TIMEOUT)?;
|
||||
stream.set_read_timeout(Some(CANCEL_REQUEST_TIMEOUT))?;
|
||||
stream.set_write_timeout(Some(CANCEL_REQUEST_TIMEOUT))?;
|
||||
|
||||
stream.write_all(b"GET /cancel HTTP/1.1\r\n")?;
|
||||
stream.write_all(format!("Host: {host}:{port}\r\n").as_bytes())?;
|
||||
stream.write_all(b"Connection: close\r\n\r\n")?;
|
||||
|
||||
// Drain the response so the server can close cleanly. We don't care
|
||||
// about the body; errors here are harmless.
|
||||
let mut buf = [0u8; 64];
|
||||
let _ = stream.read(&mut buf);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
pub use server::{
|
||||
OAuthCallbackParams, OAuthCallbackServerConfig, start_oauth_callback_server,
|
||||
start_oauth_callback_server_with_config,
|
||||
};
|
||||
|
|
@ -259,8 +259,9 @@ pub fn into_open_ai_response(
|
|||
|
||||
ResponseRequest {
|
||||
model: model_id.into(),
|
||||
instructions: None,
|
||||
input: input_items,
|
||||
store: false,
|
||||
store: Some(false),
|
||||
include,
|
||||
stream,
|
||||
temperature,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ use crate::{ReasoningEffort, RequestError, Role, ToolChoice};
|
|||
#[derive(Serialize, Debug)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub instructions: Option<String>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub input: Vec<ResponseInputItem>,
|
||||
pub store: bool,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub include: Vec<ResponseIncludable>,
|
||||
#[serde(default)]
|
||||
|
|
@ -32,6 +33,8 @@ pub struct Request {
|
|||
pub prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<ReasoningConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub store: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
|
|
@ -395,13 +398,17 @@ pub async fn stream_response(
|
|||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
extra_headers: Vec<(String, String)>,
|
||||
) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
|
||||
let uri = format!("{api_url}/responses");
|
||||
let request_builder = HttpRequest::builder()
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key.trim()));
|
||||
for (name, value) in &extra_headers {
|
||||
request_builder = request_builder.header(name.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let is_streaming = request.stream;
|
||||
let request = request_builder
|
||||
|
|
|
|||
|
|
@ -1063,9 +1063,8 @@ impl ContextServerStore {
|
|||
// Start a loopback HTTP server on an ephemeral port. The redirect URI
|
||||
// includes this port so the browser sends the callback directly to our
|
||||
// process.
|
||||
let (redirect_uri, callback_rx) = oauth::start_callback_server()
|
||||
.await
|
||||
.context("Failed to start OAuth callback server")?;
|
||||
let (redirect_uri, callback_rx) =
|
||||
oauth::start_callback_server().context("Failed to start OAuth callback server")?;
|
||||
|
||||
let http_client = cx.update(|cx| cx.http_client());
|
||||
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
|
||||
|
|
@ -1093,9 +1092,6 @@ impl ContextServerStore {
|
|||
|
||||
let callback = callback_rx
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
|
||||
})?
|
||||
.context("OAuth callback server received an invalid request")?;
|
||||
|
||||
if callback.state != state_param {
|
||||
|
|
|
|||
Loading…
Reference in a new issue