flesh out JWKs cache

This commit is contained in:
Conrad Ludgate
2024-08-07 12:44:51 +01:00
parent 262378e561
commit 103f34e954

View File

@@ -1,8 +1,8 @@
use std::{sync::Arc, time::Duration};
use anyhow::{bail, ensure, Context};
use arc_swap::{access::Access, ArcSwap, Guard};
use dashmap::{mapref::entry::Entry, DashMap};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use hmac::digest::generic_array::GenericArray;
// use jose_jwa::S;
use jose_jwk::crypto::KeyInfo;
@@ -18,8 +18,8 @@ pub struct JWKCache {
}
pub struct JWKCacheEntryLock {
cached: ArcSwap<JWKCacheEntry>,
lookup: tokio::sync::Mutex<()>,
cached: ArcSwapOption<JWKCacheEntry>,
lookup: tokio::sync::Semaphore,
}
pub struct JWKCacheEntry {
@@ -33,34 +33,92 @@ pub struct JWKCacheEntry {
key_sets: Vec<jose_jwk::JwkSet>,
}
/// How to get the JWT auth rules
#[allow(async_fn_in_trait)]
pub trait FetchAuthRules {
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules>;
}
#[derive(Clone)]
struct FetchAuthFromCplane(EndpointIdInt);
impl FetchAuthRules for FetchAuthFromCplane {
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
bail!("not yet implemented")
}
}
pub struct AuthRules {
jwks_urls: Vec<url::Url>,
}
impl JWKCacheEntryLock {
async fn renew_jwks(
async fn acquire_permit(&self) -> DetachedPermit {
match self.lookup.acquire().await {
Ok(permit) => {
permit.forget();
DetachedPermit
}
Err(_) => panic!("semaphore should not be closed"),
}
}
fn try_acquire_permit(&self) -> Option<DetachedPermit> {
match self.lookup.try_acquire() {
Ok(permit) => {
permit.forget();
Some(DetachedPermit)
}
Err(tokio::sync::TryAcquireError::NoPermits) => None,
Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
}
}
async fn renew_jwks<F: FetchAuthRules>(
&self,
permit: DetachedPermit,
client: &reqwest::Client,
endpoint: EndpointIdInt,
) -> Arc<JWKCacheEntry> {
let _guard = self.lookup.lock().await;
auth_rules: &F,
) -> anyhow::Result<Arc<JWKCacheEntry>> {
let _permit = permit.attach(&self.lookup);
// double check that no one beat us to updating the cache.
let now = Instant::now();
let guard = self.cached.load_full();
let last_update = now.duration_since(guard.last_retrieved);
if last_update < Duration::from_secs(300) {
return guard;
if let Some(cached) = guard {
let last_update = now.duration_since(cached.last_retrieved);
if last_update < Duration::from_secs(300) {
return Ok(cached);
}
}
todo!("refetch jwks from cplane");
// fetch from cplane
// fetch jwks
// entry.last_retrieved = now;
let rules = auth_rules.fetch_auth_rules().await?;
let mut key_sets = vec![];
for url in rules.jwks_urls {
match client
.get(url.clone())
.send()
.await
.and_then(|r| r.error_for_status())
{
Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
Ok(r) => match r.json::<jose_jwk::JwkSet>().await {
Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
Ok(jwks) => key_sets.push(jwks),
},
}
}
if key_sets.is_empty() {
bail!("no JWKs found")
}
let x = Arc::new(JWKCacheEntry {
last_retrieved: now,
key_sets: vec![],
key_sets,
});
self.cached.swap(x.clone());
x
self.cached.swap(Some(x.clone()));
Ok(x)
}
}
@@ -92,35 +150,57 @@ impl JWKCache {
ensure!(header.typ == "JWT");
let kid = header.kid.context("missing key id")?;
let entry = match self.map.get(&endpoint) {
Some(entry) => Arc::clone(&*entry),
None => match self.map.entry(endpoint) {
Entry::Occupied(entry) => Arc::clone(&*entry.into_ref()),
Entry::Vacant(_) => todo!("fetch jwks from cplane"),
},
// try with just a read lock first
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
let entry = match entry {
Some(entry) => entry,
None => {
// acquire a write lock after to insert.
let entry = self.map.entry(endpoint).or_insert_with(|| {
Arc::new(JWKCacheEntryLock {
cached: ArcSwapOption::empty(),
lookup: tokio::sync::Semaphore::new(1),
})
});
Arc::clone(&*entry)
}
};
let fetch = FetchAuthFromCplane(endpoint);
// check if the cached JWKs need updating.
let mut guard = {
let now = Instant::now();
let guard = entry.cached.load();
let last_update = now.duration_since(guard.last_retrieved);
let guard = entry.cached.load_full();
if last_update > MAX_RENEW {
// it's been too long since we checked the keys. wait for them to update.
entry.renew_jwks(&self.client, endpoint).await
} else if last_update > MIN_RENEW {
// every 5 minutes we should spawn a job to eagerly update the token.
if let Some(cached) = guard {
let last_update = now.duration_since(cached.last_retrieved);
let entry = entry.clone();
let client = self.client.clone();
tokio::spawn(async move {
entry.renew_jwks(&client, endpoint).await;
});
if last_update > MAX_RENEW {
let permit = entry.acquire_permit().await;
Guard::into_inner(guard)
// it's been too long since we checked the keys. wait for them to update.
entry.renew_jwks(permit, &self.client, &fetch).await?
} else if last_update > MIN_RENEW {
// every 5 minutes we should spawn a job to eagerly update the token.
if let Some(permit) = entry.try_acquire_permit() {
let entry = entry.clone();
let client = self.client.clone();
let fetch = fetch.clone();
tokio::spawn(async move {
if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
tracing::warn!(error=?e, "could not fetch JWKs in background job");
}
});
}
cached
} else {
cached
}
} else {
Guard::into_inner(guard)
let permit = entry.acquire_permit().await;
entry.renew_jwks(permit, &self.client, &fetch).await?
}
};
@@ -137,7 +217,8 @@ impl JWKCache {
match jwk {
Some(jwk) => break jwk,
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
guard = entry.renew_jwks(&self.client, endpoint).await
let permit = entry.acquire_permit().await;
guard = entry.renew_jwks(permit, &self.client, &fetch).await?;
}
_ => {
bail!("jwk not found");
@@ -290,3 +371,20 @@ struct JWTHeader<'a> {
/// key id, must be provided for our usecase
kid: Option<&'a str>,
}
// semaphore trickery
struct DetachedPermit;
impl DetachedPermit {
fn attach(self, semaphore: &tokio::sync::Semaphore) -> AttachedPermit {
AttachedPermit(semaphore)
}
}
struct AttachedPermit<'a>(&'a tokio::sync::Semaphore);
impl Drop for AttachedPermit<'_> {
fn drop(&mut self) {
self.0.add_permits(1);
}
}