mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 09:52:54 +00:00
flesh out JWKs cache
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use arc_swap::{access::Access, ArcSwap, Guard};
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use dashmap::DashMap;
|
||||
use hmac::digest::generic_array::GenericArray;
|
||||
// use jose_jwa::S;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
@@ -18,8 +18,8 @@ pub struct JWKCache {
|
||||
}
|
||||
|
||||
pub struct JWKCacheEntryLock {
|
||||
cached: ArcSwap<JWKCacheEntry>,
|
||||
lookup: tokio::sync::Mutex<()>,
|
||||
cached: ArcSwapOption<JWKCacheEntry>,
|
||||
lookup: tokio::sync::Semaphore,
|
||||
}
|
||||
|
||||
pub struct JWKCacheEntry {
|
||||
@@ -33,34 +33,92 @@ pub struct JWKCacheEntry {
|
||||
key_sets: Vec<jose_jwk::JwkSet>,
|
||||
}
|
||||
|
||||
/// How to get the JWT auth rules
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait FetchAuthRules {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FetchAuthFromCplane(EndpointIdInt);
|
||||
|
||||
impl FetchAuthRules for FetchAuthFromCplane {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
|
||||
bail!("not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthRules {
|
||||
jwks_urls: Vec<url::Url>,
|
||||
}
|
||||
|
||||
impl JWKCacheEntryLock {
|
||||
async fn renew_jwks(
|
||||
async fn acquire_permit(&self) -> DetachedPermit {
|
||||
match self.lookup.acquire().await {
|
||||
Ok(permit) => {
|
||||
permit.forget();
|
||||
DetachedPermit
|
||||
}
|
||||
Err(_) => panic!("semaphore should not be closed"),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire_permit(&self) -> Option<DetachedPermit> {
|
||||
match self.lookup.try_acquire() {
|
||||
Ok(permit) => {
|
||||
permit.forget();
|
||||
Some(DetachedPermit)
|
||||
}
|
||||
Err(tokio::sync::TryAcquireError::NoPermits) => None,
|
||||
Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn renew_jwks<F: FetchAuthRules>(
|
||||
&self,
|
||||
permit: DetachedPermit,
|
||||
client: &reqwest::Client,
|
||||
endpoint: EndpointIdInt,
|
||||
) -> Arc<JWKCacheEntry> {
|
||||
let _guard = self.lookup.lock().await;
|
||||
auth_rules: &F,
|
||||
) -> anyhow::Result<Arc<JWKCacheEntry>> {
|
||||
let _permit = permit.attach(&self.lookup);
|
||||
|
||||
// double check that no one beat us to updating the cache.
|
||||
let now = Instant::now();
|
||||
let guard = self.cached.load_full();
|
||||
let last_update = now.duration_since(guard.last_retrieved);
|
||||
if last_update < Duration::from_secs(300) {
|
||||
return guard;
|
||||
if let Some(cached) = guard {
|
||||
let last_update = now.duration_since(cached.last_retrieved);
|
||||
if last_update < Duration::from_secs(300) {
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
|
||||
todo!("refetch jwks from cplane");
|
||||
|
||||
// fetch from cplane
|
||||
// fetch jwks
|
||||
// entry.last_retrieved = now;
|
||||
let rules = auth_rules.fetch_auth_rules().await?;
|
||||
let mut key_sets = vec![];
|
||||
for url in rules.jwks_urls {
|
||||
match client
|
||||
.get(url.clone())
|
||||
.send()
|
||||
.await
|
||||
.and_then(|r| r.error_for_status())
|
||||
{
|
||||
Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
|
||||
Ok(r) => match r.json::<jose_jwk::JwkSet>().await {
|
||||
Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
|
||||
Ok(jwks) => key_sets.push(jwks),
|
||||
},
|
||||
}
|
||||
}
|
||||
if key_sets.is_empty() {
|
||||
bail!("no JWKs found")
|
||||
}
|
||||
|
||||
let x = Arc::new(JWKCacheEntry {
|
||||
last_retrieved: now,
|
||||
key_sets: vec![],
|
||||
key_sets,
|
||||
});
|
||||
self.cached.swap(x.clone());
|
||||
x
|
||||
self.cached.swap(Some(x.clone()));
|
||||
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,35 +150,57 @@ impl JWKCache {
|
||||
ensure!(header.typ == "JWT");
|
||||
let kid = header.kid.context("missing key id")?;
|
||||
|
||||
let entry = match self.map.get(&endpoint) {
|
||||
Some(entry) => Arc::clone(&*entry),
|
||||
None => match self.map.entry(endpoint) {
|
||||
Entry::Occupied(entry) => Arc::clone(&*entry.into_ref()),
|
||||
Entry::Vacant(_) => todo!("fetch jwks from cplane"),
|
||||
},
|
||||
// try with just a read lock first
|
||||
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
|
||||
let entry = match entry {
|
||||
Some(entry) => entry,
|
||||
None => {
|
||||
// acquire a write lock after to insert.
|
||||
let entry = self.map.entry(endpoint).or_insert_with(|| {
|
||||
Arc::new(JWKCacheEntryLock {
|
||||
cached: ArcSwapOption::empty(),
|
||||
lookup: tokio::sync::Semaphore::new(1),
|
||||
})
|
||||
});
|
||||
Arc::clone(&*entry)
|
||||
}
|
||||
};
|
||||
|
||||
let fetch = FetchAuthFromCplane(endpoint);
|
||||
|
||||
// check if the cached JWKs need updating.
|
||||
let mut guard = {
|
||||
let now = Instant::now();
|
||||
let guard = entry.cached.load();
|
||||
let last_update = now.duration_since(guard.last_retrieved);
|
||||
let guard = entry.cached.load_full();
|
||||
|
||||
if last_update > MAX_RENEW {
|
||||
// it's been too long since we checked the keys. wait for them to update.
|
||||
entry.renew_jwks(&self.client, endpoint).await
|
||||
} else if last_update > MIN_RENEW {
|
||||
// every 5 minutes we should spawn a job to eagerly update the token.
|
||||
if let Some(cached) = guard {
|
||||
let last_update = now.duration_since(cached.last_retrieved);
|
||||
|
||||
let entry = entry.clone();
|
||||
let client = self.client.clone();
|
||||
tokio::spawn(async move {
|
||||
entry.renew_jwks(&client, endpoint).await;
|
||||
});
|
||||
if last_update > MAX_RENEW {
|
||||
let permit = entry.acquire_permit().await;
|
||||
|
||||
Guard::into_inner(guard)
|
||||
// it's been too long since we checked the keys. wait for them to update.
|
||||
entry.renew_jwks(permit, &self.client, &fetch).await?
|
||||
} else if last_update > MIN_RENEW {
|
||||
// every 5 minutes we should spawn a job to eagerly update the token.
|
||||
if let Some(permit) = entry.try_acquire_permit() {
|
||||
let entry = entry.clone();
|
||||
let client = self.client.clone();
|
||||
let fetch = fetch.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
|
||||
tracing::warn!(error=?e, "could not fetch JWKs in background job");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
cached
|
||||
} else {
|
||||
cached
|
||||
}
|
||||
} else {
|
||||
Guard::into_inner(guard)
|
||||
let permit = entry.acquire_permit().await;
|
||||
entry.renew_jwks(permit, &self.client, &fetch).await?
|
||||
}
|
||||
};
|
||||
|
||||
@@ -137,7 +217,8 @@ impl JWKCache {
|
||||
match jwk {
|
||||
Some(jwk) => break jwk,
|
||||
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
|
||||
guard = entry.renew_jwks(&self.client, endpoint).await
|
||||
let permit = entry.acquire_permit().await;
|
||||
guard = entry.renew_jwks(permit, &self.client, &fetch).await?;
|
||||
}
|
||||
_ => {
|
||||
bail!("jwk not found");
|
||||
@@ -290,3 +371,20 @@ struct JWTHeader<'a> {
|
||||
/// key id, must be provided for our usecase
|
||||
kid: Option<&'a str>,
|
||||
}
|
||||
|
||||
// semaphore trickery
|
||||
struct DetachedPermit;
|
||||
|
||||
impl DetachedPermit {
|
||||
fn attach(self, semaphore: &tokio::sync::Semaphore) -> AttachedPermit {
|
||||
AttachedPermit(semaphore)
|
||||
}
|
||||
}
|
||||
|
||||
struct AttachedPermit<'a>(&'a tokio::sync::Semaphore);
|
||||
|
||||
impl Drop for AttachedPermit<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.0.add_permits(1);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user