copilot: Fix auth db fallback (#57764)

This PR uses https://github.com/zed-industries/zed/pull/57758 as a base
and adds tests, cleans up the comments, and checks changes the database
query used in auth.db to include oauth key.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- Fixed GitHub Copilot Chat showing an empty model dropdown for users on
newer Copilot SDK builds

---------

Co-authored-by: Alexander Shlemov <eodus@users.noreply.github.com>
Co-authored-by: cameron <cameron.studdstreet@gmail.com>
This commit is contained in:
Anthony Eid 2026-05-28 11:00:31 -04:00 committed by GitHub
parent d74e47ea51
commit bc64e1f955
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 210 additions and 23 deletions

2
Cargo.lock generated
View file

@ -3880,6 +3880,8 @@ dependencies = [
"serde",
"serde_json",
"settings",
"sqlez",
"tempfile",
]
[[package]]

View file

@ -511,8 +511,14 @@ impl Copilot {
};
}
if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
for env_var in [
copilot_chat::COPILOT_OAUTH_ENV_VAR,
copilot_chat::GITHUB_COPILOT_OAUTH_ENV_VAR,
] {
if let Ok(oauth_token) = env::var(env_var) {
env.insert(env_var.to_string(), oauth_token);
break;
}
}
if env.is_empty() { None } else { Some(env) }
@ -1259,6 +1265,7 @@ impl Copilot {
| request::SignInStatus::AlreadySignedIn { .. } => {
server.sign_in_status = SignInStatus::Authorized;
cx.emit(Event::CopilotAuthSignedIn);
notify_copilot_chat_auth_changed(cx);
for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
if let Some(buffer) = buffer.upgrade() {
self.register_buffer(&buffer, cx);
@ -1278,6 +1285,7 @@ impl Copilot {
};
}
cx.emit(Event::CopilotAuthSignedOut);
notify_copilot_chat_auth_changed(cx);
for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
self.unregister_buffer(&buffer);
}
@ -1381,6 +1389,15 @@ fn notify_did_change_config_to_server(
Ok(())
}
/// Notify Copilot Chat after the Copilot LSP reports an auth state change.
/// This replaces watching the SDK's token files, which is unreliable for
/// SQLite backed auth because writes may go through WAL files.
fn notify_copilot_chat_auth_changed(cx: &mut Context<Copilot>) {
if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
copilot_chat.update(cx, |chat, cx| chat.reload_auth(cx));
}
}
async fn clear_copilot_dir() {
remove_matching(paths::copilot_dir(), |_| true).await
}

View file

@ -34,7 +34,9 @@ paths.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
sqlez.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
serde_json.workspace = true
tempfile.workspace = true

View file

@ -1,6 +1,6 @@
pub mod responses;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::OnceLock;
@ -17,9 +17,10 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use paths::home_dir;
use serde::{Deserialize, Serialize};
use settings::watch_config_dir;
// The Copilot language server unofficially supports both token env vars:
// https://github.com/github/copilot-language-server-release/issues/3#issuecomment-2699433055
pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
pub const GITHUB_COPILOT_OAUTH_ENV_VAR: &str = "GITHUB_COPILOT_TOKEN";
const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
#[derive(Default, Clone, Debug, PartialEq)]
@ -501,6 +502,7 @@ pub struct CopilotChat {
configuration: CopilotChatConfiguration,
models: Option<Vec<Model>>,
client: Arc<dyn HttpClient>,
fs: Arc<dyn Fs>,
}
pub fn init(
@ -529,11 +531,19 @@ pub fn copilot_chat_config_dir() -> &'static PathBuf {
})
}
/// Legacy JSON token-storage paths used by older Copilot SDK builds.
/// TODO(copilot): once Copilot SDK supports `auth.db`, remove these paths.
fn copilot_chat_config_paths() -> [PathBuf; 2] {
let base_dir = copilot_chat_config_dir();
[base_dir.join("hosts.json"), base_dir.join("apps.json")]
}
fn oauth_token_from_env() -> Option<String> {
std::env::var(COPILOT_OAUTH_ENV_VAR)
.ok()
.or_else(|| std::env::var(GITHUB_COPILOT_OAUTH_ENV_VAR).ok())
}
impl CopilotChat {
pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
cx.try_global::<GlobalCopilotChat>()
@ -546,40 +556,42 @@ impl CopilotChat {
configuration: CopilotChatConfiguration,
cx: &mut Context<Self>,
) -> Self {
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
let dir_path = copilot_chat_config_dir();
cx.spawn(async move |this, cx| {
let mut parent_watch_rx = watch_config_dir(
cx.background_executor(),
fs.clone(),
dir_path.clone(),
config_paths,
);
while let Some(contents) = parent_watch_rx.next().await {
// Initial async scan of token sources. Live reload is driven by the
// Copilot LSP's auth status notifications instead of watching files,
// because SQLite WAL writes can make directory watchers racy.
cx.spawn({
let fs = fs.clone();
async move |this, cx| {
let oauth_domain =
this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
let oauth_token = extract_oauth_token(contents, &oauth_domain);
let config_paths: HashSet<PathBuf> =
copilot_chat_config_paths().into_iter().collect();
let auth_db_path = copilot_chat_config_dir().join("auth.db");
this.update(cx, |this, cx| {
this.oauth_token = oauth_token.clone();
cx.notify();
})?;
let oauth_token =
read_oauth_token(&fs, &config_paths, &oauth_domain, &auth_db_path, cx).await;
if oauth_token.is_some() {
this.update(cx, |this, cx| {
this.oauth_token = oauth_token;
cx.notify();
})?;
Self::update_models(&this, cx).await?;
}
anyhow::Ok(())
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
// Initial state uses env var because it's cheap. The others do IO, so
// are on the background.
let this = Self {
oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
oauth_token: oauth_token_from_env(),
api_endpoint: None,
models: None,
configuration,
client,
fs,
};
if this.oauth_token.is_some() {
@ -764,6 +776,39 @@ impl CopilotChat {
.detach();
}
}
pub fn reload_auth(&mut self, cx: &mut Context<Self>) {
let fs = self.fs.clone();
let oauth_domain = self.configuration.oauth_domain();
cx.spawn(async move |this, cx| {
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
let auth_db_path = copilot_chat_config_dir().join("auth.db");
let new_token =
read_oauth_token(&fs, &config_paths, &oauth_domain, &auth_db_path, cx).await;
let token_present = this.update(cx, |this, cx| {
let changed = this.oauth_token != new_token;
if changed {
this.oauth_token = new_token.clone();
if new_token.is_none() {
// Sign-out: drop derived state so a future sign-in
// re-discovers the endpoint and re-fetches models.
this.api_endpoint = None;
this.models = None;
}
cx.notify();
}
new_token.is_some()
})?;
if token_present {
Self::update_models(&this, cx).await?;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}
async fn get_models(
@ -917,6 +962,40 @@ async fn request_models(
Ok(models)
}
async fn read_oauth_token(
fs: &Arc<dyn Fs>,
config_paths: &HashSet<PathBuf>,
oauth_domain: &str,
auth_db_path: &std::path::Path,
cx: &AsyncApp,
) -> Option<String> {
if let Some(token) = oauth_token_from_env() {
return Some(token);
}
let token_from_db = cx
.background_spawn({
let auth_db_path = auth_db_path.to_path_buf();
let oauth_domain = oauth_domain.to_string();
async move { extract_oauth_token_from_db(&auth_db_path, &oauth_domain) }
})
.await;
if let Some(token) = token_from_db {
return Some(token);
}
for file_path in config_paths {
if let Ok(contents) = fs.load(file_path).await {
if let Some(token) = extract_oauth_token(contents, oauth_domain) {
return Some(token);
}
}
}
None
}
fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
serde_json::from_str::<serde_json::Value>(&contents)
.map(|v| {
@ -934,6 +1013,36 @@ fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
.flatten()
}
fn extract_oauth_token_from_db(db_path: &Path, auth_authority: &str) -> Option<String> {
if !db_path.exists() {
return None;
}
let db = sqlez::connection::Connection::open_file(db_path.to_str()?);
let token_bytes: Option<Vec<u8>> = db
.select_row_bound::<&str, Vec<u8>>(
"SELECT token_ciphertext FROM oauth_tokens WHERE auth_authority = ? ORDER BY last_used_at DESC, token_id DESC LIMIT 1",
)
.ok()
.and_then(|mut select| select(auth_authority).ok().flatten());
let token = token_bytes.and_then(|bytes| String::from_utf8(bytes).ok())?;
if token.starts_with("ghu_")
&& token.len() >= 36
&& token.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
{
log::debug!("Copilot OAuth token loaded from auth.db");
Some(token)
} else {
log::warn!(
"Copilot auth.db: token does not match expected GitHub OAuth format (ghu_<alphanumeric>)"
);
None
}
}
async fn stream_completion(
client: Arc<dyn HttpClient>,
oauth_token: String,
@ -1751,4 +1860,61 @@ mod tests {
"\"none\""
);
}
#[test]
fn test_extract_oauth_token_from_db_matches_auth_authority_and_recency() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("auth.db");
let older_github_token = "ghu_oldergithubtokenvalue000000000000";
let newer_github_token = "ghu_newergithubtokenvalue000000000000";
let enterprise_token = "ghu_enterprisetokenvalue0000000000000";
let connection = sqlez::connection::Connection::open_file(db_path.to_str().unwrap());
connection
.exec(
"CREATE TABLE oauth_tokens (
token_id INTEGER PRIMARY KEY AUTOINCREMENT,
auth_authority TEXT NOT NULL,
token_ciphertext BLOB NOT NULL,
last_used_at INTEGER NOT NULL
);",
)
.unwrap()()
.unwrap();
{
let mut insert_token = connection
.exec_bound::<(&str, Vec<u8>, i64)>(
"INSERT INTO oauth_tokens (auth_authority, token_ciphertext, last_used_at) VALUES (?, ?, ?);",
)
.unwrap();
insert_token(("github.com", older_github_token.as_bytes().to_vec(), 10)).unwrap();
insert_token((
"github.enterprise.test",
enterprise_token.as_bytes().to_vec(),
30,
))
.unwrap();
insert_token(("github.com", newer_github_token.as_bytes().to_vec(), 20)).unwrap();
}
drop(connection);
assert_eq!(
extract_oauth_token_from_db(&db_path, "github.com").as_deref(),
Some(newer_github_token)
);
assert_eq!(
extract_oauth_token_from_db(&db_path, "github.enterprise.test").as_deref(),
Some(enterprise_token)
);
}
#[test]
fn test_extract_oauth_token_from_db_missing_db_does_not_create_file() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("auth.db");
assert_eq!(extract_oauth_token_from_db(&db_path, "github.com"), None);
assert!(!db_path.exists());
}
}