mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-31 19:05:00 +07:00
copilot: Fix auth db fallback (#57764)
This PR uses https://github.com/zed-industries/zed/pull/57758 as a base and adds tests, cleans up the comments, and checks changes the database query used in auth.db to include oauth key. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - Fixed GitHub Copilot Chat showing an empty model dropdown for users on newer Copilot SDK builds --------- Co-authored-by: Alexander Shlemov <eodus@users.noreply.github.com> Co-authored-by: cameron <cameron.studdstreet@gmail.com>
This commit is contained in:
parent
d74e47ea51
commit
bc64e1f955
4 changed files with 210 additions and 23 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -3880,6 +3880,8 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"sqlez",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -511,8 +511,14 @@ impl Copilot {
|
|||
};
|
||||
}
|
||||
|
||||
if let Ok(oauth_token) = env::var(copilot_chat::COPILOT_OAUTH_ENV_VAR) {
|
||||
env.insert(copilot_chat::COPILOT_OAUTH_ENV_VAR.to_string(), oauth_token);
|
||||
for env_var in [
|
||||
copilot_chat::COPILOT_OAUTH_ENV_VAR,
|
||||
copilot_chat::GITHUB_COPILOT_OAUTH_ENV_VAR,
|
||||
] {
|
||||
if let Ok(oauth_token) = env::var(env_var) {
|
||||
env.insert(env_var.to_string(), oauth_token);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if env.is_empty() { None } else { Some(env) }
|
||||
|
|
@ -1259,6 +1265,7 @@ impl Copilot {
|
|||
| request::SignInStatus::AlreadySignedIn { .. } => {
|
||||
server.sign_in_status = SignInStatus::Authorized;
|
||||
cx.emit(Event::CopilotAuthSignedIn);
|
||||
notify_copilot_chat_auth_changed(cx);
|
||||
for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
|
||||
if let Some(buffer) = buffer.upgrade() {
|
||||
self.register_buffer(&buffer, cx);
|
||||
|
|
@ -1278,6 +1285,7 @@ impl Copilot {
|
|||
};
|
||||
}
|
||||
cx.emit(Event::CopilotAuthSignedOut);
|
||||
notify_copilot_chat_auth_changed(cx);
|
||||
for buffer in self.buffers.iter().cloned().collect::<Vec<_>>() {
|
||||
self.unregister_buffer(&buffer);
|
||||
}
|
||||
|
|
@ -1381,6 +1389,15 @@ fn notify_did_change_config_to_server(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Notify Copilot Chat after the Copilot LSP reports an auth state change.
|
||||
/// This replaces watching the SDK's token files, which is unreliable for
|
||||
/// SQLite backed auth because writes may go through WAL files.
|
||||
fn notify_copilot_chat_auth_changed(cx: &mut Context<Copilot>) {
|
||||
if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
|
||||
copilot_chat.update(cx, |chat, cx| chat.reload_auth(cx));
|
||||
}
|
||||
}
|
||||
|
||||
async fn clear_copilot_dir() {
|
||||
remove_matching(paths::copilot_dir(), |_| true).await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ paths.workspace = true
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sqlez.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
serde_json.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
pub mod responses;
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
|
|
@ -17,9 +17,10 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|||
use paths::home_dir;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use settings::watch_config_dir;
|
||||
|
||||
// The Copilot language server unofficially supports both token env vars:
|
||||
// https://github.com/github/copilot-language-server-release/issues/3#issuecomment-2699433055
|
||||
pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
|
||||
pub const GITHUB_COPILOT_OAUTH_ENV_VAR: &str = "GITHUB_COPILOT_TOKEN";
|
||||
const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
|
|
@ -501,6 +502,7 @@ pub struct CopilotChat {
|
|||
configuration: CopilotChatConfiguration,
|
||||
models: Option<Vec<Model>>,
|
||||
client: Arc<dyn HttpClient>,
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
pub fn init(
|
||||
|
|
@ -529,11 +531,19 @@ pub fn copilot_chat_config_dir() -> &'static PathBuf {
|
|||
})
|
||||
}
|
||||
|
||||
/// Legacy JSON token-storage paths used by older Copilot SDK builds.
|
||||
/// TODO(copilot): once Copilot SDK supports `auth.db`, remove these paths.
|
||||
fn copilot_chat_config_paths() -> [PathBuf; 2] {
|
||||
let base_dir = copilot_chat_config_dir();
|
||||
[base_dir.join("hosts.json"), base_dir.join("apps.json")]
|
||||
}
|
||||
|
||||
fn oauth_token_from_env() -> Option<String> {
|
||||
std::env::var(COPILOT_OAUTH_ENV_VAR)
|
||||
.ok()
|
||||
.or_else(|| std::env::var(GITHUB_COPILOT_OAUTH_ENV_VAR).ok())
|
||||
}
|
||||
|
||||
impl CopilotChat {
|
||||
pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
|
||||
cx.try_global::<GlobalCopilotChat>()
|
||||
|
|
@ -546,40 +556,42 @@ impl CopilotChat {
|
|||
configuration: CopilotChatConfiguration,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
|
||||
let dir_path = copilot_chat_config_dir();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let mut parent_watch_rx = watch_config_dir(
|
||||
cx.background_executor(),
|
||||
fs.clone(),
|
||||
dir_path.clone(),
|
||||
config_paths,
|
||||
);
|
||||
while let Some(contents) = parent_watch_rx.next().await {
|
||||
// Initial async scan of token sources. Live reload is driven by the
|
||||
// Copilot LSP's auth status notifications instead of watching files,
|
||||
// because SQLite WAL writes can make directory watchers racy.
|
||||
cx.spawn({
|
||||
let fs = fs.clone();
|
||||
async move |this, cx| {
|
||||
let oauth_domain =
|
||||
this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
|
||||
let oauth_token = extract_oauth_token(contents, &oauth_domain);
|
||||
let config_paths: HashSet<PathBuf> =
|
||||
copilot_chat_config_paths().into_iter().collect();
|
||||
let auth_db_path = copilot_chat_config_dir().join("auth.db");
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.oauth_token = oauth_token.clone();
|
||||
cx.notify();
|
||||
})?;
|
||||
let oauth_token =
|
||||
read_oauth_token(&fs, &config_paths, &oauth_domain, &auth_db_path, cx).await;
|
||||
|
||||
if oauth_token.is_some() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.oauth_token = oauth_token;
|
||||
cx.notify();
|
||||
})?;
|
||||
Self::update_models(&this, cx).await?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
// Initial state uses env var because it's cheap. The others do IO, so
|
||||
// are on the background.
|
||||
let this = Self {
|
||||
oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
|
||||
oauth_token: oauth_token_from_env(),
|
||||
api_endpoint: None,
|
||||
models: None,
|
||||
configuration,
|
||||
client,
|
||||
fs,
|
||||
};
|
||||
|
||||
if this.oauth_token.is_some() {
|
||||
|
|
@ -764,6 +776,39 @@ impl CopilotChat {
|
|||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_auth(&mut self, cx: &mut Context<Self>) {
|
||||
let fs = self.fs.clone();
|
||||
let oauth_domain = self.configuration.oauth_domain();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
|
||||
let auth_db_path = copilot_chat_config_dir().join("auth.db");
|
||||
|
||||
let new_token =
|
||||
read_oauth_token(&fs, &config_paths, &oauth_domain, &auth_db_path, cx).await;
|
||||
|
||||
let token_present = this.update(cx, |this, cx| {
|
||||
let changed = this.oauth_token != new_token;
|
||||
if changed {
|
||||
this.oauth_token = new_token.clone();
|
||||
if new_token.is_none() {
|
||||
// Sign-out: drop derived state so a future sign-in
|
||||
// re-discovers the endpoint and re-fetches models.
|
||||
this.api_endpoint = None;
|
||||
this.models = None;
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
new_token.is_some()
|
||||
})?;
|
||||
|
||||
if token_present {
|
||||
Self::update_models(&this, cx).await?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_models(
|
||||
|
|
@ -917,6 +962,40 @@ async fn request_models(
|
|||
Ok(models)
|
||||
}
|
||||
|
||||
async fn read_oauth_token(
|
||||
fs: &Arc<dyn Fs>,
|
||||
config_paths: &HashSet<PathBuf>,
|
||||
oauth_domain: &str,
|
||||
auth_db_path: &std::path::Path,
|
||||
cx: &AsyncApp,
|
||||
) -> Option<String> {
|
||||
if let Some(token) = oauth_token_from_env() {
|
||||
return Some(token);
|
||||
}
|
||||
|
||||
let token_from_db = cx
|
||||
.background_spawn({
|
||||
let auth_db_path = auth_db_path.to_path_buf();
|
||||
let oauth_domain = oauth_domain.to_string();
|
||||
async move { extract_oauth_token_from_db(&auth_db_path, &oauth_domain) }
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(token) = token_from_db {
|
||||
return Some(token);
|
||||
}
|
||||
|
||||
for file_path in config_paths {
|
||||
if let Ok(contents) = fs.load(file_path).await {
|
||||
if let Some(token) = extract_oauth_token(contents, oauth_domain) {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
|
||||
serde_json::from_str::<serde_json::Value>(&contents)
|
||||
.map(|v| {
|
||||
|
|
@ -934,6 +1013,36 @@ fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
|
|||
.flatten()
|
||||
}
|
||||
|
||||
fn extract_oauth_token_from_db(db_path: &Path, auth_authority: &str) -> Option<String> {
|
||||
if !db_path.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let db = sqlez::connection::Connection::open_file(db_path.to_str()?);
|
||||
|
||||
let token_bytes: Option<Vec<u8>> = db
|
||||
.select_row_bound::<&str, Vec<u8>>(
|
||||
"SELECT token_ciphertext FROM oauth_tokens WHERE auth_authority = ? ORDER BY last_used_at DESC, token_id DESC LIMIT 1",
|
||||
)
|
||||
.ok()
|
||||
.and_then(|mut select| select(auth_authority).ok().flatten());
|
||||
|
||||
let token = token_bytes.and_then(|bytes| String::from_utf8(bytes).ok())?;
|
||||
|
||||
if token.starts_with("ghu_")
|
||||
&& token.len() >= 36
|
||||
&& token.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
{
|
||||
log::debug!("Copilot OAuth token loaded from auth.db");
|
||||
Some(token)
|
||||
} else {
|
||||
log::warn!(
|
||||
"Copilot auth.db: token does not match expected GitHub OAuth format (ghu_<alphanumeric>)"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_completion(
|
||||
client: Arc<dyn HttpClient>,
|
||||
oauth_token: String,
|
||||
|
|
@ -1751,4 +1860,61 @@ mod tests {
|
|||
"\"none\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_from_db_matches_auth_authority_and_recency() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("auth.db");
|
||||
let older_github_token = "ghu_oldergithubtokenvalue000000000000";
|
||||
let newer_github_token = "ghu_newergithubtokenvalue000000000000";
|
||||
let enterprise_token = "ghu_enterprisetokenvalue0000000000000";
|
||||
|
||||
let connection = sqlez::connection::Connection::open_file(db_path.to_str().unwrap());
|
||||
connection
|
||||
.exec(
|
||||
"CREATE TABLE oauth_tokens (
|
||||
token_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
auth_authority TEXT NOT NULL,
|
||||
token_ciphertext BLOB NOT NULL,
|
||||
last_used_at INTEGER NOT NULL
|
||||
);",
|
||||
)
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let mut insert_token = connection
|
||||
.exec_bound::<(&str, Vec<u8>, i64)>(
|
||||
"INSERT INTO oauth_tokens (auth_authority, token_ciphertext, last_used_at) VALUES (?, ?, ?);",
|
||||
)
|
||||
.unwrap();
|
||||
insert_token(("github.com", older_github_token.as_bytes().to_vec(), 10)).unwrap();
|
||||
insert_token((
|
||||
"github.enterprise.test",
|
||||
enterprise_token.as_bytes().to_vec(),
|
||||
30,
|
||||
))
|
||||
.unwrap();
|
||||
insert_token(("github.com", newer_github_token.as_bytes().to_vec(), 20)).unwrap();
|
||||
}
|
||||
drop(connection);
|
||||
|
||||
assert_eq!(
|
||||
extract_oauth_token_from_db(&db_path, "github.com").as_deref(),
|
||||
Some(newer_github_token)
|
||||
);
|
||||
assert_eq!(
|
||||
extract_oauth_token_from_db(&db_path, "github.enterprise.test").as_deref(),
|
||||
Some(enterprise_token)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_from_db_missing_db_does_not_create_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("auth.db");
|
||||
|
||||
assert_eq!(extract_oauth_token_from_db(&db_path, "github.com"), None);
|
||||
assert!(!db_path.exists());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue