From a7028d92b7560228e6cf63a3cf102c076bad3aa6 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 14 Aug 2024 13:35:29 +0100 Subject: [PATCH] proxy: start of jwk cache (#8690) basic JWT implementation that caches JWKs and verifies signatures. this code is currently not reachable from proxy, I just wanted to get something merged in. --- Cargo.lock | 273 ++++++++++++++++- deny.toml | 5 +- proxy/Cargo.toml | 11 +- proxy/src/auth/backend.rs | 1 + proxy/src/auth/backend/jwt.rs | 556 ++++++++++++++++++++++++++++++++++ proxy/src/http.rs | 33 ++ workspace_hack/Cargo.toml | 12 +- 7 files changed, 872 insertions(+), 19 deletions(-) create mode 100644 proxy/src/auth/backend/jwt.rs diff --git a/Cargo.lock b/Cargo.lock index 031fae0f37..dee15b6aa7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,7 +484,7 @@ dependencies = [ "http 0.2.9", "http 1.1.0", "once_cell", - "p256", + "p256 0.11.1", "percent-encoding", "ring 0.17.6", "sha2", @@ -848,6 +848,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.13.1" @@ -971,9 +977,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" [[package]] name = "bytemuck" -version = "1.16.0" +version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" +checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" [[package]] name = "byteorder" @@ -1526,8 +1532,10 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" dependencies = [ + "generic-array", "rand_core 0.6.4", "subtle", + "zeroize", ] [[package]] @@ -1621,6 +1629,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" dependencies = [ "const-oid", + "pem-rfc7468", "zeroize", ] @@ -1720,6 +1729,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -1771,11 +1781,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "413301934810f597c1d19ca71c8710e99a3f1ba28a0d2ebc01551a2daeea3c5c" dependencies = [ "der 0.6.1", - "elliptic-curve", - "rfc6979", + "elliptic-curve 0.12.3", + "rfc6979 0.3.1", "signature 1.6.4", ] +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der 0.7.8", + "digest", + "elliptic-curve 0.13.8", + "rfc6979 0.4.0", + "signature 2.2.0", + "spki 0.7.3", +] + [[package]] name = "either" version = "1.8.1" @@ -1788,16 +1812,36 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7bb888ab5300a19b8e5bceef25ac745ad065f3c9f7efc6de1b91958110891d3" dependencies = [ - "base16ct", + "base16ct 0.1.1", "crypto-bigint 0.4.9", "der 0.6.1", "digest", - "ff", + "ff 0.12.1", "generic-array", - "group", - "pkcs8", + "group 0.12.1", + "pkcs8 0.9.0", "rand_core 0.6.4", - "sec1", + "sec1 0.3.0", + "subtle", + "zeroize", +] + +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct 0.2.0", + "crypto-bigint 0.5.5", + "digest", + "ff 0.13.0", + "generic-array", + "group 0.13.0", + "pem-rfc7468", + "pkcs8 0.10.2", + "rand_core 0.6.4", + "sec1 0.7.3", "subtle", "zeroize", ] @@ -1951,6 +1995,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "filetime" version = "0.2.22" @@ -2148,6 +2202,7 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", ] [[package]] @@ -2214,7 +2269,18 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" dependencies = [ - "ff", + "ff 0.12.1", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff 0.13.0", "rand_core 0.6.4", "subtle", ] @@ -2776,6 +2842,42 @@ dependencies = [ "libc", ] +[[package]] +name = "jose-b64" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec69375368709666b21c76965ce67549f2d2db7605f1f8707d17c9656801b56" +dependencies = [ + "base64ct", + "serde", + "subtle", + "zeroize", +] + +[[package]] +name = "jose-jwa" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab78e053fe886a351d67cf0d194c000f9d0dcb92906eb34d853d7e758a4b3a7" +dependencies = [ + "serde", +] + +[[package]] +name = "jose-jwk" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280fa263807fe0782ecb6f2baadc28dffc04e00558a58e33bfdb801d11fd58e7" +dependencies = [ + "jose-b64", + "jose-jwa", + "p256 0.13.2", + "p384", + "rsa", + "serde", + "zeroize", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -2835,6 +2937,9 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] [[package]] name = "lazycell" @@ -3204,6 +3309,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.4" @@ -3481,11 +3603,33 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" dependencies = [ - "ecdsa", - "elliptic-curve", + "ecdsa 0.14.8", + "elliptic-curve 0.12.3", "sha2", ] +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa 0.16.9", + "elliptic-curve 0.13.8", + "primeorder", + "sha2", +] + +[[package]] +name = "p384" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70786f51bcc69f6a4c0360e063a4cac5419ef7c5cd5b3c99ad70f3be5ba79209" +dependencies = [ + "elliptic-curve 0.13.8", + "primeorder", +] + [[package]] name = "pagebench" version = "0.1.0" @@ -3847,6 +3991,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.2.0" @@ -3913,6 +4066,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der 0.7.8", + "pkcs8 0.10.2", + "spki 0.7.3", +] + [[package]] name = "pkcs8" version = "0.9.0" @@ -3923,6 +4087,16 @@ dependencies = [ "spki 0.6.0", ] +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der 0.7.8", + "spki 0.7.3", +] + [[package]] name = "pkg-config" version = "0.3.27" @@ -4116,6 +4290,15 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve 0.13.8", +] + [[package]] name = "proc-macro-hack" version = "0.5.20+deprecated" @@ -4233,6 +4416,7 @@ version = "0.1.0" dependencies = [ "ahash", "anyhow", + "arc-swap", "async-compression", "async-trait", "atomic-take", @@ -4250,6 +4434,7 @@ dependencies = [ "consumption_metrics", "crossbeam-deque", "dashmap", + "ecdsa 0.16.9", "env_logger", "fallible-iterator", "framed-websockets", @@ -4270,12 +4455,15 @@ dependencies = [ "indexmap 2.0.1", "ipnet", "itertools 0.10.5", + "jose-jwa", + "jose-jwk", "lasso", "md5", "measured", "metrics", "once_cell", "opentelemetry", + "p256 0.13.2", "parking_lot 0.12.1", "parquet", "parquet_derive", @@ -4296,6 +4484,7 @@ dependencies = [ "reqwest-retry", "reqwest-tracing", "routerify", + "rsa", "rstest", "rustc-hash", "rustls 0.22.4", @@ -4305,6 +4494,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "signature 2.2.0", "smallvec", "smol_str", "socket2 0.5.5", @@ -4807,6 +4997,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "ring" version = "0.16.20" @@ -4867,6 +5067,26 @@ dependencies = [ "archery", ] +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8 0.10.2", + "rand_core 0.6.4", + "signature 2.2.0", + "spki 0.7.3", + "subtle", + "zeroize", +] + [[package]] name = "rstest" version = "0.18.2" @@ -5195,10 +5415,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" dependencies = [ - "base16ct", + "base16ct 0.1.1", "der 0.6.1", "generic-array", - "pkcs8", + "pkcs8 0.9.0", + "subtle", + "zeroize", +] + +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct 0.2.0", + "der 0.7.8", + "generic-array", + "pkcs8 0.10.2", "subtle", "zeroize", ] @@ -5545,6 +5779,7 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ + "digest", "rand_core 0.6.4", ] @@ -7379,13 +7614,17 @@ dependencies = [ "clap", "clap_builder", "crossbeam-utils", + "crypto-bigint 0.5.5", + "der 0.7.8", "deranged", + "digest", "either", "fail", "futures-channel", "futures-executor", "futures-io", "futures-util", + "generic-array", "getrandom 0.2.11", "hashbrown 0.14.5", "hex", @@ -7393,6 +7632,7 @@ dependencies = [ "hyper 0.14.26", "indexmap 1.9.3", "itertools 0.10.5", + "lazy_static", "libc", "log", "memchr", @@ -7416,7 +7656,9 @@ dependencies = [ "serde", "serde_json", "sha2", + "signature 2.2.0", "smallvec", + "spki 0.7.3", "subtle", "syn 1.0.109", "syn 2.0.52", @@ -7527,6 +7769,7 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" dependencies = [ + "serde", "zeroize_derive", ] diff --git a/deny.toml b/deny.toml index dc985138e6..327ac58db7 100644 --- a/deny.toml +++ b/deny.toml @@ -22,7 +22,10 @@ feature-depth = 1 [advisories] db-urls = ["https://github.com/rustsec/advisory-db"] yanked = "warn" -ignore = [] + +[[advisories.ignore]] +id = "RUSTSEC-2023-0071" +reason = "the marvin attack only affects private key decryption, not public key signature verification" # This section is considered when running `cargo deny check licenses` # More documentation for the licenses section can be found here: diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index b316c53034..21d92abb20 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -11,6 +11,7 @@ testing = [] [dependencies] ahash.workspace = true anyhow.workspace = true +arc-swap.workspace = true async-compression.workspace = true async-trait.workspace = true atomic-take.workspace = true @@ -73,7 +74,7 @@ rustls.workspace = true scopeguard.workspace = true serde.workspace = true serde_json.workspace = true -sha2 = { workspace = true, features = ["asm"] } +sha2 = { workspace = true, features = ["asm", "oid"] } smol_str.workspace = true smallvec.workspace = true socket2.workspace = true @@ -103,6 +104,14 @@ x509-parser.workspace = true postgres-protocol.workspace = true redis.workspace = true +# jwt stuff +jose-jwa = "0.1.2" +jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] } +signature = "2" +ecdsa = "0.16" +p256 = "0.13" +rsa = "0.9" + workspace_hack.workspace = true [dev-dependencies] diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 90dea01bf3..c6a0b2af5a 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -1,5 +1,6 @@ mod classic; mod hacks; +pub mod jwt; mod link; use std::net::IpAddr; diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs new file mode 100644 index 0000000000..0c2ca8fb97 --- /dev/null +++ b/proxy/src/auth/backend/jwt.rs @@ -0,0 +1,556 @@ +use std::{future::Future, sync::Arc, time::Duration}; + +use anyhow::{bail, ensure, Context}; +use arc_swap::ArcSwapOption; +use dashmap::DashMap; +use jose_jwk::crypto::KeyInfo; +use signature::Verifier; +use tokio::time::Instant; + +use crate::{http::parse_json_body_with_limit, intern::EndpointIdInt}; + +// TODO(conrad): make these configurable. +const MIN_RENEW: Duration = Duration::from_secs(30); +const AUTO_RENEW: Duration = Duration::from_secs(300); +const MAX_RENEW: Duration = Duration::from_secs(3600); +const MAX_JWK_BODY_SIZE: usize = 64 * 1024; + +/// How to get the JWT auth rules +pub trait FetchAuthRules: Clone + Send + Sync + 'static { + fn fetch_auth_rules(&self) -> impl Future> + Send; +} + +#[derive(Clone)] +struct FetchAuthRulesFromCplane { + #[allow(dead_code)] + endpoint: EndpointIdInt, +} + +impl FetchAuthRules for FetchAuthRulesFromCplane { + async fn fetch_auth_rules(&self) -> anyhow::Result { + Err(anyhow::anyhow!("not yet implemented")) + } +} + +pub struct AuthRules { + jwks_urls: Vec, +} + +#[derive(Default)] +pub struct JwkCache { + client: reqwest::Client, + + map: DashMap>, +} + +pub struct JwkCacheEntryLock { + cached: ArcSwapOption, + lookup: tokio::sync::Semaphore, +} + +impl Default for JwkCacheEntryLock { + fn default() -> Self { + JwkCacheEntryLock { + cached: ArcSwapOption::empty(), + lookup: tokio::sync::Semaphore::new(1), + } + } +} + +pub struct JwkCacheEntry { + /// Should refetch at least every hour to verify when old keys have been removed. + /// Should refetch when new key IDs are seen only every 5 minutes or so + last_retrieved: Instant, + + /// cplane will return multiple JWKs urls that we need to scrape. + key_sets: ahash::HashMap, +} + +impl JwkCacheEntryLock { + async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { + JwkRenewalPermit::acquire_permit(self).await + } + + fn try_acquire_permit<'a>(self: &'a Arc) -> Option> { + JwkRenewalPermit::try_acquire_permit(self) + } + + async fn renew_jwks( + &self, + _permit: JwkRenewalPermit<'_>, + client: &reqwest::Client, + auth_rules: &F, + ) -> anyhow::Result> { + // double check that no one beat us to updating the cache. + let now = Instant::now(); + let guard = self.cached.load_full(); + if let Some(cached) = guard { + let last_update = now.duration_since(cached.last_retrieved); + if last_update < Duration::from_secs(300) { + return Ok(cached); + } + } + + let rules = auth_rules.fetch_auth_rules().await?; + let mut key_sets = ahash::HashMap::with_capacity_and_hasher( + rules.jwks_urls.len(), + ahash::RandomState::new(), + ); + // TODO(conrad): run concurrently + // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284) + for url in rules.jwks_urls { + let req = client.get(url.clone()); + // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. + match req.send().await.and_then(|r| r.error_for_status()) { + // todo: should we re-insert JWKs if we want to keep this JWKs URL? + // I expect these failures would be quite sparse. + Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"), + Ok(r) => { + let resp: http::Response = r.into(); + match parse_json_body_with_limit::( + resp.into_body(), + MAX_JWK_BODY_SIZE, + ) + .await + { + Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), + Ok(jwks) => { + key_sets.insert(url, jwks); + } + } + } + } + } + + let entry = Arc::new(JwkCacheEntry { + last_retrieved: now, + key_sets, + }); + self.cached.swap(Some(Arc::clone(&entry))); + + Ok(entry) + } + + async fn get_or_update_jwk_cache( + self: &Arc, + client: &reqwest::Client, + fetch: &F, + ) -> Result, anyhow::Error> { + let now = Instant::now(); + let guard = self.cached.load_full(); + + // if we have no cached JWKs, try and get some + let Some(cached) = guard else { + let permit = self.acquire_permit().await; + return self.renew_jwks(permit, client, fetch).await; + }; + + let last_update = now.duration_since(cached.last_retrieved); + + // check if the cached JWKs need updating. + if last_update > MAX_RENEW { + let permit = self.acquire_permit().await; + + // it's been too long since we checked the keys. wait for them to update. + return self.renew_jwks(permit, client, fetch).await; + } + + // every 5 minutes we should spawn a job to eagerly update the token. + if last_update > AUTO_RENEW { + if let Some(permit) = self.try_acquire_permit() { + tracing::debug!("JWKs should be renewed. Renewal permit acquired"); + let permit = permit.into_owned(); + let entry = self.clone(); + let client = client.clone(); + let fetch = fetch.clone(); + tokio::spawn(async move { + if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await { + tracing::warn!(error=?e, "could not fetch JWKs in background job"); + } + }); + } else { + tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping"); + } + } + + Ok(cached) + } + + async fn check_jwt( + self: &Arc, + jwt: String, + client: &reqwest::Client, + fetch: &F, + ) -> Result<(), anyhow::Error> { + // JWT compact form is defined to be + // || . || || . || + // where Signature = alg( || . || ); + + let (header_payload, signature) = jwt + .rsplit_once(".") + .context("Provided authentication token is not a valid JWT encoding")?; + let (header, _payload) = header_payload + .split_once(".") + .context("Provided authentication token is not a valid JWT encoding")?; + + let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD) + .context("Provided authentication token is not a valid JWT encoding")?; + let header = serde_json::from_slice::(&header) + .context("Provided authentication token is not a valid JWT encoding")?; + + let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD) + .context("Provided authentication token is not a valid JWT encoding")?; + + ensure!(header.typ == "JWT"); + let kid = header.kid.context("missing key id")?; + + let mut guard = self.get_or_update_jwk_cache(client, fetch).await?; + + // get the key from the JWKs if possible. If not, wait for the keys to update. + let jwk = loop { + let jwk = guard + .key_sets + .values() + .flat_map(|jwks| &jwks.keys) + .find(|jwk| jwk.prm.kid.as_deref() == Some(kid)); + + match jwk { + Some(jwk) => break jwk, + None if guard.last_retrieved.elapsed() > MIN_RENEW => { + let permit = self.acquire_permit().await; + guard = self.renew_jwks(permit, client, fetch).await?; + } + _ => { + bail!("jwk not found"); + } + } + }; + + ensure!( + jwk.is_supported(&header.alg), + "signature algorithm not supported" + ); + + match &jwk.key { + jose_jwk::Key::Ec(key) => { + verify_ec_signature(header_payload.as_bytes(), &sig, key)?; + } + jose_jwk::Key::Rsa(key) => { + verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?; + } + key => bail!("unsupported key type {key:?}"), + }; + + // TODO(conrad): verify iss, exp, nbf, etc... + + Ok(()) + } +} + +impl JwkCache { + pub async fn check_jwt( + &self, + endpoint: EndpointIdInt, + jwt: String, + ) -> Result<(), anyhow::Error> { + // try with just a read lock first + let entry = self.map.get(&endpoint).as_deref().map(Arc::clone); + let entry = match entry { + Some(entry) => entry, + None => { + // acquire a write lock after to insert. + let entry = self.map.entry(endpoint).or_default(); + Arc::clone(&*entry) + } + }; + + let fetch = FetchAuthRulesFromCplane { endpoint }; + entry.check_jwt(jwt, &self.client, &fetch).await + } +} + +fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> { + use ecdsa::Signature; + use signature::Verifier; + + match key.crv { + jose_jwk::EcCurves::P256 => { + let pk = + p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?; + let key = p256::ecdsa::VerifyingKey::from(&pk); + let sig = Signature::from_slice(sig)?; + key.verify(data, &sig)?; + } + key => bail!("unsupported ec key type {key:?}"), + } + + Ok(()) +} + +fn verify_rsa_signature( + data: &[u8], + sig: &[u8], + key: &jose_jwk::Rsa, + alg: &Option, +) -> anyhow::Result<()> { + use jose_jwa::{Algorithm, Signing}; + use rsa::{ + pkcs1v15::{Signature, VerifyingKey}, + RsaPublicKey, + }; + + let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?; + + match alg { + Some(Algorithm::Signing(Signing::Rs256)) => { + let key = VerifyingKey::::new(key); + let sig = Signature::try_from(sig)?; + key.verify(data, &sig)?; + } + _ => bail!("invalid RSA signing algorithm"), + }; + + Ok(()) +} + +/// +#[derive(serde::Deserialize, serde::Serialize)] +struct JWTHeader<'a> { + /// must be "JWT" + typ: &'a str, + /// must be a supported alg + alg: jose_jwa::Algorithm, + /// key id, must be provided for our usecase + kid: Option<&'a str>, +} + +struct JwkRenewalPermit<'a> { + inner: Option>, +} + +enum JwkRenewalPermitInner<'a> { + Owned(Arc), + Borrowed(&'a Arc), +} + +impl JwkRenewalPermit<'_> { + fn into_owned(mut self) -> JwkRenewalPermit<'static> { + JwkRenewalPermit { + inner: self.inner.take().map(JwkRenewalPermitInner::into_owned), + } + } + + async fn acquire_permit(from: &Arc) -> JwkRenewalPermit { + match from.lookup.acquire().await { + Ok(permit) => { + permit.forget(); + JwkRenewalPermit { + inner: Some(JwkRenewalPermitInner::Borrowed(from)), + } + } + Err(_) => panic!("semaphore should not be closed"), + } + } + + fn try_acquire_permit(from: &Arc) -> Option { + match from.lookup.try_acquire() { + Ok(permit) => { + permit.forget(); + Some(JwkRenewalPermit { + inner: Some(JwkRenewalPermitInner::Borrowed(from)), + }) + } + Err(tokio::sync::TryAcquireError::NoPermits) => None, + Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"), + } + } +} + +impl JwkRenewalPermitInner<'_> { + fn into_owned(self) -> JwkRenewalPermitInner<'static> { + match self { + JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p), + JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)), + } + } +} + +impl Drop for JwkRenewalPermit<'_> { + fn drop(&mut self) { + let entry = match &self.inner { + None => return, + Some(JwkRenewalPermitInner::Owned(p)) => p, + Some(JwkRenewalPermitInner::Borrowed(p)) => *p, + }; + entry.lookup.add_permits(1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::{future::IntoFuture, net::SocketAddr, time::SystemTime}; + + use base64::URL_SAFE_NO_PAD; + use bytes::Bytes; + use http::Response; + use http_body_util::Full; + use hyper1::service::service_fn; + use hyper_util::rt::TokioIo; + use rand::rngs::OsRng; + use signature::Signer; + use tokio::net::TcpListener; + + fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) { + let sk = p256::SecretKey::random(&mut OsRng); + let pk = sk.public_key().into(); + let jwk = jose_jwk::Jwk { + key: jose_jwk::Key::Ec(pk), + prm: jose_jwk::Parameters { + kid: Some(kid), + alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)), + ..Default::default() + }, + }; + (sk, jwk) + } + + fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) { + let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap(); + let pk = sk.to_public_key().into(); + let jwk = jose_jwk::Jwk { + key: jose_jwk::Key::Rsa(pk), + prm: jose_jwk::Parameters { + kid: Some(kid), + alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)), + ..Default::default() + }, + }; + (sk, jwk) + } + + fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { + let header = JWTHeader { + typ: "JWT", + alg: jose_jwa::Algorithm::Signing(sig), + kid: Some(&kid), + }; + let body = typed_json::json! {{ + "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600, + }}; + + let header = + base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD); + let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD); + + format!("{header}.{body}") + } + + fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String { + use p256::ecdsa::{Signature, SigningKey}; + + let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256); + let sig: Signature = SigningKey::from(key).sign(payload.as_bytes()); + let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + + format!("{payload}.{sig}") + } + + fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String { + use rsa::pkcs1v15::SigningKey; + use rsa::signature::SignatureEncoding; + + let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256); + let sig = SigningKey::::new(key).sign(payload.as_bytes()); + let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + + format!("{payload}.{sig}") + } + + #[tokio::test] + async fn renew() { + let (rs1, jwk1) = new_rsa_jwk("1".into()); + let (rs2, jwk2) = new_rsa_jwk("2".into()); + let (ec1, jwk3) = new_ec_jwk("3".into()); + let (ec2, jwk4) = new_ec_jwk("4".into()); + + let jwt1 = new_rsa_jwt("1".into(), rs1); + let jwt2 = new_rsa_jwt("2".into(), rs2); + let jwt3 = new_ec_jwt("3".into(), ec1); + let jwt4 = new_ec_jwt("4".into(), ec2); + + let foo_jwks = jose_jwk::JwkSet { + keys: vec![jwk1, jwk3], + }; + let bar_jwks = jose_jwk::JwkSet { + keys: vec![jwk2, jwk4], + }; + + let service = service_fn(move |req| { + let foo_jwks = foo_jwks.clone(); + let bar_jwks = bar_jwks.clone(); + async move { + let jwks = match req.uri().path() { + "/foo" => &foo_jwks, + "/bar" => &bar_jwks, + _ => { + return Response::builder() + .status(404) + .body(Full::new(Bytes::new())); + } + }; + let body = serde_json::to_vec(jwks).unwrap(); + Response::builder() + .status(200) + .body(Full::new(Bytes::from(body))) + } + }); + + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let server = hyper1::server::conn::http1::Builder::new(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + loop { + let (s, _) = listener.accept().await.unwrap(); + let serve = server.serve_connection(TokioIo::new(s), service.clone()); + tokio::spawn(serve.into_future()); + } + }); + + let client = reqwest::Client::new(); + + #[derive(Clone)] + struct Fetch(SocketAddr); + + impl FetchAuthRules for Fetch { + async fn fetch_auth_rules(&self) -> anyhow::Result { + Ok(AuthRules { + jwks_urls: vec![ + format!("http://{}/foo", self.0).parse().unwrap(), + format!("http://{}/bar", self.0).parse().unwrap(), + ], + }) + } + } + + let jwk_cache = Arc::new(JwkCacheEntryLock::default()); + + jwk_cache + .check_jwt(jwt1, &client, &Fetch(addr)) + .await + .unwrap(); + jwk_cache + .check_jwt(jwt2, &client, &Fetch(addr)) + .await + .unwrap(); + jwk_cache + .check_jwt(jwt3, &client, &Fetch(addr)) + .await + .unwrap(); + jwk_cache + .check_jwt(jwt4, &client, &Fetch(addr)) + .await + .unwrap(); + } +} diff --git a/proxy/src/http.rs b/proxy/src/http.rs index dd7164181d..1f1dd8c415 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -6,6 +6,12 @@ pub mod health_server; use std::time::Duration; +use anyhow::bail; +use bytes::Bytes; +use http_body_util::BodyExt; +use hyper1::body::Body; +use serde::de::DeserializeOwned; + pub use reqwest::{Request, Response, StatusCode}; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; @@ -96,6 +102,33 @@ impl Endpoint { } } +pub async fn parse_json_body_with_limit( + mut b: impl Body + Unpin, + limit: usize, +) -> anyhow::Result { + // We could use `b.limited().collect().await.to_bytes()` here + // but this ends up being slightly more efficient as far as I can tell. + + // check the lower bound of the size hint. + // in reqwest, this value is influenced by the Content-Length header. + let lower_bound = match usize::try_from(b.size_hint().lower()) { + Ok(bound) if bound <= limit => bound, + _ => bail!("Content length exceeds limit of {limit} bytes"), + }; + let mut bytes = Vec::with_capacity(lower_bound); + + while let Some(frame) = b.frame().await.transpose()? { + if let Ok(data) = frame.into_data() { + if bytes.len() + data.len() > limit { + bail!("Content length exceeds limit of {limit} bytes") + } + bytes.extend_from_slice(&data); + } + } + + Ok(serde_json::from_slice::(&bytes)?) +} + #[cfg(test)] mod tests { use super::*; diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 832fe06bf6..2d9b372654 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -30,13 +30,17 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd clap = { version = "4", features = ["derive", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } +crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] } +der = { version = "0.7", default-features = false, features = ["oid", "pem", "std"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } +digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures-channel = { version = "0.3", features = ["sink"] } futures-executor = { version = "0.3" } futures-io = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } +generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } @@ -44,6 +48,7 @@ hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } +lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } memchr = { version = "2" } @@ -64,8 +69,10 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] } scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["raw_value"] } -sha2 = { version = "0.10", features = ["asm"] } +sha2 = { version = "0.10", features = ["asm", "oid"] } +signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] } smallvec = { version = "1", default-features = false, features = ["const_new", "write"] } +spki = { version = "0.7", default-features = false, features = ["pem", "std"] } subtle = { version = "2" } sync_wrapper = { version = "0.1", default-features = false, features = ["futures"] } tikv-jemalloc-sys = { version = "0.5" } @@ -81,7 +88,7 @@ tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } url = { version = "2", features = ["serde"] } uuid = { version = "1", features = ["serde", "v4", "v7"] } -zeroize = { version = "1", features = ["derive"] } +zeroize = { version = "1", features = ["derive", "serde"] } zstd = { version = "0.13" } zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] } zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] } @@ -97,6 +104,7 @@ getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown = { version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } +lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } memchr = { version = "2" }