From 103f34e9545e1b67db0ccaa00307f18aab63ae94 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 7 Aug 2024 12:44:51 +0100 Subject: [PATCH] flesh out JWKs cache --- proxy/src/auth/backend/jwt.rs | 178 ++++++++++++++++++++++++++-------- 1 file changed, 138 insertions(+), 40 deletions(-) diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 20a7ac4cfc..118677f70e 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -1,8 +1,8 @@ use std::{sync::Arc, time::Duration}; use anyhow::{bail, ensure, Context}; -use arc_swap::{access::Access, ArcSwap, Guard}; -use dashmap::{mapref::entry::Entry, DashMap}; +use arc_swap::ArcSwapOption; +use dashmap::DashMap; use hmac::digest::generic_array::GenericArray; // use jose_jwa::S; use jose_jwk::crypto::KeyInfo; @@ -18,8 +18,8 @@ pub struct JWKCache { } pub struct JWKCacheEntryLock { - cached: ArcSwap, - lookup: tokio::sync::Mutex<()>, + cached: ArcSwapOption, + lookup: tokio::sync::Semaphore, } pub struct JWKCacheEntry { @@ -33,34 +33,92 @@ pub struct JWKCacheEntry { key_sets: Vec, } +/// How to get the JWT auth rules +#[allow(async_fn_in_trait)] +pub trait FetchAuthRules { + async fn fetch_auth_rules(&self) -> anyhow::Result; +} + +#[derive(Clone)] +struct FetchAuthFromCplane(EndpointIdInt); + +impl FetchAuthRules for FetchAuthFromCplane { + async fn fetch_auth_rules(&self) -> anyhow::Result { + bail!("not yet implemented") + } +} + +pub struct AuthRules { + jwks_urls: Vec, +} + impl JWKCacheEntryLock { - async fn renew_jwks( + async fn acquire_permit(&self) -> DetachedPermit { + match self.lookup.acquire().await { + Ok(permit) => { + permit.forget(); + DetachedPermit + } + Err(_) => panic!("semaphore should not be closed"), + } + } + + fn try_acquire_permit(&self) -> Option { + match self.lookup.try_acquire() { + Ok(permit) => { + permit.forget(); + Some(DetachedPermit) + } + Err(tokio::sync::TryAcquireError::NoPermits) => None, + Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"), + } + } + + async fn renew_jwks( &self, + permit: DetachedPermit, client: &reqwest::Client, - endpoint: EndpointIdInt, - ) -> Arc { - let _guard = self.lookup.lock().await; + auth_rules: &F, + ) -> anyhow::Result> { + let _permit = permit.attach(&self.lookup); // double check that no one beat us to updating the cache. let now = Instant::now(); let guard = self.cached.load_full(); - let last_update = now.duration_since(guard.last_retrieved); - if last_update < Duration::from_secs(300) { - return guard; + if let Some(cached) = guard { + let last_update = now.duration_since(cached.last_retrieved); + if last_update < Duration::from_secs(300) { + return Ok(cached); + } } - todo!("refetch jwks from cplane"); - - // fetch from cplane - // fetch jwks - // entry.last_retrieved = now; + let rules = auth_rules.fetch_auth_rules().await?; + let mut key_sets = vec![]; + for url in rules.jwks_urls { + match client + .get(url.clone()) + .send() + .await + .and_then(|r| r.error_for_status()) + { + Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"), + Ok(r) => match r.json::().await { + Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), + Ok(jwks) => key_sets.push(jwks), + }, + } + } + if key_sets.is_empty() { + bail!("no JWKs found") + } let x = Arc::new(JWKCacheEntry { last_retrieved: now, - key_sets: vec![], + key_sets, }); - self.cached.swap(x.clone()); - x + self.cached.swap(Some(x.clone())); + + Ok(x) } } @@ -92,35 +150,57 @@ impl JWKCache { ensure!(header.typ == "JWT"); let kid = header.kid.context("missing key id")?; - let entry = match self.map.get(&endpoint) { - Some(entry) => Arc::clone(&*entry), - None => match self.map.entry(endpoint) { - Entry::Occupied(entry) => Arc::clone(&*entry.into_ref()), - Entry::Vacant(_) => todo!("fetch jwks from cplane"), - }, + // try with just a read lock first + let entry = self.map.get(&endpoint).as_deref().map(Arc::clone); + let entry = match entry { + Some(entry) => entry, + None => { + // acquire a write lock after to insert. + let entry = self.map.entry(endpoint).or_insert_with(|| { + Arc::new(JWKCacheEntryLock { + cached: ArcSwapOption::empty(), + lookup: tokio::sync::Semaphore::new(1), + }) + }); + Arc::clone(&*entry) + } }; + let fetch = FetchAuthFromCplane(endpoint); + // check if the cached JWKs need updating. let mut guard = { let now = Instant::now(); - let guard = entry.cached.load(); - let last_update = now.duration_since(guard.last_retrieved); + let guard = entry.cached.load_full(); - if last_update > MAX_RENEW { - // it's been too long since we checked the keys. wait for them to update. - entry.renew_jwks(&self.client, endpoint).await - } else if last_update > MIN_RENEW { - // every 5 minutes we should spawn a job to eagerly update the token. + if let Some(cached) = guard { + let last_update = now.duration_since(cached.last_retrieved); - let entry = entry.clone(); - let client = self.client.clone(); - tokio::spawn(async move { - entry.renew_jwks(&client, endpoint).await; - }); + if last_update > MAX_RENEW { + let permit = entry.acquire_permit().await; - Guard::into_inner(guard) + // it's been too long since we checked the keys. wait for them to update. + entry.renew_jwks(permit, &self.client, &fetch).await? + } else if last_update > MIN_RENEW { + // every 5 minutes we should spawn a job to eagerly update the token. + if let Some(permit) = entry.try_acquire_permit() { + let entry = entry.clone(); + let client = self.client.clone(); + let fetch = fetch.clone(); + tokio::spawn(async move { + if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await { + tracing::warn!(error=?e, "could not fetch JWKs in background job"); + } + }); + } + + cached + } else { + cached + } } else { - Guard::into_inner(guard) + let permit = entry.acquire_permit().await; + entry.renew_jwks(permit, &self.client, &fetch).await? } }; @@ -137,7 +217,8 @@ impl JWKCache { match jwk { Some(jwk) => break jwk, None if guard.last_retrieved.elapsed() > MIN_RENEW => { - guard = entry.renew_jwks(&self.client, endpoint).await + let permit = entry.acquire_permit().await; + guard = entry.renew_jwks(permit, &self.client, &fetch).await?; } _ => { bail!("jwk not found"); @@ -290,3 +371,20 @@ struct JWTHeader<'a> { /// key id, must be provided for our usecase kid: Option<&'a str>, } + +// semaphore trickery +struct DetachedPermit; + +impl DetachedPermit { + fn attach(self, semaphore: &tokio::sync::Semaphore) -> AttachedPermit { + AttachedPermit(semaphore) + } +} + +struct AttachedPermit<'a>(&'a tokio::sync::Semaphore); + +impl Drop for AttachedPermit<'_> { + fn drop(&mut self) { + self.0.add_permits(1); + } +}