mock tests for jwk renewal

This commit is contained in:
Conrad Ludgate
2024-08-09 15:32:28 +01:00
parent 0d895ba002
commit db4085fe22

View File

@@ -2,6 +2,7 @@ use std::{sync::Arc, time::Duration};
use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use tokio::time::Instant;
@@ -20,6 +21,15 @@ pub struct JWKCacheEntryLock {
lookup: tokio::sync::Semaphore,
}
impl Default for JWKCacheEntryLock {
fn default() -> Self {
JWKCacheEntryLock {
cached: ArcSwapOption::empty(),
lookup: tokio::sync::Semaphore::new(1),
}
}
}
pub struct JWKCacheEntry {
/// Should refetch at least every hour to verify when old keys have been removed.
/// Should refetch when new key IDs are seen only every 5 minutes or so
@@ -28,18 +38,19 @@ pub struct JWKCacheEntry {
// /// jwks urls
// urls: Vec<url::Url>,
/// cplane will return multiple JWKs urls that we need to scrape.
key_sets: Vec<jose_jwk::JwkSet>,
key_sets: ahash::HashMap<url::Url, jose_jwk::JwkSet>,
}
/// How to get the JWT auth rules
#[allow(async_fn_in_trait)]
pub trait FetchAuthRules {
#[async_trait]
pub trait FetchAuthRules: Clone + Send + Sync + 'static {
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules>;
}
#[derive(Clone)]
struct FetchAuthFromCplane(EndpointIdInt);
#[async_trait]
impl FetchAuthRules for FetchAuthFromCplane {
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
bail!("not yet implemented")
@@ -91,24 +102,24 @@ impl JWKCacheEntryLock {
}
let rules = auth_rules.fetch_auth_rules().await?;
let mut key_sets = vec![];
let mut key_sets = ahash::HashMap::with_capacity_and_hasher(
rules.jwks_urls.len(),
ahash::RandomState::new(),
);
for url in rules.jwks_urls {
match client
.get(url.clone())
.send()
.await
.and_then(|r| r.error_for_status())
{
let req = client.get(url.clone());
match req.send().await.and_then(|r| r.error_for_status()) {
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
Ok(r) => match r.json::<jose_jwk::JwkSet>().await {
Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
Ok(jwks) => key_sets.push(jwks),
Ok(jwks) => {
key_sets.insert(url, jwks);
}
},
}
}
if key_sets.is_empty() {
bail!("no JWKs found")
}
let x = Arc::new(JWKCacheEntry {
last_retrieved: now,
@@ -118,16 +129,12 @@ impl JWKCacheEntryLock {
Ok(x)
}
}
const MIN_RENEW: Duration = Duration::from_secs(300);
const MAX_RENEW: Duration = Duration::from_secs(3600);
impl JWKCache {
pub async fn check_jwt(
&self,
endpoint: EndpointIdInt,
async fn check_jwt<F: FetchAuthRules>(
self: &Arc<Self>,
jwt: String,
client: &reqwest::Client,
fetch: &F,
) -> Result<(), anyhow::Error> {
// JWT compact form is defined to be
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
@@ -148,42 +155,24 @@ impl JWKCache {
ensure!(header.typ == "JWT");
let kid = header.kid.context("missing key id")?;
// try with just a read lock first
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
let entry = match entry {
Some(entry) => entry,
None => {
// acquire a write lock after to insert.
let entry = self.map.entry(endpoint).or_insert_with(|| {
Arc::new(JWKCacheEntryLock {
cached: ArcSwapOption::empty(),
lookup: tokio::sync::Semaphore::new(1),
})
});
Arc::clone(&*entry)
}
};
let fetch = FetchAuthFromCplane(endpoint);
// check if the cached JWKs need updating.
let mut guard = {
let now = Instant::now();
let guard = entry.cached.load_full();
let guard = self.cached.load_full();
if let Some(cached) = guard {
let last_update = now.duration_since(cached.last_retrieved);
if last_update > MAX_RENEW {
let permit = entry.acquire_permit().await;
let permit = self.acquire_permit().await;
// it's been too long since we checked the keys. wait for them to update.
entry.renew_jwks(permit, &self.client, &fetch).await?
self.renew_jwks(permit, client, fetch).await?
} else if last_update > MIN_RENEW {
// every 5 minutes we should spawn a job to eagerly update the token.
if let Some(permit) = entry.try_acquire_permit() {
let entry = entry.clone();
let client = self.client.clone();
if let Some(permit) = self.try_acquire_permit() {
let entry = self.clone();
let client = client.clone();
let fetch = fetch.clone();
tokio::spawn(async move {
if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
@@ -197,8 +186,8 @@ impl JWKCache {
cached
}
} else {
let permit = entry.acquire_permit().await;
entry.renew_jwks(permit, &self.client, &fetch).await?
let permit = self.acquire_permit().await;
self.renew_jwks(permit, client, fetch).await?
}
};
@@ -206,7 +195,7 @@ impl JWKCache {
let jwk = loop {
let jwk = guard
.key_sets
.iter()
.values()
.flat_map(|jwks| &jwks.keys)
.find(|jwk| {
jwk.prm.kid.as_deref() == Some(kid) && jwk.key.is_supported(&header.alg)
@@ -215,8 +204,8 @@ impl JWKCache {
match jwk {
Some(jwk) => break jwk,
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
let permit = entry.acquire_permit().await;
guard = entry.renew_jwks(permit, &self.client, &fetch).await?;
let permit = self.acquire_permit().await;
guard = self.renew_jwks(permit, client, fetch).await?;
}
_ => {
bail!("jwk not found");
@@ -242,6 +231,31 @@ impl JWKCache {
}
}
const MIN_RENEW: Duration = Duration::from_secs(300);
const MAX_RENEW: Duration = Duration::from_secs(3600);
impl JWKCache {
pub async fn check_jwt(
&self,
endpoint: EndpointIdInt,
jwt: String,
) -> Result<(), anyhow::Error> {
// try with just a read lock first
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
let entry = match entry {
Some(entry) => entry,
None => {
// acquire a write lock after to insert.
let entry = self.map.entry(endpoint).or_default();
Arc::clone(&*entry)
}
};
let fetch = FetchAuthFromCplane(endpoint);
entry.check_jwt(jwt, &self.client, &fetch).await
}
}
fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
use ecdsa::Signature;
use signature::Verifier;
@@ -311,3 +325,146 @@ impl Drop for AttachedPermit<'_> {
self.0.add_permits(1);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{future::IntoFuture, net::SocketAddr};
use anyhow::Error;
use async_trait::async_trait;
use bytes::Bytes;
use http::Response;
use http_body_util::Full;
use hyper1::service::service_fn;
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use tokio::net::TcpListener;
fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
let sk = p256::SecretKey::random(&mut OsRng);
let pk = sk.public_key().into();
let jwk = jose_jwk::Jwk {
key: jose_jwk::Key::Ec(pk),
prm: jose_jwk::Parameters {
kid: Some(kid),
..Default::default()
},
};
(sk, jwk)
}
fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
let sk = rsa::RsaPrivateKey::new(&mut OsRng, 1024).unwrap();
let pk = sk.to_public_key().into();
let jwk = jose_jwk::Jwk {
key: jose_jwk::Key::Rsa(pk),
prm: jose_jwk::Parameters {
kid: Some(kid),
..Default::default()
},
};
(sk, jwk)
}
#[tokio::test]
async fn renew() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
// let reports = Arc::new(Mutex::new(vec![]));
// let reports2 = reports.clone();
let server = hyper1::server::conn::http1::Builder::new();
// let server = hyper1::server::Server::from_tcp(listener)
// .unwrap()
// .serve(make_service_fn(move |_| {
// // let reports = reports.clone();
// async move {
// Ok::<_, Error>(service_fn(move |req| {
// // let reports = reports.clone();
// async move {
// // let bytes = hyper::body::to_bytes(req.into_body()).await?;
// // let events: EventChunk<'static, Event<Ids, String>> =
// // serde_json::from_slice(&bytes)?;
// // reports.lock().unwrap().push(events);
// Ok::<_, Error>(Response::new(Body::from(vec![])))
// }
// }))
// }
// }));
let (rs1, jwk1) = new_rsa_jwk("1".into());
let (rs2, jwk2) = new_rsa_jwk("2".into());
let (ec1, jwk3) = new_ec_jwk("3".into());
let (ec2, jwk4) = new_ec_jwk("4".into());
let foo = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
let bar = jose_jwk::JwkSet {
keys: vec![jwk2, jwk4],
};
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (s, _) = listener.accept().await.unwrap();
let foo = foo.clone();
let bar = bar.clone();
tokio::spawn(
server
.serve_connection(
TokioIo::new(s),
service_fn(move |req| {
let foo = foo.clone();
let bar = bar.clone();
async move {
let jwks = match req.uri().path() {
"/foo" => &foo,
"/bar" => &bar,
_ => {
return Ok::<_, Error>(
Response::builder()
.status(404)
.body(Full::new(Bytes::new()))
.unwrap(),
)
}
};
Ok::<_, Error>(Response::new(Full::new(Bytes::from(
serde_json::to_vec(jwks).unwrap(),
))))
}
}),
)
.into_future(),
);
}
});
let client = reqwest::Client::new();
#[derive(Clone)]
struct Fetch(SocketAddr);
#[async_trait]
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
Ok(AuthRules {
jwks_urls: vec![
format!("http://{}/foo", self.0).parse().unwrap(),
format!("http://{}/bar", self.0).parse().unwrap(),
],
})
}
}
let jwk_cache = Arc::new(JWKCacheEntryLock::default());
let permit = jwk_cache.acquire_permit().await;
let entry = jwk_cache
.renew_jwks(permit, &client, &Fetch(addr))
.await
.unwrap();
}
}