From 47c35f67c392a9642a4f0ccaeb326a53913449e4 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 29 Oct 2024 11:01:09 +0000 Subject: [PATCH] [proxy]: fix JWT handling for AWS cognito. (#9536) In the base64 payload of an aws cognito jwt, I saw the following: ``` "iss":"https:\/\/cognito-idp.us-west-2.amazonaws.com\/us-west-2_redacted" ``` issuers are supposed to be URLs, and URLs are always valid un-escaped JSON. However, `\/` is a valid escape character so what AWS is doing is technically correct... sigh... This PR refactors the test suite and adds a new regression test for cognito. --- proxy/src/auth/backend/jwt.rs | 508 +++++++++++++++++++++++++--------- 1 file changed, 383 insertions(+), 125 deletions(-) diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 69ab4b8ccb..83c3617612 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::future::Future; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -45,6 +46,7 @@ pub(crate) enum FetchAuthRulesError { RoleJwksNotConfigured, } +#[derive(Clone)] pub(crate) struct AuthRule { pub(crate) id: String, pub(crate) jwks_url: url::Url, @@ -277,7 +279,7 @@ impl JwkCacheEntryLock { // get the key from the JWKs if possible. If not, wait for the keys to update. let (jwk, expected_audience) = loop { - match guard.find_jwk_and_audience(kid, role_name) { + match guard.find_jwk_and_audience(&kid, role_name) { Some(jwk) => break jwk, None if guard.last_retrieved.elapsed() > MIN_RENEW => { let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); @@ -312,7 +314,9 @@ impl JwkCacheEntryLock { if let Some(aud) = expected_audience { if payload.audience.0.iter().all(|s| s != aud) { - return Err(JwtError::InvalidJwtTokenAudience); + return Err(JwtError::InvalidClaims( + JwtClaimsError::InvalidJwtTokenAudience, + )); } } @@ -320,13 +324,15 @@ impl JwkCacheEntryLock { if let Some(exp) = payload.expiration { if now >= exp + CLOCK_SKEW_LEEWAY { - return Err(JwtError::JwtTokenHasExpired); + return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired)); } } if let Some(nbf) = payload.not_before { if nbf >= now + CLOCK_SKEW_LEEWAY { - return Err(JwtError::JwtTokenNotYetReadyToUse); + return Err(JwtError::InvalidClaims( + JwtClaimsError::JwtTokenNotYetReadyToUse, + )); } } @@ -420,8 +426,8 @@ struct JwtHeader<'a> { #[serde(rename = "alg")] algorithm: jose_jwa::Algorithm, /// key id, must be provided for our usecase - #[serde(rename = "kid")] - key_id: Option<&'a str>, + #[serde(rename = "kid", borrow)] + key_id: Option>, } /// @@ -440,17 +446,17 @@ struct JwtPayload<'a> { // the following entries are only extracted for the sake of debug logging. /// Issuer of the JWT - #[serde(rename = "iss")] - issuer: Option<&'a str>, + #[serde(rename = "iss", borrow)] + issuer: Option>, /// Subject of the JWT (the user) - #[serde(rename = "sub")] - subject: Option<&'a str>, + #[serde(rename = "sub", borrow)] + subject: Option>, /// Unique token identifier - #[serde(rename = "jti")] - jwt_id: Option<&'a str>, + #[serde(rename = "jti", borrow)] + jwt_id: Option>, /// Unique session identifier - #[serde(rename = "sid")] - session_id: Option<&'a str>, + #[serde(rename = "sid", borrow)] + session_id: Option>, } /// `OneOrMany` supports parsing either a single item or an array of items. @@ -585,14 +591,8 @@ pub(crate) enum JwtError { #[error("Provided authentication token is not a valid JWT encoding")] JwtEncoding(#[from] JwtEncodingError), - #[error("invalid JWT token audience")] - InvalidJwtTokenAudience, - - #[error("JWT token has expired")] - JwtTokenHasExpired, - - #[error("JWT token is not yet ready to use")] - JwtTokenNotYetReadyToUse, + #[error(transparent)] + InvalidClaims(#[from] JwtClaimsError), #[error("invalid P256 key")] InvalidP256Key(jose_jwk::crypto::Error), @@ -644,6 +644,19 @@ pub enum JwtEncodingError { InvalidCompactForm, } +#[derive(Error, Debug, PartialEq)] +#[non_exhaustive] +pub enum JwtClaimsError { + #[error("invalid JWT token audience")] + InvalidJwtTokenAudience, + + #[error("JWT token has expired")] + JwtTokenHasExpired, + + #[error("JWT token is not yet ready to use")] + JwtTokenNotYetReadyToUse, +} + #[allow(dead_code, reason = "Debug use only")] #[derive(Debug)] pub(crate) enum KeyType { @@ -680,6 +693,8 @@ mod tests { use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use rsa::pkcs8::DecodePrivateKey; + use serde::Serialize; + use serde_json::json; use signature::Signer; use tokio::net::TcpListener; @@ -693,6 +708,7 @@ mod tests { key: jose_jwk::Key::Ec(pk), prm: jose_jwk::Parameters { kid: Some(kid), + alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)), ..Default::default() }, }; @@ -706,24 +722,47 @@ mod tests { key: jose_jwk::Key::Rsa(pk), prm: jose_jwk::Parameters { kid: Some(kid), + alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)), ..Default::default() }, }; (sk, jwk) } + fn now() -> u64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + } + fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { + let now = now(); + let body = typed_json::json! {{ + "exp": now + 3600, + "nbf": now, + "aud": ["audience1", "neon", "audience2"], + "sub": "user1", + "sid": "session1", + "jti": "token1", + "iss": "neon-testing", + }}; + build_custom_jwt_payload(kid, body, sig) + } + + fn build_custom_jwt_payload( + kid: String, + body: impl Serialize, + sig: jose_jwa::Signing, + ) -> String { let header = JwtHeader { algorithm: jose_jwa::Algorithm::Signing(sig), - key_id: Some(&kid), + key_id: Some(Cow::Owned(kid)), }; - let body = typed_json::json! {{ - "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600, - }}; let header = base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD); - let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD); + let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD); format!("{header}.{body}") } @@ -738,6 +777,16 @@ mod tests { format!("{payload}.{sig}") } + fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String { + use p256::ecdsa::{Signature, SigningKey}; + + let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256); + let sig: Signature = SigningKey::from(key).sign(payload.as_bytes()); + let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + + format!("{payload}.{sig}") + } + fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String { use rsa::pkcs1v15::SigningKey; use rsa::signature::SignatureEncoding; @@ -809,37 +858,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL -----END PRIVATE KEY----- "; - #[tokio::test] - async fn renew() { - let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into()); - let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into()); - let (ec1, jwk3) = new_ec_jwk("3".into()); - let (ec2, jwk4) = new_ec_jwk("4".into()); + #[derive(Clone)] + struct Fetch(Vec); - let foo_jwks = jose_jwk::JwkSet { - keys: vec![jwk1, jwk3], - }; - let bar_jwks = jose_jwk::JwkSet { - keys: vec![jwk2, jwk4], - }; + impl FetchAuthRules for Fetch { + async fn fetch_auth_rules( + &self, + _ctx: &RequestMonitoring, + _endpoint: EndpointId, + ) -> Result, FetchAuthRulesError> { + Ok(self.0.clone()) + } + } + async fn jwks_server( + router: impl for<'a> Fn(&'a str) -> Option> + Send + Sync + 'static, + ) -> SocketAddr { + let router = Arc::new(router); let service = service_fn(move |req| { - let foo_jwks = foo_jwks.clone(); - let bar_jwks = bar_jwks.clone(); + let router = Arc::clone(&router); async move { - let jwks = match req.uri().path() { - "/foo" => &foo_jwks, - "/bar" => &bar_jwks, - _ => { - return Response::builder() - .status(404) - .body(Full::new(Bytes::new())); - } - }; - let body = serde_json::to_vec(jwks).unwrap(); - Response::builder() - .status(200) - .body(Full::new(Bytes::from(body))) + match router(req.uri().path()) { + Some(body) => Response::builder() + .status(200) + .body(Full::new(Bytes::from(body))), + None => Response::builder() + .status(404) + .body(Full::new(Bytes::new())), + } } }); @@ -854,84 +900,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL } }); - let client = reqwest::Client::new(); + addr + } - #[derive(Clone)] - struct Fetch(SocketAddr, Vec); + #[tokio::test] + async fn check_jwt_happy_path() { + let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into()); + let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into()); + let (ec1, jwk3) = new_ec_jwk("ec1".into()); + let (ec2, jwk4) = new_ec_jwk("ec2".into()); - impl FetchAuthRules for Fetch { - async fn fetch_auth_rules( - &self, - _ctx: &RequestMonitoring, - _endpoint: EndpointId, - ) -> Result, FetchAuthRulesError> { - Ok(vec![ - AuthRule { - id: "foo".to_owned(), - jwks_url: format!("http://{}/foo", self.0).parse().unwrap(), - audience: None, - role_names: self.1.clone(), - }, - AuthRule { - id: "bar".to_owned(), - jwks_url: format!("http://{}/bar", self.0).parse().unwrap(), - audience: None, - role_names: self.1.clone(), - }, - ]) - } - } + let foo_jwks = jose_jwk::JwkSet { + keys: vec![jwk1, jwk3], + }; + let bar_jwks = jose_jwk::JwkSet { + keys: vec![jwk2, jwk4], + }; + + let jwks_addr = jwks_server(move |path| match path { + "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()), + "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()), + _ => None, + }) + .await; let role_name1 = RoleName::from("anonymous"); let role_name2 = RoleName::from("authenticated"); - let fetch = Fetch( - addr, - vec![ - RoleNameInt::from(&role_name1), - RoleNameInt::from(&role_name2), - ], - ); + let roles = vec![ + RoleNameInt::from(&role_name1), + RoleNameInt::from(&role_name2), + ]; + let rules = vec![ + AuthRule { + id: "foo".to_owned(), + jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(), + audience: None, + role_names: roles.clone(), + }, + AuthRule { + id: "bar".to_owned(), + jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(), + audience: None, + role_names: roles.clone(), + }, + ]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); let endpoint = EndpointId::from("ep"); - let jwk_cache = Arc::new(JwkCacheEntryLock::default()); - - let jwt1 = new_rsa_jwt("1".into(), rs1); - let jwt2 = new_rsa_jwt("2".into(), rs2); - let jwt3 = new_ec_jwt("3".into(), &ec1); - let jwt4 = new_ec_jwt("4".into(), &ec2); - - // had the wrong kid, therefore will have the wrong ecdsa signature - let bad_jwt = new_ec_jwt("3".into(), &ec2); - // this role_name is not accepted - let bad_role_name = RoleName::from("cloud_admin"); - - let err = jwk_cache - .check_jwt( - &RequestMonitoring::test(), - &bad_jwt, - &client, - endpoint.clone(), - &role_name1, - &fetch, - ) - .await - .unwrap_err(); - assert!(err.to_string().contains("signature error")); - - let err = jwk_cache - .check_jwt( - &RequestMonitoring::test(), - &jwt1, - &client, - endpoint.clone(), - &bad_role_name, - &fetch, - ) - .await - .unwrap_err(); - assert!(err.to_string().contains("jwk not found")); + let jwt1 = new_rsa_jwt("rs1".into(), rs1); + let jwt2 = new_rsa_jwt("rs2".into(), rs2); + let jwt3 = new_ec_jwt("ec1".into(), &ec1); + let jwt4 = new_ec_jwt("ec2".into(), &ec2); let tokens = [jwt1, jwt2, jwt3, jwt4]; let role_names = [role_name1, role_name2]; @@ -940,15 +963,250 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL jwk_cache .check_jwt( &RequestMonitoring::test(), - token, - &client, endpoint.clone(), role, &fetch, + token, ) .await .unwrap(); } } } + + /// AWS Cognito escapes the `/` in the URL. + #[tokio::test] + async fn check_jwt_regression_cognito_issuer() { + let (key, jwk) = new_ec_jwk("key".into()); + + let now = now(); + let token = new_custom_ec_jwt( + "key".into(), + &key, + typed_json::json! {{ + "sub": "dd9a73fd-e785-4a13-aae1-e691ce43e89d", + // cognito uses `\/`. I cannot replicated that easily here as serde_json will refuse + // to write that escape character. instead I will make a bogus URL using `\` instead. + "iss": "https:\\\\cognito-idp.us-west-2.amazonaws.com\\us-west-2_abcdefgh", + "client_id": "abcdefghijklmnopqrstuvwxyz", + "origin_jti": "6759d132-3fe7-446e-9e90-2fe7e8017893", + "event_id": "ec9c36ab-b01d-46a0-94e4-87fde6767065", + "token_use": "access", + "scope": "aws.cognito.signin.user.admin", + "auth_time":now, + "exp":now + 60, + "iat":now, + "jti": "b241614b-0b93-4bdc-96db-0a3c7061d9c0", + "username": "dd9a73fd-e785-4a13-aae1-e691ce43e89d", + }}, + ); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + + let jwks_addr = jwks_server(move |_path| Some(serde_json::to_vec(&jwks).unwrap())).await; + + let role_name = RoleName::from("anonymous"); + let rules = vec![AuthRule { + id: "aws-cognito".to_owned(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![RoleNameInt::from(&role_name)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let endpoint = EndpointId::from("ep"); + + jwk_cache + .check_jwt( + &RequestMonitoring::test(), + endpoint.clone(), + &role_name, + &fetch, + &token, + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn check_jwt_invalid_signature() { + let (_, jwk) = new_ec_jwk("1".into()); + let (key, _) = new_ec_jwk("1".into()); + + // has a matching kid, but signed by the wrong key + let bad_jwt = new_ec_jwt("1".into(), &key); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let role = RoleName::from("authenticated"); + + let rules = vec![AuthRule { + id: String::new(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![RoleNameInt::from(&role)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let ep = EndpointId::from("ep"); + + let ctx = RequestMonitoring::test(); + let err = jwk_cache + .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt) + .await + .unwrap_err(); + assert!( + matches!(err, JwtError::Signature(_)), + "expected \"signature error\", got {err:?}" + ); + } + + #[tokio::test] + async fn check_jwt_unknown_role() { + let (key, jwk) = new_rsa_jwk(RS1, "1".into()); + let jwt = new_rsa_jwt("1".into(), key); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let role = RoleName::from("authenticated"); + let rules = vec![AuthRule { + id: String::new(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![RoleNameInt::from(&role)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let ep = EndpointId::from("ep"); + + // this role_name is not accepted + let bad_role_name = RoleName::from("cloud_admin"); + + let ctx = RequestMonitoring::test(); + let err = jwk_cache + .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt) + .await + .unwrap_err(); + + assert!( + matches!(err, JwtError::JwkNotFound), + "expected \"jwk not found\", got {err:?}" + ); + } + + #[tokio::test] + async fn check_jwt_invalid_claims() { + let (key, jwk) = new_ec_jwk("1".into()); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + + struct Test { + body: serde_json::Value, + error: JwtClaimsError, + } + + let table = vec![ + Test { + body: json! {{ + "nbf": now + 60, + "aud": "neon", + }}, + error: JwtClaimsError::JwtTokenNotYetReadyToUse, + }, + Test { + body: json! {{ + "exp": now - 60, + "aud": ["neon"], + }}, + error: JwtClaimsError::JwtTokenHasExpired, + }, + Test { + body: json! {{ + }}, + error: JwtClaimsError::InvalidJwtTokenAudience, + }, + Test { + body: json! {{ + "aud": [], + }}, + error: JwtClaimsError::InvalidJwtTokenAudience, + }, + Test { + body: json! {{ + "aud": "foo", + }}, + error: JwtClaimsError::InvalidJwtTokenAudience, + }, + Test { + body: json! {{ + "aud": ["foo"], + }}, + error: JwtClaimsError::InvalidJwtTokenAudience, + }, + Test { + body: json! {{ + "aud": ["foo", "bar"], + }}, + error: JwtClaimsError::InvalidJwtTokenAudience, + }, + ]; + + let role = RoleName::from("authenticated"); + + let rules = vec![AuthRule { + id: String::new(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: Some("neon".to_string()), + role_names: vec![RoleNameInt::from(&role)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let ep = EndpointId::from("ep"); + + let ctx = RequestMonitoring::test(); + for test in table { + let jwt = new_custom_ec_jwt("1".into(), &key, test.body); + + match jwk_cache + .check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt) + .await + { + Err(JwtError::InvalidClaims(error)) if error == test.error => {} + Err(err) => { + panic!("expected {:?}, got {err:?}", test.error) + } + Ok(_payload) => { + panic!("expected {:?}, got ok", test.error) + } + } + } + } }