From bc64e1f9556228a95b8aa772680352688310387f Mon Sep 17 00:00:00 2001 From: Anthony Eid <56899983+Anthony-Eid@users.noreply.github.com> Date: Thu, 28 May 2026 11:00:31 -0400 Subject: [PATCH] copilot: Fix auth db fallback (#57764) This PR uses https://github.com/zed-industries/zed/pull/57758 as a base and adds tests, cleans up the comments, and checks changes the database query used in auth.db to include oauth key. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - Fixed GitHub Copilot Chat showing an empty model dropdown for users on newer Copilot SDK builds --------- Co-authored-by: Alexander Shlemov Co-authored-by: cameron --- Cargo.lock | 2 + crates/copilot/src/copilot.rs | 21 ++- crates/copilot_chat/Cargo.toml | 2 + crates/copilot_chat/src/copilot_chat.rs | 208 +++++++++++++++++++++--- 4 files changed, 210 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0411aa04340..0cc384eef22 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3880,6 +3880,8 @@ dependencies = [ "serde", "serde_json", "settings", + "sqlez", + "tempfile", ] [[package]] diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index a48bf3c1a43..6936a5a416c 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -511,8 +511,14 @@ impl Copilot { }; } - if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) { - env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token); + for env_var in [ + copilot_chat::COPILOT_OAUTH_ENV_VAR, + copilot_chat::GITHUB_COPILOT_OAUTH_ENV_VAR, + ] { + if let Ok(oauth_token) = env::var(env_var) { + env.insert(env_var.to_string(), oauth_token); + break; + } } if env.is_empty() { None } else { Some(env) } @@ -1259,6 +1265,7 @@ impl Copilot { | request::SignInStatus::AlreadySignedIn { .. } => { server.sign_in_status = SignInStatus::Authorized; cx.emit(Event::CopilotAuthSignedIn); + notify_copilot_chat_auth_changed(cx); for buffer in self.buffers.iter().cloned().collect::>() { if let Some(buffer) = buffer.upgrade() { self.register_buffer(&buffer, cx); @@ -1278,6 +1285,7 @@ impl Copilot { }; } cx.emit(Event::CopilotAuthSignedOut); + notify_copilot_chat_auth_changed(cx); for buffer in self.buffers.iter().cloned().collect::>() { self.unregister_buffer(&buffer); } @@ -1381,6 +1389,15 @@ fn notify_did_change_config_to_server( Ok(()) } +/// Notify Copilot Chat after the Copilot LSP reports an auth state change. +/// This replaces watching the SDK's token files, which is unreliable for +/// SQLite backed auth because writes may go through WAL files. +fn notify_copilot_chat_auth_changed(cx: &mut Context) { + if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) { + copilot_chat.update(cx, |chat, cx| chat.reload_auth(cx)); + } +} + async fn clear_copilot_dir() { remove_matching(paths::copilot_dir(), |_| true).await } diff --git a/crates/copilot_chat/Cargo.toml b/crates/copilot_chat/Cargo.toml index 79159d59cc0..c6e6253bf45 100644 --- a/crates/copilot_chat/Cargo.toml +++ b/crates/copilot_chat/Cargo.toml @@ -34,7 +34,9 @@ paths.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sqlez.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } serde_json.workspace = true +tempfile.workspace = true diff --git a/crates/copilot_chat/src/copilot_chat.rs b/crates/copilot_chat/src/copilot_chat.rs index ab5c08b6174..4d0e5e6c46e 100644 --- a/crates/copilot_chat/src/copilot_chat.rs +++ b/crates/copilot_chat/src/copilot_chat.rs @@ -1,6 +1,6 @@ pub mod responses; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::OnceLock; @@ -17,9 +17,10 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use paths::home_dir; use serde::{Deserialize, Serialize}; -use settings::watch_config_dir; - +// The Copilot language server unofficially supports both token env vars: +// https://github.com/github/copilot-language-server-release/issues/3#issuecomment-2699433055 pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN"; +pub const GITHUB_COPILOT_OAUTH_ENV_VAR: &str = "GITHUB_COPILOT_TOKEN"; const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com"; #[derive(Default, Clone, Debug, PartialEq)] @@ -501,6 +502,7 @@ pub struct CopilotChat { configuration: CopilotChatConfiguration, models: Option>, client: Arc, + fs: Arc, } pub fn init( @@ -529,11 +531,19 @@ pub fn copilot_chat_config_dir() -> &'static PathBuf { }) } +/// Legacy JSON token-storage paths used by older Copilot SDK builds. +/// TODO(copilot): once Copilot SDK supports `auth.db`, remove these paths. fn copilot_chat_config_paths() -> [PathBuf; 2] { let base_dir = copilot_chat_config_dir(); [base_dir.join("hosts.json"), base_dir.join("apps.json")] } +fn oauth_token_from_env() -> Option { + std::env::var(COPILOT_OAUTH_ENV_VAR) + .ok() + .or_else(|| std::env::var(GITHUB_COPILOT_OAUTH_ENV_VAR).ok()) +} + impl CopilotChat { pub fn global(cx: &App) -> Option> { cx.try_global::() @@ -546,40 +556,42 @@ impl CopilotChat { configuration: CopilotChatConfiguration, cx: &mut Context, ) -> Self { - let config_paths: HashSet = copilot_chat_config_paths().into_iter().collect(); - let dir_path = copilot_chat_config_dir(); - - cx.spawn(async move |this, cx| { - let mut parent_watch_rx = watch_config_dir( - cx.background_executor(), - fs.clone(), - dir_path.clone(), - config_paths, - ); - while let Some(contents) = parent_watch_rx.next().await { + // Initial async scan of token sources. Live reload is driven by the + // Copilot LSP's auth status notifications instead of watching files, + // because SQLite WAL writes can make directory watchers racy. + cx.spawn({ + let fs = fs.clone(); + async move |this, cx| { let oauth_domain = this.read_with(cx, |this, _| this.configuration.oauth_domain())?; - let oauth_token = extract_oauth_token(contents, &oauth_domain); + let config_paths: HashSet = + copilot_chat_config_paths().into_iter().collect(); + let auth_db_path = copilot_chat_config_dir().join("auth.db"); - this.update(cx, |this, cx| { - this.oauth_token = oauth_token.clone(); - cx.notify(); - })?; + let oauth_token = + read_oauth_token(&fs, &config_paths, &oauth_domain, &auth_db_path, cx).await; if oauth_token.is_some() { + this.update(cx, |this, cx| { + this.oauth_token = oauth_token; + cx.notify(); + })?; Self::update_models(&this, cx).await?; } + anyhow::Ok(()) } - anyhow::Ok(()) }) .detach_and_log_err(cx); + // Initial state uses env var because it's cheap. The others do IO, so + // are on the background. let this = Self { - oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(), + oauth_token: oauth_token_from_env(), api_endpoint: None, models: None, configuration, client, + fs, }; if this.oauth_token.is_some() { @@ -764,6 +776,39 @@ impl CopilotChat { .detach(); } } + + pub fn reload_auth(&mut self, cx: &mut Context) { + let fs = self.fs.clone(); + let oauth_domain = self.configuration.oauth_domain(); + cx.spawn(async move |this, cx| { + let config_paths: HashSet = copilot_chat_config_paths().into_iter().collect(); + let auth_db_path = copilot_chat_config_dir().join("auth.db"); + + let new_token = + read_oauth_token(&fs, &config_paths, &oauth_domain, &auth_db_path, cx).await; + + let token_present = this.update(cx, |this, cx| { + let changed = this.oauth_token != new_token; + if changed { + this.oauth_token = new_token.clone(); + if new_token.is_none() { + // Sign-out: drop derived state so a future sign-in + // re-discovers the endpoint and re-fetches models. + this.api_endpoint = None; + this.models = None; + } + cx.notify(); + } + new_token.is_some() + })?; + + if token_present { + Self::update_models(&this, cx).await?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } } async fn get_models( @@ -917,6 +962,40 @@ async fn request_models( Ok(models) } +async fn read_oauth_token( + fs: &Arc, + config_paths: &HashSet, + oauth_domain: &str, + auth_db_path: &std::path::Path, + cx: &AsyncApp, +) -> Option { + if let Some(token) = oauth_token_from_env() { + return Some(token); + } + + let token_from_db = cx + .background_spawn({ + let auth_db_path = auth_db_path.to_path_buf(); + let oauth_domain = oauth_domain.to_string(); + async move { extract_oauth_token_from_db(&auth_db_path, &oauth_domain) } + }) + .await; + + if let Some(token) = token_from_db { + return Some(token); + } + + for file_path in config_paths { + if let Ok(contents) = fs.load(file_path).await { + if let Some(token) = extract_oauth_token(contents, oauth_domain) { + return Some(token); + } + } + } + + None +} + fn extract_oauth_token(contents: String, domain: &str) -> Option { serde_json::from_str::(&contents) .map(|v| { @@ -934,6 +1013,36 @@ fn extract_oauth_token(contents: String, domain: &str) -> Option { .flatten() } +fn extract_oauth_token_from_db(db_path: &Path, auth_authority: &str) -> Option { + if !db_path.exists() { + return None; + } + + let db = sqlez::connection::Connection::open_file(db_path.to_str()?); + + let token_bytes: Option> = db + .select_row_bound::<&str, Vec>( + "SELECT token_ciphertext FROM oauth_tokens WHERE auth_authority = ? ORDER BY last_used_at DESC, token_id DESC LIMIT 1", + ) + .ok() + .and_then(|mut select| select(auth_authority).ok().flatten()); + + let token = token_bytes.and_then(|bytes| String::from_utf8(bytes).ok())?; + + if token.starts_with("ghu_") + && token.len() >= 36 + && token.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') + { + log::debug!("Copilot OAuth token loaded from auth.db"); + Some(token) + } else { + log::warn!( + "Copilot auth.db: token does not match expected GitHub OAuth format (ghu_)" + ); + None + } +} + async fn stream_completion( client: Arc, oauth_token: String, @@ -1751,4 +1860,61 @@ mod tests { "\"none\"" ); } + + #[test] + fn test_extract_oauth_token_from_db_matches_auth_authority_and_recency() { + let dir = tempfile::tempdir().unwrap(); + let db_path = dir.path().join("auth.db"); + let older_github_token = "ghu_oldergithubtokenvalue000000000000"; + let newer_github_token = "ghu_newergithubtokenvalue000000000000"; + let enterprise_token = "ghu_enterprisetokenvalue0000000000000"; + + let connection = sqlez::connection::Connection::open_file(db_path.to_str().unwrap()); + connection + .exec( + "CREATE TABLE oauth_tokens ( + token_id INTEGER PRIMARY KEY AUTOINCREMENT, + auth_authority TEXT NOT NULL, + token_ciphertext BLOB NOT NULL, + last_used_at INTEGER NOT NULL + );", + ) + .unwrap()() + .unwrap(); + + { + let mut insert_token = connection + .exec_bound::<(&str, Vec, i64)>( + "INSERT INTO oauth_tokens (auth_authority, token_ciphertext, last_used_at) VALUES (?, ?, ?);", + ) + .unwrap(); + insert_token(("github.com", older_github_token.as_bytes().to_vec(), 10)).unwrap(); + insert_token(( + "github.enterprise.test", + enterprise_token.as_bytes().to_vec(), + 30, + )) + .unwrap(); + insert_token(("github.com", newer_github_token.as_bytes().to_vec(), 20)).unwrap(); + } + drop(connection); + + assert_eq!( + extract_oauth_token_from_db(&db_path, "github.com").as_deref(), + Some(newer_github_token) + ); + assert_eq!( + extract_oauth_token_from_db(&db_path, "github.enterprise.test").as_deref(), + Some(enterprise_token) + ); + } + + #[test] + fn test_extract_oauth_token_from_db_missing_db_does_not_create_file() { + let dir = tempfile::tempdir().unwrap(); + let db_path = dir.path().join("auth.db"); + + assert_eq!(extract_oauth_token_from_db(&db_path, "github.com"), None); + assert!(!db_path.exists()); + } }