From 0d832bc6d5498f545c5f05ba1f1fc84285434eb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Houl=C3=A9?= <13155277+tomhoule@users.noreply.github.com> Date: Tue, 19 May 2026 19:45:07 +0200 Subject: [PATCH] Implement MCP OAuth client preregistration (#52900) In the interactive MCP OAuth flow, the MCP client registers itself with the authorization in one of three ways: - Client ID Metadata Document aka CIMD (recommended default). This is already implemented: https://zed.dev/oauth/client-metadata.json. - Dynamic Client Registration (DCR). This is the traditional method. Also already implemented in Zed. - Pre-registration: the client is registered out of band, typically in the IdP or SaaS provider's UI. You get a client id and maybe a client secret, that have to be provided by the MCP client when it wants to exchange an access token. This is what this pull request is about. This PR has two main parts: - Allow users to configure a client id and optional client secret for an MCP server in their configuration, under a new `oauth` key, and take it into account - Make the MCP server state and the configuration modal aware of the intermediate states (client secret missing) and error cases stemming from client pre-registration. The client secret can be stored either in the system keychain or in plain text in the MCP server configuration. The UI tries to steer user towards the more secure option: the keychain. Screenshot 2026-04-10 at 16 48 06 Screenshot 2026-04-10 at 16 47 07 Screenshot 2026-04-10 at 16 47 23 Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes https://github.com/issues/assigned?issue=zed-industries%7Czed%7C52198 **Note for the reviewer: I know how busy the AI team is at the moment so please treat this as low priority, we don't have signal that this is a highly desired feature. It's a rather large PR, so I'm happy to pair review / walk through it.** Release Notes: - Added support for OAuth client pre-registration (client id, client secret) to the built-in MCP client. --- .../src/tools/context_server_registry.rs | 3 +- crates/agent_servers/src/acp.rs | 1 + crates/agent_ui/src/agent_configuration.rs | 51 ++- .../configure_context_server_modal.rs | 368 ++++++++++++++++-- crates/context_server/src/oauth.rs | 101 +++-- crates/project/src/context_server_store.rs | 284 +++++++++++++- crates/project/src/project_settings.rs | 28 ++ .../tests/integration/context_server_store.rs | 4 + crates/settings_content/src/project.rs | 18 + .../ui/src/components/ai/ai_setting_item.rs | 4 +- 10 files changed, 797 insertions(+), 65 deletions(-) diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 6c0e8d31557..d9dc972e24f 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -261,7 +261,8 @@ impl ContextServerRegistry { } ContextServerStatus::Stopped | ContextServerStatus::Error(_) - | ContextServerStatus::AuthRequired => { + | ContextServerStatus::AuthRequired + | ContextServerStatus::ClientSecretRequired { .. } => { if let Some(registered_server) = self.registered_servers.remove(server_id) { if !registered_server.tools.is_empty() { cx.emit(ContextServerRegistryEvent::ToolsChanged); diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index ff5519b7240..3a718c7a9e8 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -3844,6 +3844,7 @@ fn mcp_servers_for_project(project: &Entity, cx: &App) -> Vec Some(acp::McpServer::Http( acp::McpServerHttp::new(id.0.to_string(), url.to_string()).headers( headers diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 67d21211026..eb6ea3e81fc 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -664,8 +664,14 @@ impl AgentConfiguration { None }; let auth_required = matches!(server_status, ContextServerStatus::AuthRequired); + let client_secret_required = matches!( + server_status, + ContextServerStatus::ClientSecretRequired { .. } + ); let authenticating = matches!(server_status, ContextServerStatus::Authenticating); let context_server_store = self.context_server_store.clone(); + let workspace = self.workspace.clone(); + let language_registry = self.language_registry.clone(); let tool_count = self .context_server_registry @@ -685,6 +691,9 @@ impl AgentConfiguration { ContextServerStatus::Error(_) => AiSettingItemStatus::Error, ContextServerStatus::Stopped => AiSettingItemStatus::Stopped, ContextServerStatus::AuthRequired => AiSettingItemStatus::AuthRequired, + ContextServerStatus::ClientSecretRequired { .. } => { + AiSettingItemStatus::ClientSecretRequired + } ContextServerStatus::Authenticating => AiSettingItemStatus::Authenticating, }; @@ -886,7 +895,7 @@ impl AgentConfiguration { ), ) .child( - Button::new("error-logout-server", "Authenticate") + Button::new("authenticate-server", "Authenticate") .style(ButtonStyle::Outlined) .label_size(LabelSize::Small) .on_click({ @@ -900,6 +909,46 @@ impl AgentConfiguration { ) .into_any_element(), ) + } else if client_secret_required { + Some( + feedback_base_container() + .child( + h_flex() + .pr_4() + .min_w_0() + .w_full() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new("Enter a client secret to connect this server") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ) + .child( + Button::new("enter-client-secret", "Enter Client Secret") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .on_click({ + let context_server_id = context_server_id.clone(); + move |_event, window, cx| { + ConfigureContextServerModal::show_modal_for_existing_server( + context_server_id.clone(), + language_registry.clone(), + workspace.clone(), + window, + cx, + ) + .detach(); + } + }), + ) + .into_any_element(), + ) } else if authenticating { Some( h_flex() diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index 48d01e506bf..5ccc901b4a4 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -17,7 +17,7 @@ use project::{ ContextServerStatus, ContextServerStore, ServerStatusChangedEvent, registry::ContextServerDescriptorRegistry, }, - project_settings::{ContextServerSettings, ProjectSettings}, + project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings}, worktree_store::WorktreeStore, }; use serde::Deserialize; @@ -43,7 +43,9 @@ enum ConfigurationTarget { id: ContextServerId, url: String, headers: HashMap, + oauth: Option, }, + Extension { id: ContextServerId, repository_url: Option, @@ -121,15 +123,17 @@ impl ConfigurationSource { id, url, headers: auth, + oauth, } => ConfigurationSource::Existing { editor: create_editor( - context_server_http_input(Some((id, url, auth))), + context_server_http_input(Some((id, url, auth, oauth))), jsonc_language, window, cx, ), is_http: true, }, + ConfigurationTarget::Extension { id, repository_url, @@ -168,7 +172,7 @@ impl ConfigurationSource { ConfigurationSource::New { editor, is_http } | ConfigurationSource::Existing { editor, is_http } => { if *is_http { - parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth)| { + parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth, oauth)| { ( id, ContextServerSettings::Http { @@ -176,6 +180,7 @@ impl ConfigurationSource { url, headers: auth, timeout: None, + oauth, }, ) }) @@ -256,11 +261,16 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) } fn context_server_http_input( - existing: Option<(ContextServerId, String, HashMap)>, + existing: Option<( + ContextServerId, + String, + HashMap, + Option, + )>, ) -> String { - let (name, url, headers) = match existing { - Some((id, url, headers)) => { - let header = if headers.is_empty() { + let (name, url, headers, oauth) = match existing { + Some((id, url, headers, oauth)) => { + let headers = if headers.is_empty() { r#"// "Authorization": "Bearer "#.to_string() } else { let json = serde_json::to_string_pretty(&headers).unwrap(); @@ -274,15 +284,48 @@ fn context_server_http_input( .map(|line| format!(" {}", line)) .collect::() }; - (id.0.to_string(), url, header) + (id.0.to_string(), url, headers, oauth) } None => ( "some-remote-server".to_string(), "https://example.com/mcp".to_string(), r#"// "Authorization": "Bearer "#.to_string(), + None, ), }; + let oauth = oauth.map_or_else( + || { + r#" + /// Uncomment to use a pre-registered OAuth client. You can include the client secret here as well, otherwise it will be prompted interactively and saved in the system keychain. + // "oauth": { + // "client_id": "your-client-id", + // },"# + .to_string() + }, + + |oauth| { + let mut lines = vec![ + String::from("\n \"oauth\": {"), + + format!(" \"client_id\": {},", serde_json::to_string(&oauth.client_id).unwrap()), + ]; + if let Some(client_secret) = oauth.client_secret { + lines.push(format!( + " \"client_secret\": {}", + serde_json::to_string(&client_secret).unwrap() + )); + } else { + lines.push(String::from( + " /// Optional client secret for confidential clients\n // \"client_secret\": \"your-client-secret\"", + )); + } + lines.push(String::from(" },")); + + lines.join("\n") + }, + ); + format!( r#"{{ /// Configure an MCP server that you connect to over HTTP @@ -290,7 +333,7 @@ fn context_server_http_input( /// The name of your remote MCP server "{name}": {{ /// The URL of the remote MCP server - "url": "{url}", + "url": "{url}",{oauth} "headers": {{ /// Any headers to send along {headers} @@ -300,12 +343,21 @@ fn context_server_http_input( ) } -fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap)> { +fn parse_http_input( + text: &str, +) -> Result<( + ContextServerId, + String, + HashMap, + Option, +)> { #[derive(Deserialize)] struct Temp { url: String, #[serde(default)] headers: HashMap, + #[serde(default)] + oauth: Option, } let value: HashMap = serde_json_lenient::from_str(text)?; if value.len() != 1 { @@ -314,7 +366,12 @@ fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap, + }, + Authenticating { + server_id: ContextServerId, + }, Error(SharedString), } @@ -361,10 +426,47 @@ pub struct ConfigureContextServerModal { state: State, original_server_id: Option, scroll_handle: ScrollHandle, + secret_editor: Entity, _auth_subscription: Option, } impl ConfigureContextServerModal { + fn initial_state( + context_server_store: &Entity, + target: &ConfigurationTarget, + cx: &App, + ) -> State { + let Some(server_id) = (match target { + ConfigurationTarget::Existing { id, .. } + | ConfigurationTarget::ExistingHttp { id, .. } + | ConfigurationTarget::Extension { id, .. } => Some(id), + ConfigurationTarget::New => None, + }) else { + return State::Idle; + }; + + match context_server_store.read(cx).status_for_server(server_id) { + Some(ContextServerStatus::AuthRequired) => State::AuthRequired { + server_id: server_id.clone(), + }, + Some(ContextServerStatus::ClientSecretRequired { error }) => { + State::ClientSecretRequired { + server_id: server_id.clone(), + error: error.map(SharedString::from), + } + } + Some(ContextServerStatus::Authenticating) => State::Authenticating { + server_id: server_id.clone(), + }, + Some(ContextServerStatus::Error(error)) => State::Error(error.into()), + + Some(ContextServerStatus::Starting) + | Some(ContextServerStatus::Running) + | Some(ContextServerStatus::Stopped) + | None => State::Idle, + } + } + pub fn register( workspace: &mut Workspace, language_registry: Arc, @@ -426,12 +528,14 @@ impl ConfigureContextServerModal { url, headers, timeout: _, - .. + oauth, } => Some(ConfigurationTarget::ExistingHttp { id: server_id, url, headers, + oauth, }), + ContextServerSettings::Extension { .. } => { match workspace .update(cx, |workspace, cx| { @@ -468,9 +572,10 @@ impl ConfigureContextServerModal { let workspace_handle = cx.weak_entity(); let context_server_store = workspace.project().read(cx).context_server_store(); workspace.toggle_modal(window, cx, |window, cx| Self { - context_server_store, + context_server_store: context_server_store.clone(), workspace: workspace_handle, - state: State::Idle, + state: Self::initial_state(&context_server_store, &target, cx), + original_server_id: match &target { ConfigurationTarget::Existing { id, .. } => Some(id.clone()), ConfigurationTarget::ExistingHttp { id, .. } => Some(id.clone()), @@ -485,6 +590,16 @@ impl ConfigureContextServerModal { cx, ), scroll_handle: ScrollHandle::new(), + secret_editor: cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text( + "Enter client secret (leave empty for public clients)", + window, + cx, + ); + editor.set_masked(true, cx); + editor + }), _auth_subscription: None, }) }) @@ -497,13 +612,12 @@ impl ConfigureContextServerModal { } fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context) { - if matches!( - self.state, - State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. } - ) { + if matches!(self.state, State::Waiting | State::Authenticating { .. }) { return; } + self._auth_subscription = None; + self.state = State::Idle; let Some(workspace) = self.workspace.upgrade() else { return; @@ -519,7 +633,7 @@ impl ConfigureContextServerModal { self.state = State::Waiting; - let existing_server = self.context_server_store.read(cx).get_running_server(&id); + let existing_server = self.context_server_store.read(cx).get_server(&id); if existing_server.is_some() { self.context_server_store.update(cx, |store, cx| { store.stop_server(&id, cx).log_err(); @@ -542,6 +656,13 @@ impl ConfigureContextServerModal { this.state = State::AuthRequired { server_id: id }; cx.notify(); } + Ok(ContextServerStatus::ClientSecretRequired { error }) => { + this.state = State::ClientSecretRequired { + server_id: id, + error: error.map(SharedString::from), + }; + cx.notify(); + } Err(err) => { this.set_error(err, cx); } @@ -581,13 +702,33 @@ impl ConfigureContextServerModal { cx.emit(DismissEvent); } + fn cancel_authentication(&mut self, server_id: &ContextServerId, cx: &mut Context) { + self._auth_subscription = None; + self.context_server_store.update(cx, |store, cx| { + store.stop_server(server_id, cx).log_err(); + }); + self.state = State::Idle; + cx.notify(); + } + fn authenticate(&mut self, server_id: ContextServerId, cx: &mut Context) { self.context_server_store.update(cx, |store, cx| { store.authenticate_server(&server_id, cx).log_err(); }); + self.await_auth_outcome(server_id, cx); + } + fn submit_client_secret(&mut self, server_id: ContextServerId, cx: &mut Context) { + let secret = self.secret_editor.read(cx).text(cx); + self.context_server_store.update(cx, |store, cx| { + store.submit_client_secret(&server_id, secret, cx).log_err(); + }); + self.await_auth_outcome(server_id, cx); + } + + fn await_auth_outcome(&mut self, server_id: ContextServerId, cx: &mut Context) { self.state = State::Authenticating { - _server_id: server_id.clone(), + server_id: server_id.clone(), }; self._auth_subscription = Some(cx.subscribe( @@ -610,6 +751,14 @@ impl ConfigureContextServerModal { }; cx.notify(); } + ContextServerStatus::ClientSecretRequired { error } => { + this._auth_subscription = None; + this.state = State::ClientSecretRequired { + server_id: event.server_id.clone(), + error: error.clone().map(SharedString::from), + }; + cx.notify(); + } ContextServerStatus::Error(error) => { this._auth_subscription = None; this.set_error(error.clone(), cx); @@ -814,10 +963,7 @@ impl ConfigureContextServerModal { fn render_modal_footer(&self, cx: &mut Context) -> ModalFooter { let focus_handle = self.focus_handle(cx); - let is_busy = matches!( - self.state, - State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. } - ); + let is_busy = matches!(self.state, State::Waiting | State::Authenticating { .. }); ModalFooter::new() .start_slot::