Implement MCP OAuth client preregistration (#52900)

In the interactive MCP OAuth flow, the MCP client registers itself with
the authorization in one of three ways:

- Client ID Metadata Document aka CIMD (recommended default). This is
already implemented: https://zed.dev/oauth/client-metadata.json.
- Dynamic Client Registration (DCR). This is the traditional method.
Also already implemented in Zed.
- Pre-registration: the client is registered out of band, typically in
the IdP or SaaS provider's UI. You get a client id and maybe a client
secret, that have to be provided by the MCP client when it wants to
exchange an access token. This is what this pull request is about.

This PR has two main parts:

- Allow users to configure a client id and optional client secret for an
MCP server in their configuration, under a new `oauth` key, and take it
into account
- Make the MCP server state and the configuration modal aware of the
intermediate states (client secret missing) and error cases stemming
from client pre-registration.

The client secret can be stored either in the system keychain or in
plain text in the MCP server configuration. The UI tries to steer user
towards the more secure option: the keychain.

<img width="715" height="201" alt="Screenshot 2026-04-10 at 16 48 06"
src="https://github.com/user-attachments/assets/5e64103e-6746-4ef0-8bd9-533d492b6912"
/>

<img width="884" height="544" alt="Screenshot 2026-04-10 at 16 47 07"
src="https://github.com/user-attachments/assets/0e35bb3c-cbc4-4e8c-a713-66323597b2e2"
/>


<img width="785" height="558" alt="Screenshot 2026-04-10 at 16 47 23"
src="https://github.com/user-attachments/assets/03339187-1508-461a-87ae-a7c2647df9a5"
/>



Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes
https://github.com/issues/assigned?issue=zed-industries%7Czed%7C52198

**Note for the reviewer: I know how busy the AI team is at the moment so
please treat this as low priority, we don't have signal that this is a
highly desired feature. It's a rather large PR, so I'm happy to pair
review / walk through it.**

Release Notes:

- Added support for OAuth client pre-registration (client id, client
secret) to the built-in MCP client.
This commit is contained in:
Tom Houlé 2026-05-19 19:45:07 +02:00 committed by GitHub
parent c0596fade7
commit 0d832bc6d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 797 additions and 65 deletions

View file

@ -261,7 +261,8 @@ impl ContextServerRegistry {
}
ContextServerStatus::Stopped
| ContextServerStatus::Error(_)
| ContextServerStatus::AuthRequired => {
| ContextServerStatus::AuthRequired
| ContextServerStatus::ClientSecretRequired { .. } => {
if let Some(registered_server) = self.registered_servers.remove(server_id) {
if !registered_server.tools.is_empty() {
cx.emit(ContextServerRegistryEvent::ToolsChanged);

View file

@ -3844,6 +3844,7 @@ fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpS
url,
headers,
timeout: _,
oauth: _,
} => Some(acp::McpServer::Http(
acp::McpServerHttp::new(id.0.to_string(), url.to_string()).headers(
headers

View file

@ -664,8 +664,14 @@ impl AgentConfiguration {
None
};
let auth_required = matches!(server_status, ContextServerStatus::AuthRequired);
let client_secret_required = matches!(
server_status,
ContextServerStatus::ClientSecretRequired { .. }
);
let authenticating = matches!(server_status, ContextServerStatus::Authenticating);
let context_server_store = self.context_server_store.clone();
let workspace = self.workspace.clone();
let language_registry = self.language_registry.clone();
let tool_count = self
.context_server_registry
@ -685,6 +691,9 @@ impl AgentConfiguration {
ContextServerStatus::Error(_) => AiSettingItemStatus::Error,
ContextServerStatus::Stopped => AiSettingItemStatus::Stopped,
ContextServerStatus::AuthRequired => AiSettingItemStatus::AuthRequired,
ContextServerStatus::ClientSecretRequired { .. } => {
AiSettingItemStatus::ClientSecretRequired
}
ContextServerStatus::Authenticating => AiSettingItemStatus::Authenticating,
};
@ -886,7 +895,7 @@ impl AgentConfiguration {
),
)
.child(
Button::new("error-logout-server", "Authenticate")
Button::new("authenticate-server", "Authenticate")
.style(ButtonStyle::Outlined)
.label_size(LabelSize::Small)
.on_click({
@ -900,6 +909,46 @@ impl AgentConfiguration {
)
.into_any_element(),
)
} else if client_secret_required {
Some(
feedback_base_container()
.child(
h_flex()
.pr_4()
.min_w_0()
.w_full()
.gap_2()
.child(
Icon::new(IconName::Info)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
Label::new("Enter a client secret to connect this server")
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.child(
Button::new("enter-client-secret", "Enter Client Secret")
.style(ButtonStyle::Outlined)
.label_size(LabelSize::Small)
.on_click({
let context_server_id = context_server_id.clone();
move |_event, window, cx| {
ConfigureContextServerModal::show_modal_for_existing_server(
context_server_id.clone(),
language_registry.clone(),
workspace.clone(),
window,
cx,
)
.detach();
}
}),
)
.into_any_element(),
)
} else if authenticating {
Some(
h_flex()

View file

@ -17,7 +17,7 @@ use project::{
ContextServerStatus, ContextServerStore, ServerStatusChangedEvent,
registry::ContextServerDescriptorRegistry,
},
project_settings::{ContextServerSettings, ProjectSettings},
project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings},
worktree_store::WorktreeStore,
};
use serde::Deserialize;
@ -43,7 +43,9 @@ enum ConfigurationTarget {
id: ContextServerId,
url: String,
headers: HashMap<String, String>,
oauth: Option<OAuthClientSettings>,
},
Extension {
id: ContextServerId,
repository_url: Option<SharedString>,
@ -121,15 +123,17 @@ impl ConfigurationSource {
id,
url,
headers: auth,
oauth,
} => ConfigurationSource::Existing {
editor: create_editor(
context_server_http_input(Some((id, url, auth))),
context_server_http_input(Some((id, url, auth, oauth))),
jsonc_language,
window,
cx,
),
is_http: true,
},
ConfigurationTarget::Extension {
id,
repository_url,
@ -168,7 +172,7 @@ impl ConfigurationSource {
ConfigurationSource::New { editor, is_http }
| ConfigurationSource::Existing { editor, is_http } => {
if *is_http {
parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth)| {
parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth, oauth)| {
(
id,
ContextServerSettings::Http {
@ -176,6 +180,7 @@ impl ConfigurationSource {
url,
headers: auth,
timeout: None,
oauth,
},
)
})
@ -256,11 +261,16 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
}
fn context_server_http_input(
existing: Option<(ContextServerId, String, HashMap<String, String>)>,
existing: Option<(
ContextServerId,
String,
HashMap<String, String>,
Option<OAuthClientSettings>,
)>,
) -> String {
let (name, url, headers) = match existing {
Some((id, url, headers)) => {
let header = if headers.is_empty() {
let (name, url, headers, oauth) = match existing {
Some((id, url, headers, oauth)) => {
let headers = if headers.is_empty() {
r#"// "Authorization": "Bearer <token>"#.to_string()
} else {
let json = serde_json::to_string_pretty(&headers).unwrap();
@ -274,15 +284,48 @@ fn context_server_http_input(
.map(|line| format!(" {}", line))
.collect::<String>()
};
(id.0.to_string(), url, header)
(id.0.to_string(), url, headers, oauth)
}
None => (
"some-remote-server".to_string(),
"https://example.com/mcp".to_string(),
r#"// "Authorization": "Bearer <token>"#.to_string(),
None,
),
};
let oauth = oauth.map_or_else(
|| {
r#"
/// Uncomment to use a pre-registered OAuth client. You can include the client secret here as well, otherwise it will be prompted interactively and saved in the system keychain.
// "oauth": {
// "client_id": "your-client-id",
// },"#
.to_string()
},
|oauth| {
let mut lines = vec![
String::from("\n \"oauth\": {"),
format!(" \"client_id\": {},", serde_json::to_string(&oauth.client_id).unwrap()),
];
if let Some(client_secret) = oauth.client_secret {
lines.push(format!(
" \"client_secret\": {}",
serde_json::to_string(&client_secret).unwrap()
));
} else {
lines.push(String::from(
" /// Optional client secret for confidential clients\n // \"client_secret\": \"your-client-secret\"",
));
}
lines.push(String::from(" },"));
lines.join("\n")
},
);
format!(
r#"{{
/// Configure an MCP server that you connect to over HTTP
@ -290,7 +333,7 @@ fn context_server_http_input(
/// The name of your remote MCP server
"{name}": {{
/// The URL of the remote MCP server
"url": "{url}",
"url": "{url}",{oauth}
"headers": {{
/// Any headers to send along
{headers}
@ -300,12 +343,21 @@ fn context_server_http_input(
)
}
fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap<String, String>)> {
fn parse_http_input(
text: &str,
) -> Result<(
ContextServerId,
String,
HashMap<String, String>,
Option<OAuthClientSettings>,
)> {
#[derive(Deserialize)]
struct Temp {
url: String,
#[serde(default)]
headers: HashMap<String, String>,
#[serde(default)]
oauth: Option<OAuthClientSettings>,
}
let value: HashMap<String, Temp> = serde_json_lenient::from_str(text)?;
if value.len() != 1 {
@ -314,7 +366,12 @@ fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap<Stri
let (key, value) = value.into_iter().next().unwrap();
Ok((ContextServerId(key.into()), value.url, value.headers))
Ok((
ContextServerId(key.into()),
value.url,
value.headers,
value.oauth,
))
}
fn resolve_context_server_extension(
@ -349,8 +406,16 @@ fn resolve_context_server_extension(
enum State {
Idle,
Waiting,
AuthRequired { server_id: ContextServerId },
Authenticating { _server_id: ContextServerId },
AuthRequired {
server_id: ContextServerId,
},
ClientSecretRequired {
server_id: ContextServerId,
error: Option<SharedString>,
},
Authenticating {
server_id: ContextServerId,
},
Error(SharedString),
}
@ -361,10 +426,47 @@ pub struct ConfigureContextServerModal {
state: State,
original_server_id: Option<ContextServerId>,
scroll_handle: ScrollHandle,
secret_editor: Entity<Editor>,
_auth_subscription: Option<Subscription>,
}
impl ConfigureContextServerModal {
fn initial_state(
context_server_store: &Entity<ContextServerStore>,
target: &ConfigurationTarget,
cx: &App,
) -> State {
let Some(server_id) = (match target {
ConfigurationTarget::Existing { id, .. }
| ConfigurationTarget::ExistingHttp { id, .. }
| ConfigurationTarget::Extension { id, .. } => Some(id),
ConfigurationTarget::New => None,
}) else {
return State::Idle;
};
match context_server_store.read(cx).status_for_server(server_id) {
Some(ContextServerStatus::AuthRequired) => State::AuthRequired {
server_id: server_id.clone(),
},
Some(ContextServerStatus::ClientSecretRequired { error }) => {
State::ClientSecretRequired {
server_id: server_id.clone(),
error: error.map(SharedString::from),
}
}
Some(ContextServerStatus::Authenticating) => State::Authenticating {
server_id: server_id.clone(),
},
Some(ContextServerStatus::Error(error)) => State::Error(error.into()),
Some(ContextServerStatus::Starting)
| Some(ContextServerStatus::Running)
| Some(ContextServerStatus::Stopped)
| None => State::Idle,
}
}
pub fn register(
workspace: &mut Workspace,
language_registry: Arc<LanguageRegistry>,
@ -426,12 +528,14 @@ impl ConfigureContextServerModal {
url,
headers,
timeout: _,
..
oauth,
} => Some(ConfigurationTarget::ExistingHttp {
id: server_id,
url,
headers,
oauth,
}),
ContextServerSettings::Extension { .. } => {
match workspace
.update(cx, |workspace, cx| {
@ -468,9 +572,10 @@ impl ConfigureContextServerModal {
let workspace_handle = cx.weak_entity();
let context_server_store = workspace.project().read(cx).context_server_store();
workspace.toggle_modal(window, cx, |window, cx| Self {
context_server_store,
context_server_store: context_server_store.clone(),
workspace: workspace_handle,
state: State::Idle,
state: Self::initial_state(&context_server_store, &target, cx),
original_server_id: match &target {
ConfigurationTarget::Existing { id, .. } => Some(id.clone()),
ConfigurationTarget::ExistingHttp { id, .. } => Some(id.clone()),
@ -485,6 +590,16 @@ impl ConfigureContextServerModal {
cx,
),
scroll_handle: ScrollHandle::new(),
secret_editor: cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text(
"Enter client secret (leave empty for public clients)",
window,
cx,
);
editor.set_masked(true, cx);
editor
}),
_auth_subscription: None,
})
})
@ -497,13 +612,12 @@ impl ConfigureContextServerModal {
}
fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context<Self>) {
if matches!(
self.state,
State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
) {
if matches!(self.state, State::Waiting | State::Authenticating { .. }) {
return;
}
self._auth_subscription = None;
self.state = State::Idle;
let Some(workspace) = self.workspace.upgrade() else {
return;
@ -519,7 +633,7 @@ impl ConfigureContextServerModal {
self.state = State::Waiting;
let existing_server = self.context_server_store.read(cx).get_running_server(&id);
let existing_server = self.context_server_store.read(cx).get_server(&id);
if existing_server.is_some() {
self.context_server_store.update(cx, |store, cx| {
store.stop_server(&id, cx).log_err();
@ -542,6 +656,13 @@ impl ConfigureContextServerModal {
this.state = State::AuthRequired { server_id: id };
cx.notify();
}
Ok(ContextServerStatus::ClientSecretRequired { error }) => {
this.state = State::ClientSecretRequired {
server_id: id,
error: error.map(SharedString::from),
};
cx.notify();
}
Err(err) => {
this.set_error(err, cx);
}
@ -581,13 +702,33 @@ impl ConfigureContextServerModal {
cx.emit(DismissEvent);
}
fn cancel_authentication(&mut self, server_id: &ContextServerId, cx: &mut Context<Self>) {
self._auth_subscription = None;
self.context_server_store.update(cx, |store, cx| {
store.stop_server(server_id, cx).log_err();
});
self.state = State::Idle;
cx.notify();
}
fn authenticate(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
self.context_server_store.update(cx, |store, cx| {
store.authenticate_server(&server_id, cx).log_err();
});
self.await_auth_outcome(server_id, cx);
}
fn submit_client_secret(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
let secret = self.secret_editor.read(cx).text(cx);
self.context_server_store.update(cx, |store, cx| {
store.submit_client_secret(&server_id, secret, cx).log_err();
});
self.await_auth_outcome(server_id, cx);
}
fn await_auth_outcome(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
self.state = State::Authenticating {
_server_id: server_id.clone(),
server_id: server_id.clone(),
};
self._auth_subscription = Some(cx.subscribe(
@ -610,6 +751,14 @@ impl ConfigureContextServerModal {
};
cx.notify();
}
ContextServerStatus::ClientSecretRequired { error } => {
this._auth_subscription = None;
this.state = State::ClientSecretRequired {
server_id: event.server_id.clone(),
error: error.clone().map(SharedString::from),
};
cx.notify();
}
ContextServerStatus::Error(error) => {
this._auth_subscription = None;
this.set_error(error.clone(), cx);
@ -814,10 +963,7 @@ impl ConfigureContextServerModal {
fn render_modal_footer(&self, cx: &mut Context<Self>) -> ModalFooter {
let focus_handle = self.focus_handle(cx);
let is_busy = matches!(
self.state,
State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
);
let is_busy = matches!(self.state, State::Waiting | State::Authenticating { .. });
ModalFooter::new()
.start_slot::<Button>(
@ -944,6 +1090,112 @@ impl ConfigureContextServerModal {
)
}
fn render_client_secret_required(
&self,
server_id: &ContextServerId,
error: Option<SharedString>,
cx: &mut Context<Self>,
) -> Div {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_size: settings.buffer_font_size(cx).into(),
font_weight: settings.buffer_font.weight,
line_height: relative(settings.buffer_line_height.value()),
..Default::default()
};
v_flex()
.w_full()
.gap_2()
.when_some(error, |this, error| {
this.child(Self::render_modal_error(error))
})
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::Info)
.size(IconSize::Small)
.color(Color::Muted),
)
.child(
Label::new(
"Enter your OAuth client secret, or leave empty for public clients",
)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
h_flex()
.w_full()
.gap_2()
.capture_action({
let server_id = server_id.clone();
cx.listener(move |this, _: &editor::actions::Newline, _window, cx| {
this.submit_client_secret(server_id.clone(), cx);
})
})
.child(div().flex_1().child(EditorElement::new(
&self.secret_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
)))
.child(
Button::new("submit-client-secret", "Submit")
.style(ButtonStyle::Outlined)
.label_size(LabelSize::Small)
.on_click({
let server_id = server_id.clone();
cx.listener(move |this, _event, _window, cx| {
this.submit_client_secret(server_id.clone(), cx);
})
}),
),
)
}
fn render_authenticating(&self, server_id: &ContextServerId, cx: &mut Context<Self>) -> Div {
h_flex()
.h_8()
.gap_2()
.justify_center()
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::LoadCircle)
.size(IconSize::XSmall)
.color(Color::Muted)
.with_rotate_animation(3),
)
.child(
Label::new("Authenticating…")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
Button::new("cancel-authentication", "Cancel")
.style(ButtonStyle::Outlined)
.label_size(LabelSize::Small)
.on_click({
let server_id = server_id.clone();
cx.listener(move |this, _event, _window, cx| {
this.cancel_authentication(&server_id, cx);
})
}),
)
}
fn render_modal_error(error: SharedString) -> Div {
h_flex()
.h_8()
@ -1003,8 +1255,15 @@ impl Render for ConfigureContextServerModal {
State::AuthRequired { server_id } => {
self.render_auth_required(&server_id.clone(), cx)
}
State::Authenticating { .. } => {
self.render_loading("Authenticating…")
State::ClientSecretRequired { server_id, error } => {
self.render_client_secret_required(
&server_id.clone(),
error.clone(),
cx,
)
}
State::Authenticating { server_id } => {
self.render_authenticating(&server_id.clone(), cx)
}
State::Error(error) => {
Self::render_modal_error(error.clone())
@ -1040,7 +1299,9 @@ fn wait_for_context_server(
}
match status {
ContextServerStatus::Running | ContextServerStatus::AuthRequired => {
ContextServerStatus::Running
| ContextServerStatus::AuthRequired
| ContextServerStatus::ClientSecretRequired { .. } => {
if let Some(tx) = tx.lock().take() {
let _ = tx.send(Ok(status.clone()));
}
@ -1104,3 +1365,52 @@ pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle
..Default::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_http_input_reads_oauth_settings() {
let (id, url, headers, oauth) = parse_http_input(
r#"{
"figma": {
"url": "https://mcp.figma.com/mcp",
"oauth": {
"client_id": "client-id",
"client_secret": "client-secret"
},
"headers": {
"X-Test": "test"
}
}
}"#,
)
.unwrap();
assert_eq!(id, ContextServerId("figma".into()));
assert_eq!(url, "https://mcp.figma.com/mcp");
assert_eq!(headers.get("X-Test"), Some(&String::from("test")));
let oauth = oauth.expect("oauth should be present");
assert_eq!(oauth.client_id, "client-id");
assert_eq!(oauth.client_secret.as_deref(), Some("client-secret"));
}
#[test]
fn context_server_http_input_preserves_existing_oauth_settings() {
let text = context_server_http_input(Some((
ContextServerId("figma".into()),
String::from("https://mcp.figma.com/mcp"),
HashMap::default(),
Some(OAuthClientSettings {
client_id: String::from("client-id"),
client_secret: Some(String::from("client-secret")),
}),
)));
let (_, _, _, oauth) = parse_http_input(&text).unwrap();
let oauth = oauth.expect("oauth should be present");
assert_eq!(oauth.client_id, "client-id");
assert_eq!(oauth.client_secret.as_deref(), Some("client-secret"));
}
}

View file

@ -633,6 +633,26 @@ impl TokenResponse {
}
}
/// An OAuth token error response (RFC 6749 Section 5.2).
#[derive(Debug, Deserialize, PartialEq)]
pub struct OAuthTokenError {
pub error: String,
#[serde(default)]
pub error_description: Option<String>,
}
impl std::fmt::Display for OAuthTokenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OAuth token error: {}", self.error)?;
if let Some(description) = &self.error_description {
write!(f, " ({description})")?;
}
Ok(())
}
}
impl std::error::Error for OAuthTokenError {}
/// Build the form-encoded body for an authorization code token exchange.
pub fn token_exchange_params(
code: &str,
@ -640,15 +660,20 @@ pub fn token_exchange_params(
redirect_uri: &str,
code_verifier: &str,
resource: &str,
client_secret: Option<&str>,
) -> Vec<(&'static str, String)> {
vec![
let mut params = vec![
("grant_type", "authorization_code".to_string()),
("code", code.to_string()),
("redirect_uri", redirect_uri.to_string()),
("client_id", client_id.to_string()),
("code_verifier", code_verifier.to_string()),
("resource", resource.to_string()),
]
];
if let Some(secret) = client_secret {
params.push(("client_secret", secret.to_string()));
}
params
}
/// Build the form-encoded body for a token refresh request.
@ -656,13 +681,18 @@ pub fn token_refresh_params(
refresh_token: &str,
client_id: &str,
resource: &str,
client_secret: Option<&str>,
) -> Vec<(&'static str, String)> {
vec![
let mut params = vec![
("grant_type", "refresh_token".to_string()),
("refresh_token", refresh_token.to_string()),
("client_id", client_id.to_string()),
("resource", resource.to_string()),
]
];
if let Some(secret) = client_secret {
params.push(("client_secret", secret.to_string()));
}
params
}
// -- DCR request body (RFC 7591) ---------------------------------------------
@ -782,6 +812,7 @@ pub async fn fetch_auth_server_metadata(
match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
Ok(response) => {
let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
if reported_issuer != *issuer {
bail!(
"Auth server metadata issuer mismatch: expected {}, got {}",
@ -844,15 +875,6 @@ pub async fn discover(
None => bail!("authorization server does not advertise code_challenge_methods_supported"),
}
// Verify there is at least one supported registration strategy before we
// present the server as ready to authenticate.
match determine_registration_strategy(&auth_server_metadata) {
ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
ClientRegistrationStrategy::Unavailable => {
bail!("authorization server supports neither CIMD nor DCR")
}
}
let scopes = select_scopes(www_authenticate, &resource_metadata);
Ok(OAuthDiscovery {
@ -956,8 +978,16 @@ pub async fn exchange_code(
redirect_uri: &str,
code_verifier: &str,
resource: &str,
client_secret: Option<&str>,
) -> Result<OAuthTokens> {
let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
let params = token_exchange_params(
code,
client_id,
redirect_uri,
code_verifier,
resource,
client_secret,
);
post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
}
@ -968,8 +998,9 @@ pub async fn refresh_tokens(
refresh_token: &str,
client_id: &str,
resource: &str,
client_secret: Option<&str>,
) -> Result<OAuthTokens> {
let params = token_refresh_params(refresh_token, client_id, resource);
let params = token_refresh_params(refresh_token, client_id, resource, client_secret);
post_token_request(http_client, token_endpoint, &params).await
}
@ -997,11 +1028,12 @@ async fn post_token_request(
if !response.status().is_success() {
let mut error_body = String::new();
response.body_mut().read_to_string(&mut error_body).await?;
bail!(
"token request failed with status {}: {}",
response.status(),
error_body
);
let status = response.status();
// Try to parse as an OAuth error response (RFC 6749 Section 5.2).
if let Ok(token_error) = serde_json::from_str::<OAuthTokenError>(&error_body) {
return Err(token_error.into());
}
bail!("token request failed with status {status}: {error_body}");
}
let mut response_body = String::new();
@ -1198,7 +1230,7 @@ impl OAuthTokenProvider for McpOAuthTokenProvider {
}
async fn try_refresh(&self) -> Result<bool> {
let (refresh_token, token_endpoint, resource, client_id) = {
let (refresh_token, token_endpoint, resource, client_id, client_secret) = {
let session = self.session.lock();
match session.tokens.refresh_token.clone() {
Some(refresh_token) => (
@ -1206,6 +1238,7 @@ impl OAuthTokenProvider for McpOAuthTokenProvider {
session.token_endpoint.clone(),
session.resource.clone(),
session.client_registration.client_id.clone(),
session.client_registration.client_secret.clone(),
),
None => return Ok(false),
}
@ -1219,6 +1252,7 @@ impl OAuthTokenProvider for McpOAuthTokenProvider {
&refresh_token,
&client_id,
&resource_str,
client_secret.as_deref(),
)
.await
{
@ -1801,6 +1835,7 @@ mod tests {
"http://127.0.0.1:5555/callback",
"verifier_123",
"https://mcp.example.com",
None,
);
let map: std::collections::HashMap<&str, &str> =
params.iter().map(|(k, v)| (*k, v.as_str())).collect();
@ -1815,8 +1850,12 @@ mod tests {
#[test]
fn test_token_refresh_params() {
let params =
token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
let params = token_refresh_params(
"refresh_token_abc",
"client_xyz",
"https://mcp.example.com",
None,
);
let map: std::collections::HashMap<&str, &str> =
params.iter().map(|(k, v)| (*k, v.as_str())).collect();
@ -2422,6 +2461,7 @@ mod tests {
"http://127.0.0.1:9999/callback",
"verifier_abc",
"https://mcp.example.com",
None,
)
.await
.unwrap();
@ -2461,6 +2501,7 @@ mod tests {
"old_refresh_token",
CIMD_URL,
"https://mcp.example.com",
None,
)
.await
.unwrap();
@ -2497,11 +2538,21 @@ mod tests {
"http://127.0.0.1:1/callback",
"verifier",
"https://mcp.example.com",
None,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("400"));
let err = result.unwrap_err();
let token_error = err
.downcast_ref::<OAuthTokenError>()
.expect("expected OAuthTokenError");
assert_eq!(
*token_error,
OAuthTokenError {
error: "invalid_grant".into(),
error_description: None,
}
);
});
}

View file

@ -27,7 +27,7 @@ use util::{ResultExt as _, rel_path::RelPath};
use crate::{
DisableAiSettings, Project,
project_settings::{ContextServerSettings, ProjectSettings},
project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings},
worktree_store::{WorktreeStore, WorktreeStoreEvent},
};
@ -56,6 +56,11 @@ pub enum ContextServerStatus {
/// The server returned 401 and OAuth authorization is needed. The UI
/// should show an "Authenticate" button.
AuthRequired,
/// The server has a pre-registered OAuth client_id, but a client_secret
/// is needed and not available in settings or the keychain.
ClientSecretRequired {
error: Option<Arc<str>>,
},
/// The OAuth browser flow is in progress — the user has been redirected
/// to the authorization server and we're waiting for the callback.
Authenticating,
@ -69,6 +74,11 @@ impl ContextServerStatus {
ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
ContextServerState::ClientSecretRequired { error, .. } => {
ContextServerStatus::ClientSecretRequired {
error: error.clone(),
}
}
ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
}
}
@ -100,6 +110,14 @@ enum ContextServerState {
configuration: Arc<ContextServerConfiguration>,
discovery: Arc<OAuthDiscovery>,
},
/// A pre-registered client_id is configured but no client_secret was found
/// in settings or the keychain.
ClientSecretRequired {
server: Arc<ContextServer>,
configuration: Arc<ContextServerConfiguration>,
discovery: Arc<OAuthDiscovery>,
error: Option<Arc<str>>,
},
/// The OAuth browser flow is in progress. The user has been redirected
/// to the authorization server and we're waiting for the callback.
Authenticating {
@ -117,6 +135,7 @@ impl ContextServerState {
| ContextServerState::Stopped { server, .. }
| ContextServerState::Error { server, .. }
| ContextServerState::AuthRequired { server, .. }
| ContextServerState::ClientSecretRequired { server, .. }
| ContextServerState::Authenticating { server, .. } => server.clone(),
}
}
@ -128,6 +147,7 @@ impl ContextServerState {
| ContextServerState::Stopped { configuration, .. }
| ContextServerState::Error { configuration, .. }
| ContextServerState::AuthRequired { configuration, .. }
| ContextServerState::ClientSecretRequired { configuration, .. }
| ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
}
}
@ -148,6 +168,7 @@ pub enum ContextServerConfiguration {
url: url::Url,
headers: HashMap<String, String>,
timeout: Option<u64>,
oauth: Option<OAuthClientSettings>,
},
}
@ -228,12 +249,14 @@ impl ContextServerConfiguration {
url,
headers: auth,
timeout,
oauth,
} => {
let url = url::Url::parse(&url).log_err()?;
Some(ContextServerConfiguration::Http {
url,
headers: auth,
timeout,
oauth,
})
}
}
@ -841,6 +864,7 @@ impl ContextServerStore {
url,
headers,
timeout,
oauth: _,
} => {
let transport = HttpTransport::new_with_token_provider(
cx.http_client(),
@ -1007,6 +1031,15 @@ impl ContextServerStore {
_ => anyhow::bail!("Server is not in AuthRequired state"),
};
let needs_keychain_check = match configuration.as_ref() {
ContextServerConfiguration::Http {
url,
oauth: Some(oauth_settings),
..
} if oauth_settings.client_secret.is_none() => Some(url.clone()),
_ => None,
};
let id = id.clone();
let task = cx.spawn({
@ -1014,6 +1047,33 @@ impl ContextServerStore {
let server = server.clone();
let configuration = configuration.clone();
async move |this, cx| {
if let Some(server_url) = needs_keychain_check {
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
let has_keychain_secret =
Self::load_client_secret(&credentials_provider, &server_url, cx)
.await
.ok()
.flatten()
.is_some();
if !has_keychain_secret {
this.update(cx, |this, cx| {
this.update_server_state(
id.clone(),
ContextServerState::ClientSecretRequired {
server,
configuration,
discovery,
error: None,
},
cx,
);
})
.log_err();
return;
}
}
let result = Self::run_oauth_flow(
this.clone(),
id.clone(),
@ -1025,15 +1085,13 @@ impl ContextServerStore {
if let Err(err) = &result {
log::error!("{} OAuth authentication failed: {:?}", id, err);
// Transition back to AuthRequired so the user can retry
// rather than landing in a terminal Error state.
this.update(cx, |this, cx| {
this.update_server_state(
id.clone(),
ContextServerState::AuthRequired {
ContextServerState::Error {
server,
configuration,
discovery,
error: format!("{err:#}").into(),
},
cx,
)
@ -1056,6 +1114,121 @@ impl ContextServerStore {
Ok(())
}
/// Store the client secret and proceed with authentication.
pub fn submit_client_secret(
&mut self,
id: &ContextServerId,
secret: String,
cx: &mut Context<Self>,
) -> Result<()> {
let state = self.servers.get(id).context("Context server not found")?;
let (server, configuration, discovery) = match state {
ContextServerState::ClientSecretRequired {
server,
configuration,
discovery,
..
} => (server.clone(), configuration.clone(), discovery.clone()),
_ => anyhow::bail!("Server is not in ClientSecretRequired state"),
};
let server_url = match configuration.as_ref() {
ContextServerConfiguration::Http { url, .. } => url.clone(),
_ => anyhow::bail!("OAuth only supported for HTTP servers"),
};
let id = id.clone();
let task = cx.spawn({
let id = id.clone();
let server = server.clone();
let configuration = configuration.clone();
async move |this, cx| {
// Store the secret if non-empty (empty means public client / skip).
if !secret.is_empty() {
let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
if let Err(err) =
Self::store_client_secret(&credentials_provider, &server_url, &secret, cx)
.await
{
log::error!(
"{} failed to store client secret in keychain: {:?}",
id,
err
);
}
}
let result = Self::run_oauth_flow(
this.clone(),
id.clone(),
discovery.clone(),
configuration.clone(),
cx,
)
.await;
if let Err(err) = &result {
log::error!("{} OAuth authentication failed: {:?}", id, err);
let is_bad_client_credentials = err
.downcast_ref::<oauth::OAuthTokenError>()
.is_some_and(|e| e.error == "unauthorized_client");
if is_bad_client_credentials {
// Clear the bad secret from the keychain so the user
// gets a fresh prompt.
let credentials_provider =
cx.update(|cx| zed_credentials_provider::global(cx));
Self::clear_client_secret(&credentials_provider, &server_url, cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.update_server_state(
id.clone(),
ContextServerState::ClientSecretRequired {
server,
configuration,
discovery,
error: Some(format!("{err:#}").into()),
},
cx,
);
})
.log_err();
} else {
this.update(cx, |this, cx| {
this.update_server_state(
id.clone(),
ContextServerState::Error {
server,
configuration,
error: format!("{err:#}").into(),
},
cx,
)
})
.log_err();
}
}
}
});
self.update_server_state(
id,
ContextServerState::Authenticating {
server,
configuration,
_task: task,
},
cx,
);
Ok(())
}
async fn run_oauth_flow(
this: WeakEntity<Self>,
id: ContextServerId,
@ -1083,10 +1256,30 @@ impl ContextServerStore {
_ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
};
let client_registration =
oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
let client_registration = match configuration.as_ref() {
ContextServerConfiguration::Http {
url,
oauth: Some(oauth_settings),
..
} => {
// Pre-registered client. Resolve the secret from settings, then keychain.
let client_secret = if oauth_settings.client_secret.is_some() {
oauth_settings.client_secret.clone()
} else {
Self::load_client_secret(&credentials_provider, url, cx)
.await
.ok()
.flatten()
};
oauth::OAuthClientRegistration {
client_id: oauth_settings.client_id.clone(),
client_secret,
}
}
_ => oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
.await
.context("Failed to resolve OAuth client registration")?;
.context("Failed to resolve OAuth client registration")?,
};
let auth_url = oauth::build_authorization_url(
&discovery.auth_server_metadata,
@ -1116,6 +1309,7 @@ impl ContextServerStore {
&redirect_uri,
&pkce.verifier,
&resource,
client_registration.client_secret.as_deref(),
)
.await
.context("Failed to exchange authorization code for tokens")?;
@ -1149,6 +1343,7 @@ impl ContextServerStore {
url,
headers,
timeout,
oauth: _,
} => {
let transport = HttpTransport::new_with_token_provider(
http_client.clone(),
@ -1222,6 +1417,46 @@ impl ContextServerStore {
format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
}
fn client_secret_keychain_key(server_url: &url::Url) -> String {
format!(
"mcp-oauth-client-secret:{}",
oauth::canonical_server_uri(server_url)
)
}
async fn load_client_secret(
credentials_provider: &Arc<dyn CredentialsProvider>,
server_url: &url::Url,
cx: &AsyncApp,
) -> Result<Option<String>> {
let key = Self::client_secret_keychain_key(server_url);
match credentials_provider.read_credentials(&key, cx).await? {
Some((_username, secret_bytes)) => Ok(Some(String::from_utf8(secret_bytes)?)),
None => Ok(None),
}
}
pub async fn store_client_secret(
credentials_provider: &Arc<dyn CredentialsProvider>,
server_url: &url::Url,
secret: &str,
cx: &AsyncApp,
) -> Result<()> {
let key = Self::client_secret_keychain_key(server_url);
credentials_provider
.write_credentials(&key, "mcp-oauth-client-secret", secret.as_bytes(), cx)
.await
}
async fn clear_client_secret(
credentials_provider: &Arc<dyn CredentialsProvider>,
server_url: &url::Url,
cx: &AsyncApp,
) -> Result<()> {
let key = Self::client_secret_keychain_key(server_url);
credentials_provider.delete_credentials(&key, cx).await
}
/// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
/// session from the keychain and stop the server.
pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
@ -1241,6 +1476,11 @@ impl ContextServerStore {
if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
log::error!("{} failed to clear OAuth session: {}", id, err);
}
// Also clear any client secret so the user gets a fresh prompt on
// the next authentication attempt.
Self::clear_client_secret(&credentials_provider, &server_url, &cx)
.await
.log_err();
// Trigger server recreation so the next start uses a fresh
// transport without the old (now-invalidated) token provider.
this.update(cx, |this, cx| {
@ -1487,6 +1727,34 @@ async fn resolve_start_failure(
match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
Ok(discovery) => {
use context_server::oauth::{
ClientRegistrationStrategy, determine_registration_strategy,
};
let has_preregistered_client_id = matches!(
configuration.as_ref(),
ContextServerConfiguration::Http { oauth: Some(_), .. }
);
let strategy = determine_registration_strategy(&discovery.auth_server_metadata);
if matches!(strategy, ClientRegistrationStrategy::Unavailable)
&& !has_preregistered_client_id
{
log::error!(
"{id} authorization server supports neither CIMD nor DCR, \
and no pre-registered client_id is configured"
);
return ContextServerState::Error {
configuration,
server,
error: "Authorization server supports neither CIMD nor DCR. \
Configure a pre-registered client_id in your settings \
under the \"oauth\" key."
.into(),
};
}
log::info!(
"{id} requires OAuth authorization (auth server: {})",
discovery.auth_server_metadata.issuer,

View file

@ -201,6 +201,10 @@ pub enum ContextServerSettings {
headers: HashMap<String, String>,
/// Timeout for tool calls in milliseconds.
timeout: Option<u64>,
/// Pre-registered OAuth client credentials for authorization servers that
/// require out-of-band client registration.
#[serde(default, skip_serializing_if = "Option::is_none")]
oauth: Option<OAuthClientSettings>,
},
Extension {
/// Whether the context server is enabled.
@ -243,11 +247,16 @@ impl From<settings::ContextServerSettingsContent> for ContextServerSettings {
url,
headers,
timeout,
oauth,
} => ContextServerSettings::Http {
enabled,
url,
headers,
timeout,
oauth: oauth.map(|o| OAuthClientSettings {
client_id: o.client_id,
client_secret: o.client_secret,
}),
},
}
}
@ -278,16 +287,35 @@ impl Into<settings::ContextServerSettingsContent> for ContextServerSettings {
url,
headers,
timeout,
oauth,
} => settings::ContextServerSettingsContent::Http {
enabled,
url,
headers,
timeout,
oauth: oauth.map(|o| settings::OAuthClientSettings {
client_id: o.client_id,
client_secret: o.client_secret,
}),
},
}
}
}
/// Pre-registered OAuth client credentials for MCP servers that don't support
/// Dynamic Client Registration.
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct OAuthClientSettings {
/// The OAuth client ID obtained from out-of-band registration with the
/// authorization server.
pub client_id: String,
/// The OAuth client secret, if this is a confidential client. For security,
/// prefer providing this interactively; we will prompt and store it in
/// the system keychain.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
}
impl ContextServerSettings {
pub fn default_extension() -> Self {
Self::Extension {

View file

@ -897,6 +897,7 @@ async fn test_remote_context_server(cx: &mut TestAppContext) {
url: server_url.to_string(),
headers: Default::default(),
timeout: None,
oauth: None,
},
)],
cx,
@ -963,6 +964,7 @@ async fn test_context_server_global_timeout(cx: &mut TestAppContext) {
url: url::Url::parse("http://localhost:8080").expect("Failed to parse test URL"),
headers: Default::default(),
timeout: None,
oauth: None,
}),
&mut async_cx,
)
@ -998,6 +1000,7 @@ async fn test_context_server_per_server_timeout_override(cx: &mut TestAppContext
url: "http://localhost:8080".to_string(),
headers: Default::default(),
timeout: Some(120),
oauth: None,
},
)],
)
@ -1021,6 +1024,7 @@ async fn test_context_server_per_server_timeout_override(cx: &mut TestAppContext
url: url::Url::parse("http://localhost:8080").expect("Failed to parse test URL"),
headers: Default::default(),
timeout: Some(120),
oauth: None,
}),
&mut async_cx,
)

View file

@ -388,6 +388,10 @@ pub enum ContextServerSettingsContent {
headers: HashMap<String, String>,
/// Timeout for tool calls in seconds. Defaults to global context_server_timeout if not specified.
timeout: Option<u64>,
/// Pre-registered OAuth client credentials for authorization servers that
/// require out-of-band client registration.
#[serde(default, skip_serializing_if = "Option::is_none")]
oauth: Option<OAuthClientSettings>,
},
Extension {
/// Whether the context server is enabled.
@ -429,6 +433,20 @@ impl ContextServerSettingsContent {
}
}
/// Pre-registered OAuth client credentials for MCP servers that don't support
/// Dynamic Client Registration.
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, MergeFrom, Debug)]
pub struct OAuthClientSettings {
/// The OAuth client ID obtained from out-of-band registration with the
/// authorization server.
pub client_id: String,
/// The OAuth client secret, if this is a confidential client. For security,
/// prefer providing this interactively; we will prompt and store it in
/// the system keychain.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
}
#[with_fallible_options]
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, MergeFrom)]
pub struct ContextServerCommand {

View file

@ -10,6 +10,7 @@ pub enum AiSettingItemStatus {
Running,
Error,
AuthRequired,
ClientSecretRequired,
Authenticating,
}
@ -21,6 +22,7 @@ impl AiSettingItemStatus {
Self::Running => "Server is active.",
Self::Error => "Server has an error.",
Self::AuthRequired => "Authentication required.",
Self::ClientSecretRequired => "Client secret required.",
Self::Authenticating => "Waiting for authorization…",
}
}
@ -31,7 +33,7 @@ impl AiSettingItemStatus {
Self::Starting | Self::Authenticating => Some(Color::Muted),
Self::Running => Some(Color::Success),
Self::Error => Some(Color::Error),
Self::AuthRequired => Some(Color::Warning),
Self::AuthRequired | Self::ClientSecretRequired => Some(Color::Warning),
}
}