mirror of
https://github.com/zed-industries/zed.git
synced 2026-06-01 03:14:56 +07:00
collab: Validate access tokens through Cloud (#49535)
This PR updates Collab to make it validate access tokens through Cloud instead of doing it in-house. We're reusing the `GET /client/users/me` endpoint—which is what we also call on the client—to validate the user's access token. We only need to do this when establishing a WebSocket connection, so the increased latency of a network hop shouldn't be a problem. Closes CLO-308. Release Notes: - N/A
This commit is contained in:
parent
af050fc565
commit
f07cec59de
8 changed files with 41 additions and 301 deletions
37
Cargo.lock
generated
37
Cargo.lock
generated
|
|
@ -3202,7 +3202,6 @@ dependencies = [
|
|||
"aws-sdk-kinesis",
|
||||
"aws-sdk-s3",
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
"buffer_diff",
|
||||
"call",
|
||||
"channel",
|
||||
|
|
@ -3260,7 +3259,6 @@ dependencies = [
|
|||
"remote_server",
|
||||
"reqwest 0.11.27",
|
||||
"rpc",
|
||||
"scrypt",
|
||||
"sea-orm",
|
||||
"sea-orm-macros",
|
||||
"semver",
|
||||
|
|
@ -3272,7 +3270,6 @@ dependencies = [
|
|||
"smol",
|
||||
"sqlx",
|
||||
"strum 0.27.2",
|
||||
"subtle",
|
||||
"task",
|
||||
"telemetry_events",
|
||||
"text",
|
||||
|
|
@ -11463,17 +11460,6 @@ dependencies = [
|
|||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
"rand_core 0.6.4",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
|
|
@ -11559,7 +11545,7 @@ checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
|
|||
dependencies = [
|
||||
"digest",
|
||||
"hmac",
|
||||
"password-hash 0.4.2",
|
||||
"password-hash",
|
||||
"sha2",
|
||||
]
|
||||
|
||||
|
|
@ -14515,15 +14501,6 @@ dependencies = [
|
|||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "salsa20"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213"
|
||||
dependencies = [
|
||||
"cipher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
|
|
@ -14666,18 +14643,6 @@ dependencies = [
|
|||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scrypt"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f"
|
||||
dependencies = [
|
||||
"password-hash 0.5.0",
|
||||
"pbkdf2 0.12.2",
|
||||
"salsa20",
|
||||
"sha2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.1"
|
||||
|
|
|
|||
|
|
@ -668,7 +668,6 @@ stacksafe = "0.1"
|
|||
streaming-iterator = "0.1"
|
||||
strsim = "0.11"
|
||||
strum = { version = "0.27.2", features = ["derive"] }
|
||||
subtle = "2.5.0"
|
||||
syn = { version = "2.0.101", features = ["full", "extra-traits", "visit-mut"] }
|
||||
sys-locale = "0.3.1"
|
||||
sysinfo = "0.37.0"
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ aws-config = { version = "1.1.5" }
|
|||
aws-sdk-kinesis = "1.51.0"
|
||||
aws-sdk-s3 = { version = "1.15.0" }
|
||||
axum = { version = "0.6", features = ["json", "headers", "ws"] }
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
|
|
@ -53,7 +52,6 @@ prost.workspace = true
|
|||
rand.workspace = true
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
rpc.workspace = true
|
||||
scrypt = "0.11"
|
||||
# sea-orm and sea-orm-macros versions must match exactly.
|
||||
sea-orm = { version = "=1.1.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
|
||||
sea-orm-macros = "=1.1.10"
|
||||
|
|
@ -63,7 +61,6 @@ serde_json.workspace = true
|
|||
sha2.workspace = true
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
|
||||
strum.workspace = true
|
||||
subtle.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
time.workspace = true
|
||||
|
|
|
|||
|
|
@ -1,26 +1,13 @@
|
|||
use crate::{
|
||||
AppState, Error, Result,
|
||||
db::{AccessTokenId, Database, UserId},
|
||||
rpc::Principal,
|
||||
};
|
||||
use crate::{AppState, Error, db::UserId, rpc::Principal};
|
||||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
http::{self, Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use base64::prelude::*;
|
||||
use prometheus::{Histogram, exponential_buckets, register_histogram};
|
||||
use cloud_api_types::GetAuthenticatedUserResponse;
|
||||
pub use rpc::auth::random_token;
|
||||
use scrypt::{
|
||||
Scrypt,
|
||||
password_hash::{PasswordHash, PasswordVerifier},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Digest;
|
||||
use std::sync::OnceLock;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
use subtle::ConstantTimeEq;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Validates the authorization header and adds an Extension<Principal> to the request.
|
||||
/// Authorization: <user-id> <token>
|
||||
|
|
@ -64,11 +51,23 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||
)
|
||||
})?;
|
||||
|
||||
let validate_result = verify_access_token(access_token, user_id, &state.db).await;
|
||||
let http_client = state.http_client.clone().expect("no HTTP client");
|
||||
|
||||
let response = http_client
|
||||
.get(format!("{}/client/users/me", state.config.zed_cloud_url()))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("{user_id} {access_token}"))
|
||||
.send()
|
||||
.await
|
||||
.context("failed to validate access token")?;
|
||||
if let Ok(response) = response.error_for_status() {
|
||||
let response_body: GetAuthenticatedUserResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("failed to parse response body")?;
|
||||
|
||||
let user_id = UserId(response_body.user.id);
|
||||
|
||||
if let Ok(validate_result) = validate_result
|
||||
&& validate_result.is_valid
|
||||
{
|
||||
let user = state
|
||||
.db
|
||||
.get_user_by_id(user_id)
|
||||
|
|
@ -84,68 +83,3 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||
"invalid credentials".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AccessTokenJson {
|
||||
pub version: usize,
|
||||
pub id: AccessTokenId,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
/// Hashing prevents anyone with access to the database being able to login.
|
||||
/// As the token is randomly generated, we don't need to worry about scrypt-style
|
||||
/// protection.
|
||||
pub fn hash_access_token(token: &str) -> String {
|
||||
let digest = sha2::Sha256::digest(token);
|
||||
format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
|
||||
}
|
||||
|
||||
pub struct VerifyAccessTokenResult {
|
||||
pub is_valid: bool,
|
||||
}
|
||||
|
||||
/// Checks that the given access token is valid for the given user.
|
||||
pub async fn verify_access_token(
|
||||
token: &str,
|
||||
user_id: UserId,
|
||||
db: &Arc<Database>,
|
||||
) -> Result<VerifyAccessTokenResult> {
|
||||
static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
|
||||
let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
|
||||
register_histogram!(
|
||||
"access_token_hashing_time",
|
||||
"time spent hashing access tokens",
|
||||
exponential_buckets(10.0, 2.0, 10).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let token: AccessTokenJson = serde_json::from_str(token)?;
|
||||
|
||||
let db_token = db.get_access_token(token.id).await?;
|
||||
if db_token.user_id != user_id {
|
||||
return Err(anyhow::anyhow!("no such access token"))?;
|
||||
}
|
||||
let t0 = Instant::now();
|
||||
|
||||
let is_valid = if db_token.hash.starts_with("$scrypt$") {
|
||||
let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
|
||||
Scrypt
|
||||
.verify_password(token.token.as_bytes(), &db_hash)
|
||||
.is_ok()
|
||||
} else {
|
||||
let token_hash = hash_access_token(&token.token);
|
||||
db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
|
||||
};
|
||||
|
||||
let duration = t0.elapsed();
|
||||
log::info!("hashed access token in {:?}", duration);
|
||||
metric_access_token_hashing_time.observe(duration.as_millis() as f64);
|
||||
|
||||
if is_valid && db_token.hash.starts_with("$scrypt$") {
|
||||
let new_hash = hash_access_token(&token.token);
|
||||
db.update_access_token_hash(db_token.id, &new_hash).await?;
|
||||
}
|
||||
|
||||
Ok(VerifyAccessTokenResult { is_valid })
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ use serde::Deserialize;
|
|||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
|
|
@ -150,6 +153,14 @@ impl Config {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the base Zed Cloud URL.
|
||||
pub fn zed_cloud_url(&self) -> &str {
|
||||
match self.zed_environment.as_ref() {
|
||||
"development" => "http://localhost:8787",
|
||||
_ => "https://cloud.zed.dev",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
pub fn test() -> Self {
|
||||
Self {
|
||||
|
|
@ -199,6 +210,7 @@ impl ServiceMode {
|
|||
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub http_client: Option<reqwest::Client>,
|
||||
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub executor: Executor,
|
||||
|
|
@ -228,9 +240,16 @@ impl AppState {
|
|||
None
|
||||
};
|
||||
|
||||
let user_agent = format!("Collab/{VERSION} ({})", REVISION.unwrap_or("unknown"));
|
||||
let http_client = reqwest::Client::builder()
|
||||
.user_agent(user_agent)
|
||||
.build()
|
||||
.context("failed to construct HTTP client")?;
|
||||
|
||||
let db = Arc::new(db);
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
http_client: Some(http_client),
|
||||
livekit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
executor,
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ use axum::{
|
|||
routing::get,
|
||||
};
|
||||
|
||||
use collab::ServiceMode;
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::{
|
||||
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
|
||||
executor::Executor,
|
||||
};
|
||||
use collab::{REVISION, ServiceMode, VERSION};
|
||||
use db::Database;
|
||||
use std::{
|
||||
env::args,
|
||||
|
|
@ -28,9 +28,6 @@ use tracing_subscriber::{
|
|||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
|
||||
|
||||
#[expect(clippy::result_large_err)]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
|
|
|
|||
|
|
@ -50,175 +50,3 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
|
|||
fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
|
||||
cx.read(|cx| room.read(cx).channel_id())
|
||||
}
|
||||
|
||||
mod auth_token_tests {
|
||||
use collab::auth::{
|
||||
AccessTokenJson, VerifyAccessTokenResult, hash_access_token, verify_access_token,
|
||||
};
|
||||
use rand::prelude::*;
|
||||
use scrypt::Scrypt;
|
||||
use scrypt::password_hash::{PasswordHasher, SaltString};
|
||||
use sea_orm::EntityTrait;
|
||||
|
||||
use collab::db::{Database, NewUserParams, UserId, access_token};
|
||||
use collab::*;
|
||||
|
||||
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
|
||||
|
||||
async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
|
||||
const VERSION: usize = 1;
|
||||
let access_token = ::rpc::auth::random_token();
|
||||
let access_token_hash = hash_access_token(&access_token);
|
||||
let id = db
|
||||
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
|
||||
.await?;
|
||||
Ok(serde_json::to_string(&AccessTokenJson {
|
||||
version: VERSION,
|
||||
id,
|
||||
token: access_token,
|
||||
})?)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
|
||||
let test_db = crate::db_tests::TestDb::sqlite(cx.executor());
|
||||
let db = test_db.db();
|
||||
|
||||
let user = db
|
||||
.create_user(
|
||||
"example@example.com",
|
||||
None,
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "example".into(),
|
||||
github_user_id: 1,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let token = create_access_token(db, user.user_id).await.unwrap();
|
||||
assert!(matches!(
|
||||
verify_access_token(&token, user.user_id, db).await.unwrap(),
|
||||
VerifyAccessTokenResult { is_valid: true }
|
||||
));
|
||||
|
||||
let old_token = create_previous_access_token(user.user_id, db)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
|
||||
.unwrap()
|
||||
.id;
|
||||
|
||||
let hash = db
|
||||
.transaction(|tx| async move {
|
||||
Ok(access_token::Entity::find_by_id(old_token_id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.hash;
|
||||
assert!(hash.starts_with("$scrypt$"));
|
||||
|
||||
assert!(matches!(
|
||||
verify_access_token(&old_token, user.user_id, db)
|
||||
.await
|
||||
.unwrap(),
|
||||
VerifyAccessTokenResult { is_valid: true }
|
||||
));
|
||||
|
||||
let hash = db
|
||||
.transaction(|tx| async move {
|
||||
Ok(access_token::Entity::find_by_id(old_token_id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.hash;
|
||||
assert!(hash.starts_with("$sha256$"));
|
||||
|
||||
assert!(matches!(
|
||||
verify_access_token(&old_token, user.user_id, db)
|
||||
.await
|
||||
.unwrap(),
|
||||
VerifyAccessTokenResult { is_valid: true }
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
verify_access_token(&token, user.user_id, db).await.unwrap(),
|
||||
VerifyAccessTokenResult { is_valid: true }
|
||||
));
|
||||
}
|
||||
|
||||
async fn create_previous_access_token(user_id: UserId, db: &Database) -> Result<String> {
|
||||
let access_token = collab::auth::random_token();
|
||||
let access_token_hash = previous_hash_access_token(&access_token)?;
|
||||
let id = db
|
||||
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
|
||||
.await?;
|
||||
Ok(serde_json::to_string(&AccessTokenJson {
|
||||
version: 1,
|
||||
id,
|
||||
token: access_token,
|
||||
})?)
|
||||
}
|
||||
|
||||
#[expect(clippy::result_large_err)]
|
||||
fn previous_hash_access_token(token: &str) -> Result<String> {
|
||||
// Avoid slow hashing in debug mode.
|
||||
let params = if cfg!(debug_assertions) {
|
||||
scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
|
||||
} else {
|
||||
scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
|
||||
};
|
||||
|
||||
Ok(Scrypt
|
||||
.hash_password_customized(
|
||||
token.as_bytes(),
|
||||
None,
|
||||
None,
|
||||
params,
|
||||
&SaltString::generate(PasswordHashRngCompat::new()),
|
||||
)
|
||||
.map_err(anyhow::Error::new)?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
// TODO: remove once we password_hash v0.6 is released.
|
||||
struct PasswordHashRngCompat(rand::rngs::ThreadRng);
|
||||
|
||||
impl PasswordHashRngCompat {
|
||||
fn new() -> Self {
|
||||
Self(rand::rng())
|
||||
}
|
||||
}
|
||||
|
||||
impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.0.next_u32()
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
self.0.next_u64()
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
self.0.fill_bytes(dest);
|
||||
}
|
||||
|
||||
fn try_fill_bytes(
|
||||
&mut self,
|
||||
dest: &mut [u8],
|
||||
) -> Result<(), scrypt::password_hash::rand_core::Error> {
|
||||
self.fill_bytes(dest);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -564,6 +564,7 @@ impl TestServer {
|
|||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
http_client: None,
|
||||
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
executor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue