From 7cb234929648897323d6f08f842e3458121244eb Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 Aug 2024 11:48:57 +0100 Subject: [PATCH] add jwks size limiter --- Cargo.lock | 7 +++++ proxy/src/auth/backend/jwt.rs | 49 +++++++++++++++++++++++++++++++---- workspace_hack/Cargo.toml | 12 +++++++-- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4905e08012..dee15b6aa7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7614,13 +7614,17 @@ dependencies = [ "clap", "clap_builder", "crossbeam-utils", + "crypto-bigint 0.5.5", + "der 0.7.8", "deranged", + "digest", "either", "fail", "futures-channel", "futures-executor", "futures-io", "futures-util", + "generic-array", "getrandom 0.2.11", "hashbrown 0.14.5", "hex", @@ -7628,6 +7632,7 @@ dependencies = [ "hyper 0.14.26", "indexmap 1.9.3", "itertools 0.10.5", + "lazy_static", "libc", "log", "memchr", @@ -7651,7 +7656,9 @@ dependencies = [ "serde", "serde_json", "sha2", + "signature 2.2.0", "smallvec", + "spki 0.7.3", "subtle", "syn 1.0.109", "syn 2.0.52", diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index a63f1c1f0b..9fa2da5e0f 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -3,8 +3,10 @@ use std::{sync::Arc, time::Duration}; use anyhow::{bail, ensure, Context}; use arc_swap::ArcSwapOption; use async_trait::async_trait; +use bytes::Bytes; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; +use serde::de::DeserializeOwned; use signature::Verifier; use tokio::time::Instant; @@ -116,12 +118,27 @@ impl JWKCacheEntryLock { // todo: should we re-insert JWKs if we want to keep this JWKs URL? // I expect these failures would be quite sparse. Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"), - Ok(r) => match r.json::().await { - Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), - Ok(jwks) => { - key_sets.insert(url, jwks); + Ok(r) => { + if r.content_length() + .is_some_and(|l| l > MAX_JWK_BODY_SIZE as u64) + { + tracing::warn!( + ?url, + error = "JWKs response too large", + "could not decode JWKs" + ); + continue; } - }, + + let resp: http::Response = r.into(); + match parse_json::(resp.into_body(), MAX_JWK_BODY_SIZE).await + { + Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), + Ok(jwks) => { + key_sets.insert(url, jwks); + } + } + } } } @@ -343,6 +360,28 @@ impl Drop for AttachedPermit<'_> { } } +const MAX_JWK_BODY_SIZE: usize = 64 * 1024; +pub async fn parse_json( + mut b: impl hyper1::body::Body + Unpin, + limit: usize, +) -> anyhow::Result +where + D: DeserializeOwned, +{ + use http_body_util::BodyExt; + let mut bytes = vec![]; + while let Some(frame) = b.frame().await.transpose()? { + if let Ok(data) = frame.into_data() { + if bytes.len() + data.len() > limit { + bail!("overflow") + } + bytes.extend_from_slice(&data); + } + } + + Ok(serde_json::from_slice::(&bytes)?) +} + #[cfg(test)] mod tests { use super::*; diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 832fe06bf6..2d9b372654 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -30,13 +30,17 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd clap = { version = "4", features = ["derive", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } +crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] } +der = { version = "0.7", default-features = false, features = ["oid", "pem", "std"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } +digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures-channel = { version = "0.3", features = ["sink"] } futures-executor = { version = "0.3" } futures-io = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } +generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } @@ -44,6 +48,7 @@ hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } +lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } memchr = { version = "2" } @@ -64,8 +69,10 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] } scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["raw_value"] } -sha2 = { version = "0.10", features = ["asm"] } +sha2 = { version = "0.10", features = ["asm", "oid"] } +signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] } smallvec = { version = "1", default-features = false, features = ["const_new", "write"] } +spki = { version = "0.7", default-features = false, features = ["pem", "std"] } subtle = { version = "2" } sync_wrapper = { version = "0.1", default-features = false, features = ["futures"] } tikv-jemalloc-sys = { version = "0.5" } @@ -81,7 +88,7 @@ tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } url = { version = "2", features = ["serde"] } uuid = { version = "1", features = ["serde", "v4", "v7"] } -zeroize = { version = "1", features = ["derive"] } +zeroize = { version = "1", features = ["derive", "serde"] } zstd = { version = "0.13" } zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] } zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] } @@ -97,6 +104,7 @@ getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown = { version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } +lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } memchr = { version = "2" }