chore(proxy/jwks): reduce the rightward drift of jwks renewal (#9853)

I found the rightward drift of the `renew_jwks` function hard to review.

This PR splits out some major logic and uses early returns to make the
happy path more linear.
This commit is contained in:
Conrad Ludgate
2024-11-22 14:51:32 +00:00
committed by GitHub
parent 51d26a261b
commit 8ab96cc71f

View File

@@ -132,6 +132,93 @@ struct JwkSet<'a> {
keys: Vec<&'a RawValue>,
}
/// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
/// Returns `None` and log a warning if there are any errors.
async fn fetch_jwks(
client: &reqwest_middleware::ClientWithMiddleware,
jwks_url: url::Url,
) -> Option<jose_jwk::JwkSet> {
let req = client.get(jwks_url.clone());
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
let resp = req.send().await.and_then(|r| {
r.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
});
let resp = match resp {
Ok(r) => r,
// TODO: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => {
tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
return None;
}
};
let resp: http::Response<reqwest::Body> = resp.into();
let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
return None;
}
};
let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
Ok(jwks) => jwks,
Err(e) => {
tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
return None;
}
};
// `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
//
// Even though we limit our responses to 64KiB, we could still receive a payload like
// `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
// Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
let mut keys = vec![];
let mut failed = 0;
for key in jwks.keys {
let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
Ok(key) => key,
Err(e) => {
tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
failed += 1;
continue;
}
};
// if `use` (called `cls` in rust) is specified to be something other than signing,
// we can skip storing it.
if key
.prm
.cls
.as_ref()
.is_some_and(|c| *c != jose_jwk::Class::Signing)
{
continue;
}
keys.push(key);
}
keys.shrink_to_fit();
if failed > 0 {
tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
}
if keys.is_empty() {
tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
return None;
}
Some(jose_jwk::JwkSet { keys })
}
impl JwkCacheEntryLock {
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
JwkRenewalPermit::acquire_permit(self).await
@@ -166,87 +253,15 @@ impl JwkCacheEntryLock {
// TODO(conrad): run concurrently
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
for rule in rules {
let req = client.get(rule.jwks_url.clone());
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
match req.send().await.and_then(|r| {
r.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
}) {
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
.await
{
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
continue;
}
};
match serde_json::from_slice::<JwkSet>(&bytes) {
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
}
Ok(jwks) => {
// size_of::<&RawValue>() == 16
// size_of::<jose_jwk::Jwk>() == 288
// better to not pre-allocate this as it might be pretty large - especially if it has many
// keys we don't want or need.
// trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
// this would consume 8MiB just like that!
let mut keys = vec![];
let mut failed = 0;
for key in jwks.keys {
match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
Ok(key) => {
// if `use` (called `cls` in rust) is specified to be something other than signing,
// we can skip storing it.
if key
.prm
.cls
.as_ref()
.is_some_and(|c| *c != jose_jwk::Class::Signing)
{
continue;
}
keys.push(key);
}
Err(e) => {
tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
failed += 1;
}
}
}
keys.shrink_to_fit();
if failed > 0 {
tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
}
if keys.is_empty() {
tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
continue;
}
let jwks = jose_jwk::JwkSet { keys };
key_sets.insert(
rule.id,
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
};
}
if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
key_sets.insert(
rule.id,
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
}