From f3f7d0d3f19c181f43ce5e846af3b29c6735a9fe Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 30 Sep 2024 12:47:07 +0100 Subject: [PATCH] zero-copy jwt claim validation --- proxy/src/auth/backend/jwt.rs | 381 +++++++++++++++++++++++++++++----- 1 file changed, 324 insertions(+), 57 deletions(-) diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 3a9fcc8992..8f8f949cb2 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, future::Future, sync::Arc, time::{Duration, SystemTime}, @@ -8,7 +9,10 @@ use anyhow::{bail, ensure, Context}; use arc_swap::ArcSwapOption; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; -use serde::{Deserialize, Deserializer}; +use serde::{ + de::{DeserializeSeed, IgnoredAny, Visitor}, + Deserializer, +}; use signature::Verifier; use tokio::time::Instant; @@ -304,35 +308,21 @@ impl JwkCacheEntryLock { } key => bail!("unsupported key type {key:?}"), }; + tracing::debug!("JWT signature valid"); let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; - let payload = serde_json::from_slice::>(&payload) - .context("Provided authentication token is not a valid JWT encoding")?; - tracing::debug!(?payload, "JWT signature valid with claims"); + let validator = JwtValidator { + expected_audience, + current_time: SystemTime::now(), + clock_skew_leeway: CLOCK_SKEW_LEEWAY, + }; - match (expected_audience, payload.audience) { - // check the audience matches - (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"), - // the audience is expected but is missing - (Some(_), None) => bail!("invalid JWT token audience"), - // we don't care for the audience field - (None, _) => {} - } + let payload = validator + .deserialize(&mut serde_json::Deserializer::from_slice(&payload))?; - let now = SystemTime::now(); - - if let Some(exp) = payload.expiration { - ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired"); - } - - if let Some(nbf) = payload.not_before { - ensure!( - nbf < now + CLOCK_SKEW_LEEWAY, - "JWT token is not yet ready to use" - ); - } + tracing::debug!(?payload, "JWT claims valid"); Ok(()) } @@ -420,37 +410,184 @@ struct JwtHeader<'a> { key_id: Option<&'a str>, } -/// -#[derive(serde::Deserialize, serde::Serialize, Debug)] -struct JwtPayload<'a> { - /// Audience - Recipient for which the JWT is intended - #[serde(rename = "aud")] - audience: Option<&'a str>, - /// Expiration - Time after which the JWT expires - #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)] - expiration: Option, - /// Not before - Time after which the JWT expires - #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)] - not_before: Option, - - // the following entries are only extracted for the sake of debug logging. - /// Issuer of the JWT - #[serde(rename = "iss")] - issuer: Option<&'a str>, - /// Subject of the JWT (the user) - #[serde(rename = "sub")] - subject: Option<&'a str>, - /// Unique token identifier - #[serde(rename = "jti")] - jwt_id: Option<&'a str>, - /// Unique session identifier - #[serde(rename = "sid")] - session_id: Option<&'a str>, +struct JwtValidator<'a> { + expected_audience: Option<&'a str>, + current_time: SystemTime, + clock_skew_leeway: Duration, } -fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { - let d = >::deserialize(d)?; - Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n))) +impl<'de> DeserializeSeed<'de> for JwtValidator<'_> { + type Value = JwtPayload<'de>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + impl<'de> Visitor<'de> for JwtValidator<'_> { + type Value = JwtPayload<'de>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a JWT payload") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut payload = JwtPayload { + issuer: None, + subject: None, + jwt_id: None, + session_id: None, + }; + + let mut aud = false; + + while let Some(key) = map.next_key()? { + match key { + "iss" if payload.issuer.is_none() => { + payload.issuer = Some(map.next_value()?); + } + "sub" if payload.subject.is_none() => { + payload.subject = Some(map.next_value()?); + } + "jit" if payload.jwt_id.is_none() => { + payload.jwt_id = Some(map.next_value()?); + } + "sid" if payload.session_id.is_none() => { + payload.session_id = Some(map.next_value()?); + } + "exp" => { + let exp = map.next_value::()?; + let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp); + + if self.current_time > exp + self.clock_skew_leeway { + return Err(serde::de::Error::custom("JWT token has expired")); + } + } + "nbf" => { + let nbf = map.next_value::()?; + let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf); + + if self.current_time + self.clock_skew_leeway < nbf { + return Err(serde::de::Error::custom( + "JWT token is not yet ready to use", + )); + } + } + "aud" => { + if let Some(expected_audience) = self.expected_audience { + map.next_value_seed(AudienceValidator { expected_audience })?; + aud = true; + } else { + map.next_value::()?; + } + } + _ => map.next_value::().map(|IgnoredAny| ())?, + } + } + + if self.expected_audience.is_some() && !aud { + return Err(serde::de::Error::custom("invalid JWT token audience")); + } + + Ok(payload) + } + } + + deserializer.deserialize_map(self) + } +} + +struct AudienceValidator<'a> { + expected_audience: &'a str, +} + +impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + impl<'de> Visitor<'de> for AudienceValidator<'_> { + type Value = (); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a single string or an array of strings") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + if self.expected_audience == v { + Ok(()) + } else { + Err(E::custom("invalid JWT token audience")) + } + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + while let Some(v) = seq.next_element_seed(SingleAudienceValidator { + expected_audience: self.expected_audience, + })? { + if v { + return Ok(()); + } + } + Err(serde::de::Error::custom("invalid JWT token audience")) + } + } + deserializer.deserialize_any(self) + } +} + +struct SingleAudienceValidator<'a> { + expected_audience: &'a str, +} + +impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> { + type Value = bool; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + impl<'de> Visitor<'de> for SingleAudienceValidator<'_> { + type Value = bool; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a single audience string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(self.expected_audience == v) + } + } + deserializer.deserialize_any(self) + } +} + +/// +// the following entries are only extracted for the sake of debug logging. +#[derive(Debug)] +#[allow(dead_code)] +struct JwtPayload<'a> { + /// Issuer of the JWT + issuer: Option>, + /// Subject of the JWT (the user) + subject: Option>, + /// Unique token identifier + jwt_id: Option>, + /// Unique session identifier + session_id: Option>, } struct JwkRenewalPermit<'a> { @@ -531,6 +668,8 @@ mod tests { use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use rsa::pkcs8::DecodePrivateKey; + use serde::Serialize; + use serde_json::json; use signature::Signer; use tokio::net::TcpListener; @@ -563,18 +702,36 @@ mod tests { } fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + let body = typed_json::json! {{ + "exp": now + 3600, + "nbf": now, + "aud": ["audience1", "neon", "audience2"], + "sub": "user1", + "sid": "session1", + "jti": "token1", + "iss": "neon-testing", + }}; + build_custom_jwt_payload(kid, body, sig) + } + + fn build_custom_jwt_payload( + kid: String, + body: impl Serialize, + sig: jose_jwa::Signing, + ) -> String { let header = JwtHeader { typ: "JWT", algorithm: jose_jwa::Algorithm::Signing(sig), key_id: Some(&kid), }; - let body = typed_json::json! {{ - "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600, - }}; let header = base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD); - let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD); + let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD); format!("{header}.{body}") } @@ -589,6 +746,16 @@ mod tests { format!("{payload}.{sig}") } + fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String { + use p256::ecdsa::{Signature, SigningKey}; + + let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256); + let sig: Signature = SigningKey::from(key).sign(payload.as_bytes()); + let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + + format!("{payload}.{sig}") + } + fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String { use rsa::pkcs1v15::SigningKey; use rsa::signature::SignatureEncoding; @@ -855,4 +1022,104 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL "expected \"jwk not found\", got {err:?}" ); } + + #[tokio::test] + async fn check_jwt_invalid_claims() { + let (key, jwk) = new_ec_jwk("1".into()); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + + struct Test { + body: serde_json::Value, + error: &'static str, + } + + let table = vec![ + Test { + body: json! {{ + "nbf": now + 60, + "aud": "neon", + }}, + error: "JWT token is not yet ready to use", + }, + Test { + body: json! {{ + "exp": now - 60, + "aud": ["neon"], + }}, + error: "JWT token has expired", + }, + Test { + body: json! {{ + }}, + error: "invalid JWT token audience", + }, + Test { + body: json! {{ + "aud": [], + }}, + error: "invalid JWT token audience", + }, + Test { + body: json! {{ + "aud": "foo", + }}, + error: "invalid JWT token audience", + }, + Test { + body: json! {{ + "aud": ["foo"], + }}, + error: "invalid JWT token audience", + }, + Test { + body: json! {{ + "aud": ["foo", "bar"], + }}, + error: "invalid JWT token audience", + }, + ]; + + let role = RoleName::from("authenticated"); + + let rules = vec![AuthRule { + id: String::new(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: Some("neon".to_string()), + role_names: vec![RoleNameInt::from(&role)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let ep = EndpointId::from("ep"); + + let ctx = RequestMonitoring::test(); + for test in table { + let jwt = new_custom_ec_jwt("1".into(), &key, test.body); + + match jwk_cache + .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt) + .await + { + Err(err) if err.to_string().contains(test.error) => {} + Err(err) => { + panic!("expected {:?}, got {err:?}", test.error) + } + Ok(()) => { + panic!("expected {:?}, got ok", test.error) + } + } + } + } }