diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index f721d81aa2..517d4fd34b 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -132,6 +132,93 @@ struct JwkSet<'a> { keys: Vec<&'a RawValue>, } +/// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs. +/// Returns `None` and log a warning if there are any errors. +async fn fetch_jwks( + client: &reqwest_middleware::ClientWithMiddleware, + jwks_url: url::Url, +) -> Option { + let req = client.get(jwks_url.clone()); + // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. + let resp = req.send().await.and_then(|r| { + r.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }); + + let resp = match resp { + Ok(r) => r, + // TODO: should we re-insert JWKs if we want to keep this JWKs URL? + // I expect these failures would be quite sparse. + Err(e) => { + tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs"); + return None; + } + }; + + let resp: http::Response = resp.into(); + + let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs"); + return None; + } + }; + + let jwks = match serde_json::from_slice::(&bytes) { + Ok(jwks) => jwks, + Err(e) => { + tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs"); + return None; + } + }; + + // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need. + // + // Even though we limit our responses to 64KiB, we could still receive a payload like + // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB. + // Pre-allocating the corresponding `Vec::::with_capacity(30000)` uses 8.2MiB. + let mut keys = vec![]; + + let mut failed = 0; + for key in jwks.keys { + let key = match serde_json::from_str::(key.get()) { + Ok(key) => key, + Err(e) => { + tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK"); + failed += 1; + continue; + } + }; + + // if `use` (called `cls` in rust) is specified to be something other than signing, + // we can skip storing it. + if key + .prm + .cls + .as_ref() + .is_some_and(|c| *c != jose_jwk::Class::Signing) + { + continue; + } + + keys.push(key); + } + + keys.shrink_to_fit(); + + if failed > 0 { + tracing::warn!(url=?jwks_url, failed, "could not decode JWKs"); + } + + if keys.is_empty() { + tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body"); + return None; + } + + Some(jose_jwk::JwkSet { keys }) +} + impl JwkCacheEntryLock { async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { JwkRenewalPermit::acquire_permit(self).await @@ -166,87 +253,15 @@ impl JwkCacheEntryLock { // TODO(conrad): run concurrently // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284) for rule in rules { - let req = client.get(rule.jwks_url.clone()); - // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. - // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. - match req.send().await.and_then(|r| { - r.error_for_status() - .map_err(reqwest_middleware::Error::Reqwest) - }) { - // todo: should we re-insert JWKs if we want to keep this JWKs URL? - // I expect these failures would be quite sparse. - Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), - Ok(r) => { - let resp: http::Response = r.into(); - - let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE) - .await - { - Ok(bytes) => bytes, - Err(e) => { - tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); - continue; - } - }; - - match serde_json::from_slice::(&bytes) { - Err(e) => { - tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); - } - Ok(jwks) => { - // size_of::<&RawValue>() == 16 - // size_of::() == 288 - // better to not pre-allocate this as it might be pretty large - especially if it has many - // keys we don't want or need. - // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}` - // this would consume 8MiB just like that! - let mut keys = vec![]; - let mut failed = 0; - for key in jwks.keys { - match serde_json::from_str::(key.get()) { - Ok(key) => { - // if `use` (called `cls` in rust) is specified to be something other than signing, - // we can skip storing it. - if key - .prm - .cls - .as_ref() - .is_some_and(|c| *c != jose_jwk::Class::Signing) - { - continue; - } - - keys.push(key); - } - Err(e) => { - tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK"); - failed += 1; - } - } - } - keys.shrink_to_fit(); - - if failed > 0 { - tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs"); - } - - if keys.is_empty() { - tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body"); - continue; - } - - let jwks = jose_jwk::JwkSet { keys }; - key_sets.insert( - rule.id, - KeySet { - jwks, - audience: rule.audience, - role_names: rule.role_names, - }, - ); - } - }; - } + if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await { + key_sets.insert( + rule.id, + KeySet { + jwks, + audience: rule.audience, + role_names: rule.role_names, + }, + ); } }