From db4085fe22321948738a7cdde20dc7fba9d92ba2 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 9 Aug 2024 15:32:28 +0100 Subject: [PATCH] mock tests for jwk renewal --- proxy/src/auth/backend/jwt.rs | 259 +++++++++++++++++++++++++++------- 1 file changed, 208 insertions(+), 51 deletions(-) diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 0994beca23..6585d6f539 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -2,6 +2,7 @@ use std::{sync::Arc, time::Duration}; use anyhow::{bail, ensure, Context}; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; use tokio::time::Instant; @@ -20,6 +21,15 @@ pub struct JWKCacheEntryLock { lookup: tokio::sync::Semaphore, } +impl Default for JWKCacheEntryLock { + fn default() -> Self { + JWKCacheEntryLock { + cached: ArcSwapOption::empty(), + lookup: tokio::sync::Semaphore::new(1), + } + } +} + pub struct JWKCacheEntry { /// Should refetch at least every hour to verify when old keys have been removed. /// Should refetch when new key IDs are seen only every 5 minutes or so @@ -28,18 +38,19 @@ pub struct JWKCacheEntry { // /// jwks urls // urls: Vec, /// cplane will return multiple JWKs urls that we need to scrape. - key_sets: Vec, + key_sets: ahash::HashMap, } /// How to get the JWT auth rules -#[allow(async_fn_in_trait)] -pub trait FetchAuthRules { +#[async_trait] +pub trait FetchAuthRules: Clone + Send + Sync + 'static { async fn fetch_auth_rules(&self) -> anyhow::Result; } #[derive(Clone)] struct FetchAuthFromCplane(EndpointIdInt); +#[async_trait] impl FetchAuthRules for FetchAuthFromCplane { async fn fetch_auth_rules(&self) -> anyhow::Result { bail!("not yet implemented") @@ -91,24 +102,24 @@ impl JWKCacheEntryLock { } let rules = auth_rules.fetch_auth_rules().await?; - let mut key_sets = vec![]; + let mut key_sets = ahash::HashMap::with_capacity_and_hasher( + rules.jwks_urls.len(), + ahash::RandomState::new(), + ); for url in rules.jwks_urls { - match client - .get(url.clone()) - .send() - .await - .and_then(|r| r.error_for_status()) - { + let req = client.get(url.clone()); + match req.send().await.and_then(|r| r.error_for_status()) { + // todo: should we re-insert JWKs if we want to keep this JWKs URL? + // I expect these failures would be quite sparse. Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"), Ok(r) => match r.json::().await { Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), - Ok(jwks) => key_sets.push(jwks), + Ok(jwks) => { + key_sets.insert(url, jwks); + } }, } } - if key_sets.is_empty() { - bail!("no JWKs found") - } let x = Arc::new(JWKCacheEntry { last_retrieved: now, @@ -118,16 +129,12 @@ impl JWKCacheEntryLock { Ok(x) } -} -const MIN_RENEW: Duration = Duration::from_secs(300); -const MAX_RENEW: Duration = Duration::from_secs(3600); - -impl JWKCache { - pub async fn check_jwt( - &self, - endpoint: EndpointIdInt, + async fn check_jwt( + self: &Arc, jwt: String, + client: &reqwest::Client, + fetch: &F, ) -> Result<(), anyhow::Error> { // JWT compact form is defined to be // || . || || . || @@ -148,42 +155,24 @@ impl JWKCache { ensure!(header.typ == "JWT"); let kid = header.kid.context("missing key id")?; - // try with just a read lock first - let entry = self.map.get(&endpoint).as_deref().map(Arc::clone); - let entry = match entry { - Some(entry) => entry, - None => { - // acquire a write lock after to insert. - let entry = self.map.entry(endpoint).or_insert_with(|| { - Arc::new(JWKCacheEntryLock { - cached: ArcSwapOption::empty(), - lookup: tokio::sync::Semaphore::new(1), - }) - }); - Arc::clone(&*entry) - } - }; - - let fetch = FetchAuthFromCplane(endpoint); - // check if the cached JWKs need updating. let mut guard = { let now = Instant::now(); - let guard = entry.cached.load_full(); + let guard = self.cached.load_full(); if let Some(cached) = guard { let last_update = now.duration_since(cached.last_retrieved); if last_update > MAX_RENEW { - let permit = entry.acquire_permit().await; + let permit = self.acquire_permit().await; // it's been too long since we checked the keys. wait for them to update. - entry.renew_jwks(permit, &self.client, &fetch).await? + self.renew_jwks(permit, client, fetch).await? } else if last_update > MIN_RENEW { // every 5 minutes we should spawn a job to eagerly update the token. - if let Some(permit) = entry.try_acquire_permit() { - let entry = entry.clone(); - let client = self.client.clone(); + if let Some(permit) = self.try_acquire_permit() { + let entry = self.clone(); + let client = client.clone(); let fetch = fetch.clone(); tokio::spawn(async move { if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await { @@ -197,8 +186,8 @@ impl JWKCache { cached } } else { - let permit = entry.acquire_permit().await; - entry.renew_jwks(permit, &self.client, &fetch).await? + let permit = self.acquire_permit().await; + self.renew_jwks(permit, client, fetch).await? } }; @@ -206,7 +195,7 @@ impl JWKCache { let jwk = loop { let jwk = guard .key_sets - .iter() + .values() .flat_map(|jwks| &jwks.keys) .find(|jwk| { jwk.prm.kid.as_deref() == Some(kid) && jwk.key.is_supported(&header.alg) @@ -215,8 +204,8 @@ impl JWKCache { match jwk { Some(jwk) => break jwk, None if guard.last_retrieved.elapsed() > MIN_RENEW => { - let permit = entry.acquire_permit().await; - guard = entry.renew_jwks(permit, &self.client, &fetch).await?; + let permit = self.acquire_permit().await; + guard = self.renew_jwks(permit, client, fetch).await?; } _ => { bail!("jwk not found"); @@ -242,6 +231,31 @@ impl JWKCache { } } +const MIN_RENEW: Duration = Duration::from_secs(300); +const MAX_RENEW: Duration = Duration::from_secs(3600); + +impl JWKCache { + pub async fn check_jwt( + &self, + endpoint: EndpointIdInt, + jwt: String, + ) -> Result<(), anyhow::Error> { + // try with just a read lock first + let entry = self.map.get(&endpoint).as_deref().map(Arc::clone); + let entry = match entry { + Some(entry) => entry, + None => { + // acquire a write lock after to insert. + let entry = self.map.entry(endpoint).or_default(); + Arc::clone(&*entry) + } + }; + + let fetch = FetchAuthFromCplane(endpoint); + entry.check_jwt(jwt, &self.client, &fetch).await + } +} + fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> { use ecdsa::Signature; use signature::Verifier; @@ -311,3 +325,146 @@ impl Drop for AttachedPermit<'_> { self.0.add_permits(1); } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::{future::IntoFuture, net::SocketAddr}; + + use anyhow::Error; + use async_trait::async_trait; + use bytes::Bytes; + use http::Response; + use http_body_util::Full; + use hyper1::service::service_fn; + use hyper_util::rt::TokioIo; + use rand::rngs::OsRng; + use tokio::net::TcpListener; + + fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) { + let sk = p256::SecretKey::random(&mut OsRng); + let pk = sk.public_key().into(); + let jwk = jose_jwk::Jwk { + key: jose_jwk::Key::Ec(pk), + prm: jose_jwk::Parameters { + kid: Some(kid), + ..Default::default() + }, + }; + (sk, jwk) + } + fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) { + let sk = rsa::RsaPrivateKey::new(&mut OsRng, 1024).unwrap(); + let pk = sk.to_public_key().into(); + let jwk = jose_jwk::Jwk { + key: jose_jwk::Key::Rsa(pk), + prm: jose_jwk::Parameters { + kid: Some(kid), + ..Default::default() + }, + }; + (sk, jwk) + } + + #[tokio::test] + async fn renew() { + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + + // let reports = Arc::new(Mutex::new(vec![])); + // let reports2 = reports.clone(); + + let server = hyper1::server::conn::http1::Builder::new(); + // let server = hyper1::server::Server::from_tcp(listener) + // .unwrap() + // .serve(make_service_fn(move |_| { + // // let reports = reports.clone(); + // async move { + // Ok::<_, Error>(service_fn(move |req| { + // // let reports = reports.clone(); + // async move { + // // let bytes = hyper::body::to_bytes(req.into_body()).await?; + // // let events: EventChunk<'static, Event> = + // // serde_json::from_slice(&bytes)?; + // // reports.lock().unwrap().push(events); + // Ok::<_, Error>(Response::new(Body::from(vec![]))) + // } + // })) + // } + // })); + + let (rs1, jwk1) = new_rsa_jwk("1".into()); + let (rs2, jwk2) = new_rsa_jwk("2".into()); + let (ec1, jwk3) = new_ec_jwk("3".into()); + let (ec2, jwk4) = new_ec_jwk("4".into()); + + let foo = jose_jwk::JwkSet { + keys: vec![jwk1, jwk3], + }; + let bar = jose_jwk::JwkSet { + keys: vec![jwk2, jwk4], + }; + + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + loop { + let (s, _) = listener.accept().await.unwrap(); + let foo = foo.clone(); + let bar = bar.clone(); + tokio::spawn( + server + .serve_connection( + TokioIo::new(s), + service_fn(move |req| { + let foo = foo.clone(); + let bar = bar.clone(); + async move { + let jwks = match req.uri().path() { + "/foo" => &foo, + "/bar" => &bar, + _ => { + return Ok::<_, Error>( + Response::builder() + .status(404) + .body(Full::new(Bytes::new())) + .unwrap(), + ) + } + }; + + Ok::<_, Error>(Response::new(Full::new(Bytes::from( + serde_json::to_vec(jwks).unwrap(), + )))) + } + }), + ) + .into_future(), + ); + } + }); + + let client = reqwest::Client::new(); + + #[derive(Clone)] + struct Fetch(SocketAddr); + + #[async_trait] + impl FetchAuthRules for Fetch { + async fn fetch_auth_rules(&self) -> anyhow::Result { + Ok(AuthRules { + jwks_urls: vec![ + format!("http://{}/foo", self.0).parse().unwrap(), + format!("http://{}/bar", self.0).parse().unwrap(), + ], + }) + } + } + + let jwk_cache = Arc::new(JWKCacheEntryLock::default()); + let permit = jwk_cache.acquire_permit().await; + let entry = jwk_cache + .renew_jwks(permit, &client, &Fetch(addr)) + .await + .unwrap(); + } +}