mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 01:42:55 +00:00
mock tests for jwk renewal
This commit is contained in:
@@ -2,6 +2,7 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
use tokio::time::Instant;
|
||||
@@ -20,6 +21,15 @@ pub struct JWKCacheEntryLock {
|
||||
lookup: tokio::sync::Semaphore,
|
||||
}
|
||||
|
||||
impl Default for JWKCacheEntryLock {
|
||||
fn default() -> Self {
|
||||
JWKCacheEntryLock {
|
||||
cached: ArcSwapOption::empty(),
|
||||
lookup: tokio::sync::Semaphore::new(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JWKCacheEntry {
|
||||
/// Should refetch at least every hour to verify when old keys have been removed.
|
||||
/// Should refetch when new key IDs are seen only every 5 minutes or so
|
||||
@@ -28,18 +38,19 @@ pub struct JWKCacheEntry {
|
||||
// /// jwks urls
|
||||
// urls: Vec<url::Url>,
|
||||
/// cplane will return multiple JWKs urls that we need to scrape.
|
||||
key_sets: Vec<jose_jwk::JwkSet>,
|
||||
key_sets: ahash::HashMap<url::Url, jose_jwk::JwkSet>,
|
||||
}
|
||||
|
||||
/// How to get the JWT auth rules
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait FetchAuthRules {
|
||||
#[async_trait]
|
||||
pub trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FetchAuthFromCplane(EndpointIdInt);
|
||||
|
||||
#[async_trait]
|
||||
impl FetchAuthRules for FetchAuthFromCplane {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
|
||||
bail!("not yet implemented")
|
||||
@@ -91,24 +102,24 @@ impl JWKCacheEntryLock {
|
||||
}
|
||||
|
||||
let rules = auth_rules.fetch_auth_rules().await?;
|
||||
let mut key_sets = vec![];
|
||||
let mut key_sets = ahash::HashMap::with_capacity_and_hasher(
|
||||
rules.jwks_urls.len(),
|
||||
ahash::RandomState::new(),
|
||||
);
|
||||
for url in rules.jwks_urls {
|
||||
match client
|
||||
.get(url.clone())
|
||||
.send()
|
||||
.await
|
||||
.and_then(|r| r.error_for_status())
|
||||
{
|
||||
let req = client.get(url.clone());
|
||||
match req.send().await.and_then(|r| r.error_for_status()) {
|
||||
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
|
||||
// I expect these failures would be quite sparse.
|
||||
Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
|
||||
Ok(r) => match r.json::<jose_jwk::JwkSet>().await {
|
||||
Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
|
||||
Ok(jwks) => key_sets.push(jwks),
|
||||
Ok(jwks) => {
|
||||
key_sets.insert(url, jwks);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
if key_sets.is_empty() {
|
||||
bail!("no JWKs found")
|
||||
}
|
||||
|
||||
let x = Arc::new(JWKCacheEntry {
|
||||
last_retrieved: now,
|
||||
@@ -118,16 +129,12 @@ impl JWKCacheEntryLock {
|
||||
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
const MIN_RENEW: Duration = Duration::from_secs(300);
|
||||
const MAX_RENEW: Duration = Duration::from_secs(3600);
|
||||
|
||||
impl JWKCache {
|
||||
pub async fn check_jwt(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
async fn check_jwt<F: FetchAuthRules>(
|
||||
self: &Arc<Self>,
|
||||
jwt: String,
|
||||
client: &reqwest::Client,
|
||||
fetch: &F,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// JWT compact form is defined to be
|
||||
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
|
||||
@@ -148,42 +155,24 @@ impl JWKCache {
|
||||
ensure!(header.typ == "JWT");
|
||||
let kid = header.kid.context("missing key id")?;
|
||||
|
||||
// try with just a read lock first
|
||||
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
|
||||
let entry = match entry {
|
||||
Some(entry) => entry,
|
||||
None => {
|
||||
// acquire a write lock after to insert.
|
||||
let entry = self.map.entry(endpoint).or_insert_with(|| {
|
||||
Arc::new(JWKCacheEntryLock {
|
||||
cached: ArcSwapOption::empty(),
|
||||
lookup: tokio::sync::Semaphore::new(1),
|
||||
})
|
||||
});
|
||||
Arc::clone(&*entry)
|
||||
}
|
||||
};
|
||||
|
||||
let fetch = FetchAuthFromCplane(endpoint);
|
||||
|
||||
// check if the cached JWKs need updating.
|
||||
let mut guard = {
|
||||
let now = Instant::now();
|
||||
let guard = entry.cached.load_full();
|
||||
let guard = self.cached.load_full();
|
||||
|
||||
if let Some(cached) = guard {
|
||||
let last_update = now.duration_since(cached.last_retrieved);
|
||||
|
||||
if last_update > MAX_RENEW {
|
||||
let permit = entry.acquire_permit().await;
|
||||
let permit = self.acquire_permit().await;
|
||||
|
||||
// it's been too long since we checked the keys. wait for them to update.
|
||||
entry.renew_jwks(permit, &self.client, &fetch).await?
|
||||
self.renew_jwks(permit, client, fetch).await?
|
||||
} else if last_update > MIN_RENEW {
|
||||
// every 5 minutes we should spawn a job to eagerly update the token.
|
||||
if let Some(permit) = entry.try_acquire_permit() {
|
||||
let entry = entry.clone();
|
||||
let client = self.client.clone();
|
||||
if let Some(permit) = self.try_acquire_permit() {
|
||||
let entry = self.clone();
|
||||
let client = client.clone();
|
||||
let fetch = fetch.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
|
||||
@@ -197,8 +186,8 @@ impl JWKCache {
|
||||
cached
|
||||
}
|
||||
} else {
|
||||
let permit = entry.acquire_permit().await;
|
||||
entry.renew_jwks(permit, &self.client, &fetch).await?
|
||||
let permit = self.acquire_permit().await;
|
||||
self.renew_jwks(permit, client, fetch).await?
|
||||
}
|
||||
};
|
||||
|
||||
@@ -206,7 +195,7 @@ impl JWKCache {
|
||||
let jwk = loop {
|
||||
let jwk = guard
|
||||
.key_sets
|
||||
.iter()
|
||||
.values()
|
||||
.flat_map(|jwks| &jwks.keys)
|
||||
.find(|jwk| {
|
||||
jwk.prm.kid.as_deref() == Some(kid) && jwk.key.is_supported(&header.alg)
|
||||
@@ -215,8 +204,8 @@ impl JWKCache {
|
||||
match jwk {
|
||||
Some(jwk) => break jwk,
|
||||
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
|
||||
let permit = entry.acquire_permit().await;
|
||||
guard = entry.renew_jwks(permit, &self.client, &fetch).await?;
|
||||
let permit = self.acquire_permit().await;
|
||||
guard = self.renew_jwks(permit, client, fetch).await?;
|
||||
}
|
||||
_ => {
|
||||
bail!("jwk not found");
|
||||
@@ -242,6 +231,31 @@ impl JWKCache {
|
||||
}
|
||||
}
|
||||
|
||||
const MIN_RENEW: Duration = Duration::from_secs(300);
|
||||
const MAX_RENEW: Duration = Duration::from_secs(3600);
|
||||
|
||||
impl JWKCache {
|
||||
pub async fn check_jwt(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
jwt: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// try with just a read lock first
|
||||
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
|
||||
let entry = match entry {
|
||||
Some(entry) => entry,
|
||||
None => {
|
||||
// acquire a write lock after to insert.
|
||||
let entry = self.map.entry(endpoint).or_default();
|
||||
Arc::clone(&*entry)
|
||||
}
|
||||
};
|
||||
|
||||
let fetch = FetchAuthFromCplane(endpoint);
|
||||
entry.check_jwt(jwt, &self.client, &fetch).await
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
|
||||
use ecdsa::Signature;
|
||||
use signature::Verifier;
|
||||
@@ -311,3 +325,146 @@ impl Drop for AttachedPermit<'_> {
|
||||
self.0.add_permits(1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::{future::IntoFuture, net::SocketAddr};
|
||||
|
||||
use anyhow::Error;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use http::Response;
|
||||
use http_body_util::Full;
|
||||
use hyper1::service::service_fn;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rand::rngs::OsRng;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
|
||||
let sk = p256::SecretKey::random(&mut OsRng);
|
||||
let pk = sk.public_key().into();
|
||||
let jwk = jose_jwk::Jwk {
|
||||
key: jose_jwk::Key::Ec(pk),
|
||||
prm: jose_jwk::Parameters {
|
||||
kid: Some(kid),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
(sk, jwk)
|
||||
}
|
||||
fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
|
||||
let sk = rsa::RsaPrivateKey::new(&mut OsRng, 1024).unwrap();
|
||||
let pk = sk.to_public_key().into();
|
||||
let jwk = jose_jwk::Jwk {
|
||||
key: jose_jwk::Key::Rsa(pk),
|
||||
prm: jose_jwk::Parameters {
|
||||
kid: Some(kid),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
(sk, jwk)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn renew() {
|
||||
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
|
||||
|
||||
// let reports = Arc::new(Mutex::new(vec![]));
|
||||
// let reports2 = reports.clone();
|
||||
|
||||
let server = hyper1::server::conn::http1::Builder::new();
|
||||
// let server = hyper1::server::Server::from_tcp(listener)
|
||||
// .unwrap()
|
||||
// .serve(make_service_fn(move |_| {
|
||||
// // let reports = reports.clone();
|
||||
// async move {
|
||||
// Ok::<_, Error>(service_fn(move |req| {
|
||||
// // let reports = reports.clone();
|
||||
// async move {
|
||||
// // let bytes = hyper::body::to_bytes(req.into_body()).await?;
|
||||
// // let events: EventChunk<'static, Event<Ids, String>> =
|
||||
// // serde_json::from_slice(&bytes)?;
|
||||
// // reports.lock().unwrap().push(events);
|
||||
// Ok::<_, Error>(Response::new(Body::from(vec![])))
|
||||
// }
|
||||
// }))
|
||||
// }
|
||||
// }));
|
||||
|
||||
let (rs1, jwk1) = new_rsa_jwk("1".into());
|
||||
let (rs2, jwk2) = new_rsa_jwk("2".into());
|
||||
let (ec1, jwk3) = new_ec_jwk("3".into());
|
||||
let (ec2, jwk4) = new_ec_jwk("4".into());
|
||||
|
||||
let foo = jose_jwk::JwkSet {
|
||||
keys: vec![jwk1, jwk3],
|
||||
};
|
||||
let bar = jose_jwk::JwkSet {
|
||||
keys: vec![jwk2, jwk4],
|
||||
};
|
||||
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (s, _) = listener.accept().await.unwrap();
|
||||
let foo = foo.clone();
|
||||
let bar = bar.clone();
|
||||
tokio::spawn(
|
||||
server
|
||||
.serve_connection(
|
||||
TokioIo::new(s),
|
||||
service_fn(move |req| {
|
||||
let foo = foo.clone();
|
||||
let bar = bar.clone();
|
||||
async move {
|
||||
let jwks = match req.uri().path() {
|
||||
"/foo" => &foo,
|
||||
"/bar" => &bar,
|
||||
_ => {
|
||||
return Ok::<_, Error>(
|
||||
Response::builder()
|
||||
.status(404)
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok::<_, Error>(Response::new(Full::new(Bytes::from(
|
||||
serde_json::to_vec(jwks).unwrap(),
|
||||
))))
|
||||
}
|
||||
}),
|
||||
)
|
||||
.into_future(),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Fetch(SocketAddr);
|
||||
|
||||
#[async_trait]
|
||||
impl FetchAuthRules for Fetch {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
|
||||
Ok(AuthRules {
|
||||
jwks_urls: vec![
|
||||
format!("http://{}/foo", self.0).parse().unwrap(),
|
||||
format!("http://{}/bar", self.0).parse().unwrap(),
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let jwk_cache = Arc::new(JWKCacheEntryLock::default());
|
||||
let permit = jwk_cache.acquire_permit().await;
|
||||
let entry = jwk_cache
|
||||
.renew_jwks(permit, &client, &Fetch(addr))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user