collab: Validate access tokens through Cloud (#49535)

This PR updates Collab to make it validate access tokens through Cloud
instead of doing it in-house.

We're reusing the `GET /client/users/me` endpoint—which is what we also
call on the client—to validate the user's access token.

We only need to do this when establishing a WebSocket connection, so the
increased latency of a network hop shouldn't be a problem.

Closes CLO-308.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2026-02-18 18:20:52 -05:00 committed by GitHub
parent af050fc565
commit f07cec59de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 41 additions and 301 deletions

37
Cargo.lock generated
View file

@ -3202,7 +3202,6 @@ dependencies = [
"aws-sdk-kinesis",
"aws-sdk-s3",
"axum",
"base64 0.22.1",
"buffer_diff",
"call",
"channel",
@ -3260,7 +3259,6 @@ dependencies = [
"remote_server",
"reqwest 0.11.27",
"rpc",
"scrypt",
"sea-orm",
"sea-orm-macros",
"semver",
@ -3272,7 +3270,6 @@ dependencies = [
"smol",
"sqlx",
"strum 0.27.2",
"subtle",
"task",
"telemetry_events",
"text",
@ -11463,17 +11460,6 @@ dependencies = [
"subtle",
]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core 0.6.4",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.15"
@ -11559,7 +11545,7 @@ checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
dependencies = [
"digest",
"hmac",
"password-hash 0.4.2",
"password-hash",
"sha2",
]
@ -14515,15 +14501,6 @@ dependencies = [
"serde_json",
]
[[package]]
name = "salsa20"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213"
dependencies = [
"cipher",
]
[[package]]
name = "same-file"
version = "1.0.6"
@ -14666,18 +14643,6 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "scrypt"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f"
dependencies = [
"password-hash 0.5.0",
"pbkdf2 0.12.2",
"salsa20",
"sha2",
]
[[package]]
name = "sct"
version = "0.7.1"

View file

@ -668,7 +668,6 @@ stacksafe = "0.1"
streaming-iterator = "0.1"
strsim = "0.11"
strum = { version = "0.27.2", features = ["derive"] }
subtle = "2.5.0"
syn = { version = "2.0.101", features = ["full", "extra-traits", "visit-mut"] }
sys-locale = "0.3.1"
sysinfo = "0.37.0"

View file

@ -33,7 +33,6 @@ aws-config = { version = "1.1.5" }
aws-sdk-kinesis = "1.51.0"
aws-sdk-s3 = { version = "1.15.0" }
axum = { version = "0.6", features = ["json", "headers", "ws"] }
base64.workspace = true
chrono.workspace = true
clock.workspace = true
cloud_api_types.workspace = true
@ -53,7 +52,6 @@ prost.workspace = true
rand.workspace = true
reqwest = { version = "0.11", features = ["json"] }
rpc.workspace = true
scrypt = "0.11"
# sea-orm and sea-orm-macros versions must match exactly.
sea-orm = { version = "=1.1.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
sea-orm-macros = "=1.1.10"
@ -63,7 +61,6 @@ serde_json.workspace = true
sha2.workspace = true
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
strum.workspace = true
subtle.workspace = true
telemetry_events.workspace = true
text.workspace = true
time.workspace = true

View file

@ -1,26 +1,13 @@
use crate::{
AppState, Error, Result,
db::{AccessTokenId, Database, UserId},
rpc::Principal,
};
use crate::{AppState, Error, db::UserId, rpc::Principal};
use anyhow::Context as _;
use axum::{
http::{self, Request, StatusCode},
middleware::Next,
response::IntoResponse,
};
use base64::prelude::*;
use prometheus::{Histogram, exponential_buckets, register_histogram};
use cloud_api_types::GetAuthenticatedUserResponse;
pub use rpc::auth::random_token;
use scrypt::{
Scrypt,
password_hash::{PasswordHash, PasswordVerifier},
};
use serde::{Deserialize, Serialize};
use sha2::Digest;
use std::sync::OnceLock;
use std::{sync::Arc, time::Instant};
use subtle::ConstantTimeEq;
use std::sync::Arc;
/// Validates the authorization header and adds an Extension<Principal> to the request.
/// Authorization: <user-id> <token>
@ -64,11 +51,23 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
)
})?;
let validate_result = verify_access_token(access_token, user_id, &state.db).await;
let http_client = state.http_client.clone().expect("no HTTP client");
let response = http_client
.get(format!("{}/client/users/me", state.config.zed_cloud_url()))
.header("Content-Type", "application/json")
.header("Authorization", format!("{user_id} {access_token}"))
.send()
.await
.context("failed to validate access token")?;
if let Ok(response) = response.error_for_status() {
let response_body: GetAuthenticatedUserResponse = response
.json()
.await
.context("failed to parse response body")?;
let user_id = UserId(response_body.user.id);
if let Ok(validate_result) = validate_result
&& validate_result.is_valid
{
let user = state
.db
.get_user_by_id(user_id)
@ -84,68 +83,3 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
"invalid credentials".to_string(),
))
}
#[derive(Serialize, Deserialize)]
pub struct AccessTokenJson {
pub version: usize,
pub id: AccessTokenId,
pub token: String,
}
/// Hashing prevents anyone with access to the database being able to login.
/// As the token is randomly generated, we don't need to worry about scrypt-style
/// protection.
pub fn hash_access_token(token: &str) -> String {
let digest = sha2::Sha256::digest(token);
format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
}
pub struct VerifyAccessTokenResult {
pub is_valid: bool,
}
/// Checks that the given access token is valid for the given user.
pub async fn verify_access_token(
token: &str,
user_id: UserId,
db: &Arc<Database>,
) -> Result<VerifyAccessTokenResult> {
static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
register_histogram!(
"access_token_hashing_time",
"time spent hashing access tokens",
exponential_buckets(10.0, 2.0, 10).unwrap(),
)
.unwrap()
});
let token: AccessTokenJson = serde_json::from_str(token)?;
let db_token = db.get_access_token(token.id).await?;
if db_token.user_id != user_id {
return Err(anyhow::anyhow!("no such access token"))?;
}
let t0 = Instant::now();
let is_valid = if db_token.hash.starts_with("$scrypt$") {
let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
Scrypt
.verify_password(token.token.as_bytes(), &db_hash)
.is_ok()
} else {
let token_hash = hash_access_token(&token.token);
db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
};
let duration = t0.elapsed();
log::info!("hashed access token in {:?}", duration);
metric_access_token_hashing_time.observe(duration.as_millis() as f64);
if is_valid && db_token.hash.starts_with("$scrypt$") {
let new_hash = hash_access_token(&token.token);
db.update_access_token_hash(db_token.id, &new_hash).await?;
}
Ok(VerifyAccessTokenResult { is_valid })
}

View file

@ -18,6 +18,9 @@ use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
@ -150,6 +153,14 @@ impl Config {
}
}
/// Returns the base Zed Cloud URL.
pub fn zed_cloud_url(&self) -> &str {
match self.zed_environment.as_ref() {
"development" => "http://localhost:8787",
_ => "https://cloud.zed.dev",
}
}
#[cfg(feature = "test-support")]
pub fn test() -> Self {
Self {
@ -199,6 +210,7 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
pub http_client: Option<reqwest::Client>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub executor: Executor,
@ -228,9 +240,16 @@ impl AppState {
None
};
let user_agent = format!("Collab/{VERSION} ({})", REVISION.unwrap_or("unknown"));
let http_client = reqwest::Client::builder()
.user_agent(user_agent)
.build()
.context("failed to construct HTTP client")?;
let db = Arc::new(db);
let this = Self {
db: db.clone(),
http_client: Some(http_client),
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
executor,

View file

@ -7,12 +7,12 @@ use axum::{
routing::get,
};
use collab::ServiceMode;
use collab::api::CloudflareIpCountryHeader;
use collab::{
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor,
};
use collab::{REVISION, ServiceMode, VERSION};
use db::Database;
use std::{
env::args,
@ -28,9 +28,6 @@ use tracing_subscriber::{
};
use util::ResultExt as _;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
#[expect(clippy::result_large_err)]
#[tokio::main]
async fn main() -> Result<()> {

View file

@ -50,175 +50,3 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
cx.read(|cx| room.read(cx).channel_id())
}
mod auth_token_tests {
use collab::auth::{
AccessTokenJson, VerifyAccessTokenResult, hash_access_token, verify_access_token,
};
use rand::prelude::*;
use scrypt::Scrypt;
use scrypt::password_hash::{PasswordHasher, SaltString};
use sea_orm::EntityTrait;
use collab::db::{Database, NewUserParams, UserId, access_token};
use collab::*;
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
const VERSION: usize = 1;
let access_token = ::rpc::auth::random_token();
let access_token_hash = hash_access_token(&access_token);
let id = db
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.await?;
Ok(serde_json::to_string(&AccessTokenJson {
version: VERSION,
id,
token: access_token,
})?)
}
#[gpui::test]
async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
let test_db = crate::db_tests::TestDb::sqlite(cx.executor());
let db = test_db.db();
let user = db
.create_user(
"example@example.com",
None,
false,
NewUserParams {
github_login: "example".into(),
github_user_id: 1,
},
)
.await
.unwrap();
let token = create_access_token(db, user.user_id).await.unwrap();
assert!(matches!(
verify_access_token(&token, user.user_id, db).await.unwrap(),
VerifyAccessTokenResult { is_valid: true }
));
let old_token = create_previous_access_token(user.user_id, db)
.await
.unwrap();
let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
.unwrap()
.id;
let hash = db
.transaction(|tx| async move {
Ok(access_token::Entity::find_by_id(old_token_id)
.one(&*tx)
.await?)
})
.await
.unwrap()
.unwrap()
.hash;
assert!(hash.starts_with("$scrypt$"));
assert!(matches!(
verify_access_token(&old_token, user.user_id, db)
.await
.unwrap(),
VerifyAccessTokenResult { is_valid: true }
));
let hash = db
.transaction(|tx| async move {
Ok(access_token::Entity::find_by_id(old_token_id)
.one(&*tx)
.await?)
})
.await
.unwrap()
.unwrap()
.hash;
assert!(hash.starts_with("$sha256$"));
assert!(matches!(
verify_access_token(&old_token, user.user_id, db)
.await
.unwrap(),
VerifyAccessTokenResult { is_valid: true }
));
assert!(matches!(
verify_access_token(&token, user.user_id, db).await.unwrap(),
VerifyAccessTokenResult { is_valid: true }
));
}
async fn create_previous_access_token(user_id: UserId, db: &Database) -> Result<String> {
let access_token = collab::auth::random_token();
let access_token_hash = previous_hash_access_token(&access_token)?;
let id = db
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.await?;
Ok(serde_json::to_string(&AccessTokenJson {
version: 1,
id,
token: access_token,
})?)
}
#[expect(clippy::result_large_err)]
fn previous_hash_access_token(token: &str) -> Result<String> {
// Avoid slow hashing in debug mode.
let params = if cfg!(debug_assertions) {
scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
} else {
scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
};
Ok(Scrypt
.hash_password_customized(
token.as_bytes(),
None,
None,
params,
&SaltString::generate(PasswordHashRngCompat::new()),
)
.map_err(anyhow::Error::new)?
.to_string())
}
// TODO: remove once we password_hash v0.6 is released.
struct PasswordHashRngCompat(rand::rngs::ThreadRng);
impl PasswordHashRngCompat {
fn new() -> Self {
Self(rand::rng())
}
}
impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
fn next_u32(&mut self) -> u32 {
self.0.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.0.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.0.fill_bytes(dest);
}
fn try_fill_bytes(
&mut self,
dest: &mut [u8],
) -> Result<(), scrypt::password_hash::rand_core::Error> {
self.fill_bytes(dest);
Ok(())
}
}
impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
}

View file

@ -564,6 +564,7 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
http_client: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
executor,