diff --git a/Cargo.lock b/Cargo.lock index 7d18f44aec..00d58be2d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4743,6 +4743,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rustls 0.22.4", + "rustls-native-certs 0.7.0", "rustls-pemfile 2.1.1", "rustls-pki-types", "serde", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index efd336dbea..1665d6361a 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -60,7 +60,7 @@ prometheus.workspace = true rand.workspace = true regex.workspace = true remote_storage = { version = "0.1", path = "../libs/remote_storage/" } -reqwest.workspace = true +reqwest = { workspace = true, features = ["rustls-tls-native-roots"] } reqwest-middleware = { workspace = true, features = ["json"] } reqwest-retry.workspace = true reqwest-tracing.workspace = true diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 83c3617612..bfc674139b 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -7,8 +7,11 @@ use arc_swap::ArcSwapOption; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; use reqwest::{redirect, Client}; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; use serde::de::Visitor; use serde::{Deserialize, Deserializer}; +use serde_json::value::RawValue; use signature::Verifier; use thiserror::Error; use tokio::time::Instant; @@ -16,7 +19,7 @@ use tokio::time::Instant; use crate::auth::backend::ComputeCredentialKeys; use crate::context::RequestMonitoring; use crate::control_plane::errors::GetEndpointJwksError; -use crate::http::parse_json_body_with_limit; +use crate::http::read_body_with_limit; use crate::intern::RoleNameInt; use crate::types::{EndpointId, RoleName}; @@ -28,6 +31,10 @@ const MAX_RENEW: Duration = Duration::from_secs(3600); const MAX_JWK_BODY_SIZE: usize = 64 * 1024; const JWKS_USER_AGENT: &str = "neon-proxy"; +const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2); +const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5); +const JWKS_FETCH_RETRIES: u32 = 3; + /// How to get the JWT auth rules pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static { fn fetch_auth_rules( @@ -55,7 +62,7 @@ pub(crate) struct AuthRule { } pub struct JwkCache { - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, map: DashMap<(EndpointId, RoleName), Arc>, } @@ -117,6 +124,14 @@ impl Default for JwkCacheEntryLock { } } +#[derive(Deserialize)] +struct JwkSet<'a> { + /// we parse into raw-value because not all keys in a JWKS are ones + /// we can parse directly, so we parse them lazily. + #[serde(borrow)] + keys: Vec<&'a RawValue>, +} + impl JwkCacheEntryLock { async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { JwkRenewalPermit::acquire_permit(self).await @@ -130,7 +145,7 @@ impl JwkCacheEntryLock { &self, _permit: JwkRenewalPermit<'_>, ctx: &RequestMonitoring, - client: &reqwest::Client, + client: &reqwest_middleware::ClientWithMiddleware, endpoint: EndpointId, auth_rules: &F, ) -> Result, JwtError> { @@ -154,22 +169,73 @@ impl JwkCacheEntryLock { let req = client.get(rule.jwks_url.clone()); // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. - match req.send().await.and_then(|r| r.error_for_status()) { + match req.send().await.and_then(|r| { + r.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }) { // todo: should we re-insert JWKs if we want to keep this JWKs URL? // I expect these failures would be quite sparse. Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), Ok(r) => { let resp: http::Response = r.into(); - match parse_json_body_with_limit::( - resp.into_body(), - MAX_JWK_BODY_SIZE, - ) - .await + + let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE) + .await { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); + continue; + } + }; + + match serde_json::from_slice::(&bytes) { Err(e) => { tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); } Ok(jwks) => { + // size_of::<&RawValue>() == 16 + // size_of::() == 288 + // better to not pre-allocate this as it might be pretty large - especially if it has many + // keys we don't want or need. + // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}` + // this would consume 8MiB just like that! + let mut keys = vec![]; + let mut failed = 0; + for key in jwks.keys { + match serde_json::from_str::(key.get()) { + Ok(key) => { + // if `use` (called `cls` in rust) is specified to be something other than signing, + // we can skip storing it. + if key + .prm + .cls + .as_ref() + .is_some_and(|c| *c != jose_jwk::Class::Signing) + { + continue; + } + + keys.push(key); + } + Err(e) => { + tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK"); + failed += 1; + } + } + } + keys.shrink_to_fit(); + + if failed > 0 { + tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs"); + } + + if keys.is_empty() { + tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body"); + continue; + } + + let jwks = jose_jwk::JwkSet { keys }; key_sets.insert( rule.id, KeySet { @@ -179,7 +245,7 @@ impl JwkCacheEntryLock { }, ); } - } + }; } } } @@ -196,7 +262,7 @@ impl JwkCacheEntryLock { async fn get_or_update_jwk_cache( self: &Arc, ctx: &RequestMonitoring, - client: &reqwest::Client, + client: &reqwest_middleware::ClientWithMiddleware, endpoint: EndpointId, fetch: &F, ) -> Result, JwtError> { @@ -250,7 +316,7 @@ impl JwkCacheEntryLock { self: &Arc, ctx: &RequestMonitoring, jwt: &str, - client: &reqwest::Client, + client: &reqwest_middleware::ClientWithMiddleware, endpoint: EndpointId, role_name: &RoleName, fetch: &F, @@ -369,8 +435,19 @@ impl Default for JwkCache { let client = Client::builder() .user_agent(JWKS_USER_AGENT) .redirect(redirect::Policy::none()) + .tls_built_in_native_certs(true) + .connect_timeout(JWKS_CONNECT_TIMEOUT) + .timeout(JWKS_FETCH_TIMEOUT) .build() - .expect("using &str and standard redirect::Policy"); + .expect("client config should be valid"); + + // Retry up to 3 times with increasing intervals between attempts. + let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES); + + let client = reqwest_middleware::ClientBuilder::new(client) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + JwkCache { client, map: DashMap::default(), @@ -1209,4 +1286,63 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL } } } + + #[tokio::test] + async fn check_jwk_keycloak_regression() { + let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into()); + let valid_jwk = serde_json::to_value(valid_jwk).unwrap(); + + // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones. + // This is taken directly from keycloak. + let invalid_jwk = serde_json::json! { + { + "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ", + "kty": "RSA", + "alg": "RSA-OAEP", + "use": "enc", + "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw", + "e": "AQAB", + "x5c": [ + "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI=" + ], + "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs", + "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU" + } + }; + + let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }}; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let role_name = RoleName::from("anonymous"); + let role = RoleNameInt::from(&role_name); + + let rules = vec![AuthRule { + id: "foo".to_owned(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![role], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let endpoint = EndpointId::from("ep"); + + let token = new_rsa_jwt("rs1".into(), rs); + + jwk_cache + .check_jwt( + &RequestMonitoring::test(), + endpoint.clone(), + &role_name, + &fetch, + &token, + ) + .await + .unwrap(); + } } diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index f1b632e704..b1642cedb3 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -6,7 +6,6 @@ pub mod health_server; use std::time::Duration; -use anyhow::bail; use bytes::Bytes; use http::Method; use http_body_util::BodyExt; @@ -16,7 +15,7 @@ use reqwest_middleware::RequestBuilder; pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error}; pub(crate) use reqwest_retry::policies::ExponentialBackoff; pub(crate) use reqwest_retry::RetryTransientMiddleware; -use serde::de::DeserializeOwned; +use thiserror::Error; use crate::metrics::{ConsoleRequest, Metrics}; use crate::url::ApiUrl; @@ -122,10 +121,19 @@ impl Endpoint { } } -pub(crate) async fn parse_json_body_with_limit( +#[derive(Error, Debug)] +pub(crate) enum ReadBodyError { + #[error("Content length exceeds limit of {limit} bytes")] + BodyTooLarge { limit: usize }, + + #[error(transparent)] + Read(#[from] reqwest::Error), +} + +pub(crate) async fn read_body_with_limit( mut b: impl Body + Unpin, limit: usize, -) -> anyhow::Result { +) -> Result, ReadBodyError> { // We could use `b.limited().collect().await.to_bytes()` here // but this ends up being slightly more efficient as far as I can tell. @@ -133,20 +141,20 @@ pub(crate) async fn parse_json_body_with_limit( // in reqwest, this value is influenced by the Content-Length header. let lower_bound = match usize::try_from(b.size_hint().lower()) { Ok(bound) if bound <= limit => bound, - _ => bail!("Content length exceeds limit of {limit} bytes"), + _ => return Err(ReadBodyError::BodyTooLarge { limit }), }; let mut bytes = Vec::with_capacity(lower_bound); while let Some(frame) = b.frame().await.transpose()? { if let Ok(data) = frame.into_data() { if bytes.len() + data.len() > limit { - bail!("Content length exceeds limit of {limit} bytes") + return Err(ReadBodyError::BodyTooLarge { limit }); } bytes.extend_from_slice(&data); } } - Ok(serde_json::from_slice::(&bytes)?) + Ok(bytes) } #[cfg(test)] diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 00a8ac4768..61c39c32c9 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -16,8 +16,7 @@ use super::http_conn_pool::ClientDataHttp; use super::local_conn_pool::ClientDataLocal; use crate::auth::backend::ComputeUserInfo; use crate::context::RequestMonitoring; -use crate::control_plane::messages::ColdStartInfo; -use crate::control_plane::messages::MetricsAuxInfo; +use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::types::{DbName, EndpointCacheKey, RoleName}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 56be70abec..a1d4473b01 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -7,7 +7,6 @@ use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use rand::Rng; -use std::result::Result::Ok; use tokio::net::TcpStream; use tracing::{debug, error, info, info_span, Instrument}; diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 02deecd385..ae4018a884 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -64,7 +64,7 @@ rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] } regex-syntax = { version = "0.8" } -reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] } +reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] } rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] } scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] }