diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 6c0e8d31557..d9dc972e24f 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -261,7 +261,8 @@ impl ContextServerRegistry { } ContextServerStatus::Stopped | ContextServerStatus::Error(_) - | ContextServerStatus::AuthRequired => { + | ContextServerStatus::AuthRequired + | ContextServerStatus::ClientSecretRequired { .. } => { if let Some(registered_server) = self.registered_servers.remove(server_id) { if !registered_server.tools.is_empty() { cx.emit(ContextServerRegistryEvent::ToolsChanged); diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index ff5519b7240..3a718c7a9e8 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -3844,6 +3844,7 @@ fn mcp_servers_for_project(project: &Entity, cx: &App) -> Vec Some(acp::McpServer::Http( acp::McpServerHttp::new(id.0.to_string(), url.to_string()).headers( headers diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 67d21211026..eb6ea3e81fc 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -664,8 +664,14 @@ impl AgentConfiguration { None }; let auth_required = matches!(server_status, ContextServerStatus::AuthRequired); + let client_secret_required = matches!( + server_status, + ContextServerStatus::ClientSecretRequired { .. } + ); let authenticating = matches!(server_status, ContextServerStatus::Authenticating); let context_server_store = self.context_server_store.clone(); + let workspace = self.workspace.clone(); + let language_registry = self.language_registry.clone(); let tool_count = self .context_server_registry @@ -685,6 +691,9 @@ impl AgentConfiguration { ContextServerStatus::Error(_) => AiSettingItemStatus::Error, ContextServerStatus::Stopped => AiSettingItemStatus::Stopped, ContextServerStatus::AuthRequired => AiSettingItemStatus::AuthRequired, + ContextServerStatus::ClientSecretRequired { .. } => { + AiSettingItemStatus::ClientSecretRequired + } ContextServerStatus::Authenticating => AiSettingItemStatus::Authenticating, }; @@ -886,7 +895,7 @@ impl AgentConfiguration { ), ) .child( - Button::new("error-logout-server", "Authenticate") + Button::new("authenticate-server", "Authenticate") .style(ButtonStyle::Outlined) .label_size(LabelSize::Small) .on_click({ @@ -900,6 +909,46 @@ impl AgentConfiguration { ) .into_any_element(), ) + } else if client_secret_required { + Some( + feedback_base_container() + .child( + h_flex() + .pr_4() + .min_w_0() + .w_full() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new("Enter a client secret to connect this server") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ) + .child( + Button::new("enter-client-secret", "Enter Client Secret") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .on_click({ + let context_server_id = context_server_id.clone(); + move |_event, window, cx| { + ConfigureContextServerModal::show_modal_for_existing_server( + context_server_id.clone(), + language_registry.clone(), + workspace.clone(), + window, + cx, + ) + .detach(); + } + }), + ) + .into_any_element(), + ) } else if authenticating { Some( h_flex() diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index 48d01e506bf..5ccc901b4a4 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -17,7 +17,7 @@ use project::{ ContextServerStatus, ContextServerStore, ServerStatusChangedEvent, registry::ContextServerDescriptorRegistry, }, - project_settings::{ContextServerSettings, ProjectSettings}, + project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings}, worktree_store::WorktreeStore, }; use serde::Deserialize; @@ -43,7 +43,9 @@ enum ConfigurationTarget { id: ContextServerId, url: String, headers: HashMap, + oauth: Option, }, + Extension { id: ContextServerId, repository_url: Option, @@ -121,15 +123,17 @@ impl ConfigurationSource { id, url, headers: auth, + oauth, } => ConfigurationSource::Existing { editor: create_editor( - context_server_http_input(Some((id, url, auth))), + context_server_http_input(Some((id, url, auth, oauth))), jsonc_language, window, cx, ), is_http: true, }, + ConfigurationTarget::Extension { id, repository_url, @@ -168,7 +172,7 @@ impl ConfigurationSource { ConfigurationSource::New { editor, is_http } | ConfigurationSource::Existing { editor, is_http } => { if *is_http { - parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth)| { + parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth, oauth)| { ( id, ContextServerSettings::Http { @@ -176,6 +180,7 @@ impl ConfigurationSource { url, headers: auth, timeout: None, + oauth, }, ) }) @@ -256,11 +261,16 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) } fn context_server_http_input( - existing: Option<(ContextServerId, String, HashMap)>, + existing: Option<( + ContextServerId, + String, + HashMap, + Option, + )>, ) -> String { - let (name, url, headers) = match existing { - Some((id, url, headers)) => { - let header = if headers.is_empty() { + let (name, url, headers, oauth) = match existing { + Some((id, url, headers, oauth)) => { + let headers = if headers.is_empty() { r#"// "Authorization": "Bearer "#.to_string() } else { let json = serde_json::to_string_pretty(&headers).unwrap(); @@ -274,15 +284,48 @@ fn context_server_http_input( .map(|line| format!(" {}", line)) .collect::() }; - (id.0.to_string(), url, header) + (id.0.to_string(), url, headers, oauth) } None => ( "some-remote-server".to_string(), "https://example.com/mcp".to_string(), r#"// "Authorization": "Bearer "#.to_string(), + None, ), }; + let oauth = oauth.map_or_else( + || { + r#" + /// Uncomment to use a pre-registered OAuth client. You can include the client secret here as well, otherwise it will be prompted interactively and saved in the system keychain. + // "oauth": { + // "client_id": "your-client-id", + // },"# + .to_string() + }, + + |oauth| { + let mut lines = vec![ + String::from("\n \"oauth\": {"), + + format!(" \"client_id\": {},", serde_json::to_string(&oauth.client_id).unwrap()), + ]; + if let Some(client_secret) = oauth.client_secret { + lines.push(format!( + " \"client_secret\": {}", + serde_json::to_string(&client_secret).unwrap() + )); + } else { + lines.push(String::from( + " /// Optional client secret for confidential clients\n // \"client_secret\": \"your-client-secret\"", + )); + } + lines.push(String::from(" },")); + + lines.join("\n") + }, + ); + format!( r#"{{ /// Configure an MCP server that you connect to over HTTP @@ -290,7 +333,7 @@ fn context_server_http_input( /// The name of your remote MCP server "{name}": {{ /// The URL of the remote MCP server - "url": "{url}", + "url": "{url}",{oauth} "headers": {{ /// Any headers to send along {headers} @@ -300,12 +343,21 @@ fn context_server_http_input( ) } -fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap)> { +fn parse_http_input( + text: &str, +) -> Result<( + ContextServerId, + String, + HashMap, + Option, +)> { #[derive(Deserialize)] struct Temp { url: String, #[serde(default)] headers: HashMap, + #[serde(default)] + oauth: Option, } let value: HashMap = serde_json_lenient::from_str(text)?; if value.len() != 1 { @@ -314,7 +366,12 @@ fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap, + }, + Authenticating { + server_id: ContextServerId, + }, Error(SharedString), } @@ -361,10 +426,47 @@ pub struct ConfigureContextServerModal { state: State, original_server_id: Option, scroll_handle: ScrollHandle, + secret_editor: Entity, _auth_subscription: Option, } impl ConfigureContextServerModal { + fn initial_state( + context_server_store: &Entity, + target: &ConfigurationTarget, + cx: &App, + ) -> State { + let Some(server_id) = (match target { + ConfigurationTarget::Existing { id, .. } + | ConfigurationTarget::ExistingHttp { id, .. } + | ConfigurationTarget::Extension { id, .. } => Some(id), + ConfigurationTarget::New => None, + }) else { + return State::Idle; + }; + + match context_server_store.read(cx).status_for_server(server_id) { + Some(ContextServerStatus::AuthRequired) => State::AuthRequired { + server_id: server_id.clone(), + }, + Some(ContextServerStatus::ClientSecretRequired { error }) => { + State::ClientSecretRequired { + server_id: server_id.clone(), + error: error.map(SharedString::from), + } + } + Some(ContextServerStatus::Authenticating) => State::Authenticating { + server_id: server_id.clone(), + }, + Some(ContextServerStatus::Error(error)) => State::Error(error.into()), + + Some(ContextServerStatus::Starting) + | Some(ContextServerStatus::Running) + | Some(ContextServerStatus::Stopped) + | None => State::Idle, + } + } + pub fn register( workspace: &mut Workspace, language_registry: Arc, @@ -426,12 +528,14 @@ impl ConfigureContextServerModal { url, headers, timeout: _, - .. + oauth, } => Some(ConfigurationTarget::ExistingHttp { id: server_id, url, headers, + oauth, }), + ContextServerSettings::Extension { .. } => { match workspace .update(cx, |workspace, cx| { @@ -468,9 +572,10 @@ impl ConfigureContextServerModal { let workspace_handle = cx.weak_entity(); let context_server_store = workspace.project().read(cx).context_server_store(); workspace.toggle_modal(window, cx, |window, cx| Self { - context_server_store, + context_server_store: context_server_store.clone(), workspace: workspace_handle, - state: State::Idle, + state: Self::initial_state(&context_server_store, &target, cx), + original_server_id: match &target { ConfigurationTarget::Existing { id, .. } => Some(id.clone()), ConfigurationTarget::ExistingHttp { id, .. } => Some(id.clone()), @@ -485,6 +590,16 @@ impl ConfigureContextServerModal { cx, ), scroll_handle: ScrollHandle::new(), + secret_editor: cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text( + "Enter client secret (leave empty for public clients)", + window, + cx, + ); + editor.set_masked(true, cx); + editor + }), _auth_subscription: None, }) }) @@ -497,13 +612,12 @@ impl ConfigureContextServerModal { } fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context) { - if matches!( - self.state, - State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. } - ) { + if matches!(self.state, State::Waiting | State::Authenticating { .. }) { return; } + self._auth_subscription = None; + self.state = State::Idle; let Some(workspace) = self.workspace.upgrade() else { return; @@ -519,7 +633,7 @@ impl ConfigureContextServerModal { self.state = State::Waiting; - let existing_server = self.context_server_store.read(cx).get_running_server(&id); + let existing_server = self.context_server_store.read(cx).get_server(&id); if existing_server.is_some() { self.context_server_store.update(cx, |store, cx| { store.stop_server(&id, cx).log_err(); @@ -542,6 +656,13 @@ impl ConfigureContextServerModal { this.state = State::AuthRequired { server_id: id }; cx.notify(); } + Ok(ContextServerStatus::ClientSecretRequired { error }) => { + this.state = State::ClientSecretRequired { + server_id: id, + error: error.map(SharedString::from), + }; + cx.notify(); + } Err(err) => { this.set_error(err, cx); } @@ -581,13 +702,33 @@ impl ConfigureContextServerModal { cx.emit(DismissEvent); } + fn cancel_authentication(&mut self, server_id: &ContextServerId, cx: &mut Context) { + self._auth_subscription = None; + self.context_server_store.update(cx, |store, cx| { + store.stop_server(server_id, cx).log_err(); + }); + self.state = State::Idle; + cx.notify(); + } + fn authenticate(&mut self, server_id: ContextServerId, cx: &mut Context) { self.context_server_store.update(cx, |store, cx| { store.authenticate_server(&server_id, cx).log_err(); }); + self.await_auth_outcome(server_id, cx); + } + fn submit_client_secret(&mut self, server_id: ContextServerId, cx: &mut Context) { + let secret = self.secret_editor.read(cx).text(cx); + self.context_server_store.update(cx, |store, cx| { + store.submit_client_secret(&server_id, secret, cx).log_err(); + }); + self.await_auth_outcome(server_id, cx); + } + + fn await_auth_outcome(&mut self, server_id: ContextServerId, cx: &mut Context) { self.state = State::Authenticating { - _server_id: server_id.clone(), + server_id: server_id.clone(), }; self._auth_subscription = Some(cx.subscribe( @@ -610,6 +751,14 @@ impl ConfigureContextServerModal { }; cx.notify(); } + ContextServerStatus::ClientSecretRequired { error } => { + this._auth_subscription = None; + this.state = State::ClientSecretRequired { + server_id: event.server_id.clone(), + error: error.clone().map(SharedString::from), + }; + cx.notify(); + } ContextServerStatus::Error(error) => { this._auth_subscription = None; this.set_error(error.clone(), cx); @@ -814,10 +963,7 @@ impl ConfigureContextServerModal { fn render_modal_footer(&self, cx: &mut Context) -> ModalFooter { let focus_handle = self.focus_handle(cx); - let is_busy = matches!( - self.state, - State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. } - ); + let is_busy = matches!(self.state, State::Waiting | State::Authenticating { .. }); ModalFooter::new() .start_slot::