From c1cb7a0fa0d0bb6b58aa0f3e0979905476a19225 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 23 Aug 2024 18:01:02 +0100 Subject: [PATCH] proxy: flesh out JWT verification code (#8805) This change adds in the necessary verification steps for the JWT payload, and adds per-role querying of JWKs as needed for #8736 --- proxy/src/auth/backend/jwt.rs | 295 +++++++++++++++++++++++----------- 1 file changed, 203 insertions(+), 92 deletions(-) diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index e021a7e23f..49d5de16c3 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -1,15 +1,21 @@ -use std::{future::Future, sync::Arc, time::Duration}; +use std::{ + future::Future, + sync::Arc, + time::{Duration, SystemTime}, +}; use anyhow::{bail, ensure, Context}; use arc_swap::ArcSwapOption; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; +use serde::{Deserialize, Deserializer}; use signature::Verifier; use tokio::time::Instant; -use crate::{http::parse_json_body_with_limit, intern::EndpointIdInt}; +use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName}; // TODO(conrad): make these configurable. +const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30); const MIN_RENEW: Duration = Duration::from_secs(30); const AUTO_RENEW: Duration = Duration::from_secs(300); const MAX_RENEW: Duration = Duration::from_secs(3600); @@ -17,30 +23,56 @@ const MAX_JWK_BODY_SIZE: usize = 64 * 1024; /// How to get the JWT auth rules pub trait FetchAuthRules: Clone + Send + Sync + 'static { - fn fetch_auth_rules(&self) -> impl Future> + Send; + fn fetch_auth_rules( + &self, + role_name: RoleName, + ) -> impl Future>> + Send; } -#[derive(Clone)] -struct FetchAuthRulesFromCplane { - #[allow(dead_code)] - endpoint: EndpointIdInt, -} - -impl FetchAuthRules for FetchAuthRulesFromCplane { - async fn fetch_auth_rules(&self) -> anyhow::Result { - Err(anyhow::anyhow!("not yet implemented")) - } -} - -pub struct AuthRules { - jwks_urls: Vec, +pub struct AuthRule { + pub id: String, + pub jwks_url: url::Url, + pub audience: Option, } #[derive(Default)] pub struct JwkCache { client: reqwest::Client, - map: DashMap>, + map: DashMap<(EndpointId, RoleName), Arc>, +} + +pub struct JwkCacheEntry { + /// Should refetch at least every hour to verify when old keys have been removed. + /// Should refetch when new key IDs are seen only every 5 minutes or so + last_retrieved: Instant, + + /// cplane will return multiple JWKs urls that we need to scrape. + key_sets: ahash::HashMap, +} + +impl JwkCacheEntry { + fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> { + self.key_sets.values().find_map(|key_set| { + key_set + .find_key(key_id) + .map(|jwk| (jwk, key_set.audience.as_deref())) + }) + } +} + +struct KeySet { + jwks: jose_jwk::JwkSet, + audience: Option, +} + +impl KeySet { + fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> { + self.jwks + .keys + .iter() + .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id)) + } } pub struct JwkCacheEntryLock { @@ -57,15 +89,6 @@ impl Default for JwkCacheEntryLock { } } -pub struct JwkCacheEntry { - /// Should refetch at least every hour to verify when old keys have been removed. - /// Should refetch when new key IDs are seen only every 5 minutes or so - last_retrieved: Instant, - - /// cplane will return multiple JWKs urls that we need to scrape. - key_sets: ahash::HashMap, -} - impl JwkCacheEntryLock { async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { JwkRenewalPermit::acquire_permit(self).await @@ -79,6 +102,7 @@ impl JwkCacheEntryLock { &self, _permit: JwkRenewalPermit<'_>, client: &reqwest::Client, + role_name: RoleName, auth_rules: &F, ) -> anyhow::Result> { // double check that no one beat us to updating the cache. @@ -91,20 +115,19 @@ impl JwkCacheEntryLock { } } - let rules = auth_rules.fetch_auth_rules().await?; - let mut key_sets = ahash::HashMap::with_capacity_and_hasher( - rules.jwks_urls.len(), - ahash::RandomState::new(), - ); + let rules = auth_rules.fetch_auth_rules(role_name).await?; + let mut key_sets = + ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new()); // TODO(conrad): run concurrently // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284) - for url in rules.jwks_urls { - let req = client.get(url.clone()); + for rule in rules { + let req = client.get(rule.jwks_url.clone()); // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. + // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. match req.send().await.and_then(|r| r.error_for_status()) { // todo: should we re-insert JWKs if we want to keep this JWKs URL? // I expect these failures would be quite sparse. - Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"), + Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), Ok(r) => { let resp: http::Response = r.into(); match parse_json_body_with_limit::( @@ -113,9 +136,17 @@ impl JwkCacheEntryLock { ) .await { - Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), + Err(e) => { + tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); + } Ok(jwks) => { - key_sets.insert(url, jwks); + key_sets.insert( + rule.id, + KeySet { + jwks, + audience: rule.audience, + }, + ); } } } @@ -133,7 +164,9 @@ impl JwkCacheEntryLock { async fn get_or_update_jwk_cache( self: &Arc, + ctx: &RequestMonitoring, client: &reqwest::Client, + role_name: RoleName, fetch: &F, ) -> Result, anyhow::Error> { let now = Instant::now(); @@ -141,18 +174,20 @@ impl JwkCacheEntryLock { // if we have no cached JWKs, try and get some let Some(cached) = guard else { + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let permit = self.acquire_permit().await; - return self.renew_jwks(permit, client, fetch).await; + return self.renew_jwks(permit, client, role_name, fetch).await; }; let last_update = now.duration_since(cached.last_retrieved); // check if the cached JWKs need updating. if last_update > MAX_RENEW { + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let permit = self.acquire_permit().await; // it's been too long since we checked the keys. wait for them to update. - return self.renew_jwks(permit, client, fetch).await; + return self.renew_jwks(permit, client, role_name, fetch).await; } // every 5 minutes we should spawn a job to eagerly update the token. @@ -164,7 +199,7 @@ impl JwkCacheEntryLock { let client = client.clone(); let fetch = fetch.clone(); tokio::spawn(async move { - if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await { + if let Err(e) = entry.renew_jwks(permit, &client, role_name, &fetch).await { tracing::warn!(error=?e, "could not fetch JWKs in background job"); } }); @@ -178,8 +213,10 @@ impl JwkCacheEntryLock { async fn check_jwt( self: &Arc, - jwt: String, + ctx: &RequestMonitoring, + jwt: &str, client: &reqwest::Client, + role_name: RoleName, fetch: &F, ) -> Result<(), anyhow::Error> { // JWT compact form is defined to be @@ -189,36 +226,36 @@ impl JwkCacheEntryLock { let (header_payload, signature) = jwt .rsplit_once(".") .context("Provided authentication token is not a valid JWT encoding")?; - let (header, _payload) = header_payload + let (header, payload) = header_payload .split_once(".") .context("Provided authentication token is not a valid JWT encoding")?; let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; - let header = serde_json::from_slice::>(&header) + let header = serde_json::from_slice::>(&header) .context("Provided authentication token is not a valid JWT encoding")?; let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; ensure!(header.typ == "JWT"); - let kid = header.kid.context("missing key id")?; + let kid = header.key_id.context("missing key id")?; - let mut guard = self.get_or_update_jwk_cache(client, fetch).await?; + let mut guard = self + .get_or_update_jwk_cache(ctx, client, role_name.clone(), fetch) + .await?; // get the key from the JWKs if possible. If not, wait for the keys to update. - let jwk = loop { - let jwk = guard - .key_sets - .values() - .flat_map(|jwks| &jwks.keys) - .find(|jwk| jwk.prm.kid.as_deref() == Some(kid)); - - match jwk { + let (jwk, expected_audience) = loop { + match guard.find_jwk_and_audience(kid) { Some(jwk) => break jwk, None if guard.last_retrieved.elapsed() > MIN_RENEW => { + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let permit = self.acquire_permit().await; - guard = self.renew_jwks(permit, client, fetch).await?; + guard = self + .renew_jwks(permit, client, role_name.clone(), fetch) + .await?; } _ => { bail!("jwk not found"); @@ -227,7 +264,7 @@ impl JwkCacheEntryLock { }; ensure!( - jwk.is_supported(&header.alg), + jwk.is_supported(&header.algorithm), "signature algorithm not supported" ); @@ -241,31 +278,60 @@ impl JwkCacheEntryLock { key => bail!("unsupported key type {key:?}"), }; - // TODO(conrad): verify iss, exp, nbf, etc... + let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) + .context("Provided authentication token is not a valid JWT encoding")?; + let payload = serde_json::from_slice::>(&payload) + .context("Provided authentication token is not a valid JWT encoding")?; + + tracing::debug!(?payload, "JWT signature valid with claims"); + + match (expected_audience, payload.audience) { + // check the audience matches + (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"), + // the audience is expected but is missing + (Some(_), None) => bail!("invalid JWT token audience"), + // we don't care for the audience field + (None, _) => {} + } + + let now = SystemTime::now(); + + if let Some(exp) = payload.expiration { + ensure!(now < exp + CLOCK_SKEW_LEEWAY); + } + + if let Some(nbf) = payload.not_before { + ensure!(nbf < now + CLOCK_SKEW_LEEWAY); + } Ok(()) } } impl JwkCache { - pub async fn check_jwt( + pub async fn check_jwt( &self, - endpoint: EndpointIdInt, - jwt: String, + ctx: &RequestMonitoring, + endpoint: EndpointId, + role_name: RoleName, + fetch: &F, + jwt: &str, ) -> Result<(), anyhow::Error> { // try with just a read lock first - let entry = self.map.get(&endpoint).as_deref().map(Arc::clone); + let key = (endpoint, role_name.clone()); + let entry = self.map.get(&key).as_deref().map(Arc::clone); let entry = match entry { Some(entry) => entry, None => { // acquire a write lock after to insert. - let entry = self.map.entry(endpoint).or_default(); + let entry = self.map.entry(key).or_default(); Arc::clone(&*entry) } }; - let fetch = FetchAuthRulesFromCplane { endpoint }; - entry.check_jwt(jwt, &self.client, &fetch).await + entry + .check_jwt(ctx, jwt, &self.client, role_name, fetch) + .await } } @@ -315,13 +381,49 @@ fn verify_rsa_signature( /// #[derive(serde::Deserialize, serde::Serialize)] -struct JWTHeader<'a> { +struct JwtHeader<'a> { /// must be "JWT" + #[serde(rename = "typ")] typ: &'a str, /// must be a supported alg - alg: jose_jwa::Algorithm, + #[serde(rename = "alg")] + algorithm: jose_jwa::Algorithm, /// key id, must be provided for our usecase - kid: Option<&'a str>, + #[serde(rename = "kid")] + key_id: Option<&'a str>, +} + +/// +#[derive(serde::Deserialize, serde::Serialize, Debug)] +struct JwtPayload<'a> { + /// Audience - Recipient for which the JWT is intended + #[serde(rename = "aud")] + audience: Option<&'a str>, + /// Expiration - Time after which the JWT expires + #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)] + expiration: Option, + /// Not before - Time after which the JWT expires + #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)] + not_before: Option, + + // the following entries are only extracted for the sake of debug logging. + /// Issuer of the JWT + #[serde(rename = "iss")] + issuer: Option<&'a str>, + /// Subject of the JWT (the user) + #[serde(rename = "sub")] + subject: Option<&'a str>, + /// Unique token identifier + #[serde(rename = "jti")] + jwt_id: Option<&'a str>, + /// Unique session identifier + #[serde(rename = "sid")] + session_id: Option<&'a str>, +} + +fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let d = >::deserialize(d)?; + Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n))) } struct JwkRenewalPermit<'a> { @@ -388,6 +490,8 @@ impl Drop for JwkRenewalPermit<'_> { #[cfg(test)] mod tests { + use crate::RoleName; + use super::*; use std::{future::IntoFuture, net::SocketAddr, time::SystemTime}; @@ -431,10 +535,10 @@ mod tests { } fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { - let header = JWTHeader { + let header = JwtHeader { typ: "JWT", - alg: jose_jwa::Algorithm::Signing(sig), - kid: Some(&kid), + algorithm: jose_jwa::Algorithm::Signing(sig), + key_id: Some(&kid), }; let body = typed_json::json! {{ "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600, @@ -524,33 +628,40 @@ mod tests { struct Fetch(SocketAddr); impl FetchAuthRules for Fetch { - async fn fetch_auth_rules(&self) -> anyhow::Result { - Ok(AuthRules { - jwks_urls: vec![ - format!("http://{}/foo", self.0).parse().unwrap(), - format!("http://{}/bar", self.0).parse().unwrap(), - ], - }) + async fn fetch_auth_rules( + &self, + _role_name: RoleName, + ) -> anyhow::Result> { + Ok(vec![ + AuthRule { + id: "foo".to_owned(), + jwks_url: format!("http://{}/foo", self.0).parse().unwrap(), + audience: None, + }, + AuthRule { + id: "bar".to_owned(), + jwks_url: format!("http://{}/bar", self.0).parse().unwrap(), + audience: None, + }, + ]) } } + let role_name = RoleName::from("user"); + let jwk_cache = Arc::new(JwkCacheEntryLock::default()); - jwk_cache - .check_jwt(jwt1, &client, &Fetch(addr)) - .await - .unwrap(); - jwk_cache - .check_jwt(jwt2, &client, &Fetch(addr)) - .await - .unwrap(); - jwk_cache - .check_jwt(jwt3, &client, &Fetch(addr)) - .await - .unwrap(); - jwk_cache - .check_jwt(jwt4, &client, &Fetch(addr)) - .await - .unwrap(); + for token in [jwt1, jwt2, jwt3, jwt4] { + jwk_cache + .check_jwt( + &RequestMonitoring::test(), + &token, + &client, + role_name.clone(), + &Fetch(addr), + ) + .await + .unwrap(); + } } }