mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-04 14:00:38 +00:00
chore(proxy/jwks): reduce the rightward drift of jwks renewal (#9853)
I found the rightward drift of the `renew_jwks` function hard to review. This PR splits out some major logic and uses early returns to make the happy path more linear.
This commit is contained in:
@@ -132,6 +132,93 @@ struct JwkSet<'a> {
|
||||
keys: Vec<&'a RawValue>,
|
||||
}
|
||||
|
||||
/// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
|
||||
/// Returns `None` and log a warning if there are any errors.
|
||||
async fn fetch_jwks(
|
||||
client: &reqwest_middleware::ClientWithMiddleware,
|
||||
jwks_url: url::Url,
|
||||
) -> Option<jose_jwk::JwkSet> {
|
||||
let req = client.get(jwks_url.clone());
|
||||
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
|
||||
let resp = req.send().await.and_then(|r| {
|
||||
r.error_for_status()
|
||||
.map_err(reqwest_middleware::Error::Reqwest)
|
||||
});
|
||||
|
||||
let resp = match resp {
|
||||
Ok(r) => r,
|
||||
// TODO: should we re-insert JWKs if we want to keep this JWKs URL?
|
||||
// I expect these failures would be quite sparse.
|
||||
Err(e) => {
|
||||
tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let resp: http::Response<reqwest::Body> = resp.into();
|
||||
|
||||
let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
|
||||
Ok(jwks) => jwks,
|
||||
Err(e) => {
|
||||
tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
|
||||
//
|
||||
// Even though we limit our responses to 64KiB, we could still receive a payload like
|
||||
// `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
|
||||
// Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
|
||||
let mut keys = vec![];
|
||||
|
||||
let mut failed = 0;
|
||||
for key in jwks.keys {
|
||||
let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
|
||||
Ok(key) => key,
|
||||
Err(e) => {
|
||||
tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
|
||||
failed += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// if `use` (called `cls` in rust) is specified to be something other than signing,
|
||||
// we can skip storing it.
|
||||
if key
|
||||
.prm
|
||||
.cls
|
||||
.as_ref()
|
||||
.is_some_and(|c| *c != jose_jwk::Class::Signing)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
keys.push(key);
|
||||
}
|
||||
|
||||
keys.shrink_to_fit();
|
||||
|
||||
if failed > 0 {
|
||||
tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
|
||||
}
|
||||
|
||||
if keys.is_empty() {
|
||||
tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(jose_jwk::JwkSet { keys })
|
||||
}
|
||||
|
||||
impl JwkCacheEntryLock {
|
||||
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
|
||||
JwkRenewalPermit::acquire_permit(self).await
|
||||
@@ -166,87 +253,15 @@ impl JwkCacheEntryLock {
|
||||
// TODO(conrad): run concurrently
|
||||
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
|
||||
for rule in rules {
|
||||
let req = client.get(rule.jwks_url.clone());
|
||||
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
|
||||
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
|
||||
match req.send().await.and_then(|r| {
|
||||
r.error_for_status()
|
||||
.map_err(reqwest_middleware::Error::Reqwest)
|
||||
}) {
|
||||
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
|
||||
// I expect these failures would be quite sparse.
|
||||
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
|
||||
Ok(r) => {
|
||||
let resp: http::Response<reqwest::Body> = r.into();
|
||||
|
||||
let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
|
||||
.await
|
||||
{
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_slice::<JwkSet>(&bytes) {
|
||||
Err(e) => {
|
||||
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
|
||||
}
|
||||
Ok(jwks) => {
|
||||
// size_of::<&RawValue>() == 16
|
||||
// size_of::<jose_jwk::Jwk>() == 288
|
||||
// better to not pre-allocate this as it might be pretty large - especially if it has many
|
||||
// keys we don't want or need.
|
||||
// trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
|
||||
// this would consume 8MiB just like that!
|
||||
let mut keys = vec![];
|
||||
let mut failed = 0;
|
||||
for key in jwks.keys {
|
||||
match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
|
||||
Ok(key) => {
|
||||
// if `use` (called `cls` in rust) is specified to be something other than signing,
|
||||
// we can skip storing it.
|
||||
if key
|
||||
.prm
|
||||
.cls
|
||||
.as_ref()
|
||||
.is_some_and(|c| *c != jose_jwk::Class::Signing)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
keys.push(key);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
|
||||
failed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
keys.shrink_to_fit();
|
||||
|
||||
if failed > 0 {
|
||||
tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
|
||||
}
|
||||
|
||||
if keys.is_empty() {
|
||||
tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
|
||||
continue;
|
||||
}
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys };
|
||||
key_sets.insert(
|
||||
rule.id,
|
||||
KeySet {
|
||||
jwks,
|
||||
audience: rule.audience,
|
||||
role_names: rule.role_names,
|
||||
},
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
|
||||
key_sets.insert(
|
||||
rule.id,
|
||||
KeySet {
|
||||
jwks,
|
||||
audience: rule.audience,
|
||||
role_names: rule.role_names,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user