diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index ab848551a9..3a9fcc8992 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -33,6 +33,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static { ) -> impl Future>> + Send; } +#[derive(Debug, Clone)] pub(crate) struct AuthRule { pub(crate) id: String, pub(crate) jwks_url: url::Url, @@ -659,37 +660,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL -----END PRIVATE KEY----- "; - #[tokio::test] - async fn renew() { - let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into()); - let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into()); - let (ec1, jwk3) = new_ec_jwk("3".into()); - let (ec2, jwk4) = new_ec_jwk("4".into()); + #[derive(Clone)] + struct Fetch(Vec); - let foo_jwks = jose_jwk::JwkSet { - keys: vec![jwk1, jwk3], - }; - let bar_jwks = jose_jwk::JwkSet { - keys: vec![jwk2, jwk4], - }; + impl FetchAuthRules for Fetch { + async fn fetch_auth_rules( + &self, + _ctx: &RequestMonitoring, + _endpoint: EndpointId, + ) -> anyhow::Result> { + Ok(self.0.clone()) + } + } + async fn jwks_server( + router: impl for<'a> Fn(&'a str) -> Option> + Send + Sync + 'static, + ) -> SocketAddr { + let router = Arc::new(router); let service = service_fn(move |req| { - let foo_jwks = foo_jwks.clone(); - let bar_jwks = bar_jwks.clone(); + let router = Arc::clone(&router); async move { - let jwks = match req.uri().path() { - "/foo" => &foo_jwks, - "/bar" => &bar_jwks, - _ => { - return Response::builder() - .status(404) - .body(Full::new(Bytes::new())); - } - }; - let body = serde_json::to_vec(jwks).unwrap(); - Response::builder() - .status(200) - .body(Full::new(Bytes::from(body))) + match router(req.uri().path()) { + Some(body) => Response::builder() + .status(200) + .body(Full::new(Bytes::from(body))), + None => Response::builder() + .status(404) + .body(Full::new(Bytes::new())), + } } }); @@ -704,84 +702,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL } }); - let client = reqwest::Client::new(); + addr + } - #[derive(Clone)] - struct Fetch(SocketAddr, Vec); + #[tokio::test] + async fn check_jwt_happy_path() { + let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into()); + let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into()); + let (ec1, jwk3) = new_ec_jwk("ec1".into()); + let (ec2, jwk4) = new_ec_jwk("ec2".into()); - impl FetchAuthRules for Fetch { - async fn fetch_auth_rules( - &self, - _ctx: &RequestMonitoring, - _endpoint: EndpointId, - ) -> anyhow::Result> { - Ok(vec![ - AuthRule { - id: "foo".to_owned(), - jwks_url: format!("http://{}/foo", self.0).parse().unwrap(), - audience: None, - role_names: self.1.clone(), - }, - AuthRule { - id: "bar".to_owned(), - jwks_url: format!("http://{}/bar", self.0).parse().unwrap(), - audience: None, - role_names: self.1.clone(), - }, - ]) - } - } + let foo_jwks = jose_jwk::JwkSet { + keys: vec![jwk1, jwk3], + }; + let bar_jwks = jose_jwk::JwkSet { + keys: vec![jwk2, jwk4], + }; + + let jwks_addr = jwks_server(move |path| match path { + "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()), + "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()), + _ => None, + }) + .await; let role_name1 = RoleName::from("anonymous"); let role_name2 = RoleName::from("authenticated"); - let fetch = Fetch( - addr, - vec![ - RoleNameInt::from(&role_name1), - RoleNameInt::from(&role_name2), - ], - ); + let roles = vec![ + RoleNameInt::from(&role_name1), + RoleNameInt::from(&role_name2), + ]; + let rules = vec![ + AuthRule { + id: "foo".to_owned(), + jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(), + audience: None, + role_names: roles.clone(), + }, + AuthRule { + id: "bar".to_owned(), + jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(), + audience: None, + role_names: roles.clone(), + }, + ]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); let endpoint = EndpointId::from("ep"); - let jwk_cache = Arc::new(JwkCacheEntryLock::default()); - - let jwt1 = new_rsa_jwt("1".into(), rs1); - let jwt2 = new_rsa_jwt("2".into(), rs2); - let jwt3 = new_ec_jwt("3".into(), &ec1); - let jwt4 = new_ec_jwt("4".into(), &ec2); - - // had the wrong kid, therefore will have the wrong ecdsa signature - let bad_jwt = new_ec_jwt("3".into(), &ec2); - // this role_name is not accepted - let bad_role_name = RoleName::from("cloud_admin"); - - let err = jwk_cache - .check_jwt( - &RequestMonitoring::test(), - &bad_jwt, - &client, - endpoint.clone(), - &role_name1, - &fetch, - ) - .await - .unwrap_err(); - assert!(err.to_string().contains("signature error")); - - let err = jwk_cache - .check_jwt( - &RequestMonitoring::test(), - &jwt1, - &client, - endpoint.clone(), - &bad_role_name, - &fetch, - ) - .await - .unwrap_err(); - assert!(err.to_string().contains("jwk not found")); + let jwt1 = new_rsa_jwt("rs1".into(), rs1); + let jwt2 = new_rsa_jwt("rs2".into(), rs2); + let jwt3 = new_ec_jwt("ec1".into(), &ec1); + let jwt4 = new_ec_jwt("ec2".into(), &ec2); let tokens = [jwt1, jwt2, jwt3, jwt4]; let role_names = [role_name1, role_name2]; @@ -790,15 +765,94 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL jwk_cache .check_jwt( &RequestMonitoring::test(), - token, - &client, endpoint.clone(), role, &fetch, + token, ) .await .unwrap(); } } } + + #[tokio::test] + async fn check_jwt_invalid_signature() { + let (_, jwk) = new_ec_jwk("1".into()); + let (key, _) = new_ec_jwk("1".into()); + + // has a matching kid, but signed by the wrong key + let bad_jwt = new_ec_jwt("1".into(), &key); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let role = RoleName::from("authenticated"); + + let rules = vec![AuthRule { + id: String::new(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![RoleNameInt::from(&role)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let ep = EndpointId::from("ep"); + + let ctx = RequestMonitoring::test(); + let err = jwk_cache + .check_jwt(&ctx, ep, &role, &fetch, &bad_jwt) + .await + .unwrap_err(); + assert!( + err.to_string().contains("signature error"), + "expected \"signature error\", got {err:?}" + ); + } + + #[tokio::test] + async fn check_jwt_unknown_role() { + let (key, jwk) = new_rsa_jwk(RS1, "1".into()); + let jwt = new_rsa_jwt("1".into(), key); + + let jwks = jose_jwk::JwkSet { keys: vec![jwk] }; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let role = RoleName::from("authenticated"); + let rules = vec![AuthRule { + id: String::new(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![RoleNameInt::from(&role)], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let ep = EndpointId::from("ep"); + + // this role_name is not accepted + let bad_role_name = RoleName::from("cloud_admin"); + + let ctx = RequestMonitoring::test(); + let err = jwk_cache + .check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt) + .await + .unwrap_err(); + + assert!( + err.to_string().contains("jwk not found"), + "expected \"jwk not found\", got {err:?}" + ); + } } diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index 108420d7d7..b5d8f996ce 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -1,5 +1,6 @@ use std::{ - hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock, + any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, + sync::OnceLock, }; use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo}; @@ -16,12 +17,21 @@ pub struct StringInterner { _id: PhantomData, } -#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)] +#[derive(PartialEq, Clone, Copy, Eq, Hash)] pub struct InternedString { inner: Spur, _id: PhantomData, } +impl std::fmt::Debug for InternedString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("InternedString") + .field(&type_name::()) + .field(&self.as_str()) + .finish() + } +} + impl std::fmt::Display for InternedString { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.as_str().fmt(f)