diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 6585d6f539..6896c7e943 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -5,6 +5,7 @@ use arc_swap::ArcSwapOption; use async_trait::async_trait; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; +use signature::Verifier; use tokio::time::Instant; use crate::intern::EndpointIdInt; @@ -281,16 +282,18 @@ fn verify_rsa_signature( alg: &Option, ) -> anyhow::Result<()> { use jose_jwa::{Algorithm, Signing}; - use rsa::{Pkcs1v15Sign, RsaPublicKey}; - use sha2::Digest; + use rsa::{ + pkcs1v15::{Signature, VerifyingKey}, + RsaPublicKey, + }; let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?; match alg { Some(Algorithm::Signing(Signing::Rs256)) => { - let hashed = sha2::Sha256::digest(data); - let scheme = Pkcs1v15Sign::new::(); - key.verify(scheme, &hashed, sig)?; + let key = VerifyingKey::::new(key); + let sig = Signature::try_from(sig)?; + key.verify(data, &sig)?; } _ => bail!("invalid RSA signing algorithm"), }; @@ -299,7 +302,7 @@ fn verify_rsa_signature( } /// -#[derive(serde::Deserialize)] +#[derive(serde::Deserialize, serde::Serialize)] struct JWTHeader<'a> { /// must be "JWT" typ: &'a str, @@ -330,16 +333,18 @@ impl Drop for AttachedPermit<'_> { mod tests { use super::*; - use std::{future::IntoFuture, net::SocketAddr}; + use std::{future::IntoFuture, net::SocketAddr, time::SystemTime}; use anyhow::Error; use async_trait::async_trait; + use base64::URL_SAFE_NO_PAD; use bytes::Bytes; use http::Response; use http_body_util::Full; use hyper1::service::service_fn; use hyper_util::rt::TokioIo; use rand::rngs::OsRng; + use signature::Signer; use tokio::net::TcpListener; fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) { @@ -349,55 +354,77 @@ mod tests { key: jose_jwk::Key::Ec(pk), prm: jose_jwk::Parameters { kid: Some(kid), + alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)), ..Default::default() }, }; (sk, jwk) } + fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) { - let sk = rsa::RsaPrivateKey::new(&mut OsRng, 1024).unwrap(); + let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap(); let pk = sk.to_public_key().into(); let jwk = jose_jwk::Jwk { key: jose_jwk::Key::Rsa(pk), prm: jose_jwk::Parameters { kid: Some(kid), + alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)), ..Default::default() }, }; (sk, jwk) } + fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { + let header = JWTHeader { + typ: "JWT", + alg: jose_jwa::Algorithm::Signing(sig), + kid: Some(&kid), + }; + let body = typed_json::json! {{ + "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600, + }}; + + let header = + base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD); + let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD); + + format!("{header}.{body}") + } + + fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String { + use p256::ecdsa::{Signature, SigningKey}; + + let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256); + let sig: Signature = SigningKey::from(key).sign(payload.as_bytes()); + let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + + format!("{payload}.{sig}") + } + + fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String { + use rsa::pkcs1v15::SigningKey; + use rsa::signature::SignatureEncoding; + + let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256); + let sig = SigningKey::::new(key).sign(payload.as_bytes()); + let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + + format!("{payload}.{sig}") + } + #[tokio::test] async fn renew() { - let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); - - // let reports = Arc::new(Mutex::new(vec![])); - // let reports2 = reports.clone(); - - let server = hyper1::server::conn::http1::Builder::new(); - // let server = hyper1::server::Server::from_tcp(listener) - // .unwrap() - // .serve(make_service_fn(move |_| { - // // let reports = reports.clone(); - // async move { - // Ok::<_, Error>(service_fn(move |req| { - // // let reports = reports.clone(); - // async move { - // // let bytes = hyper::body::to_bytes(req.into_body()).await?; - // // let events: EventChunk<'static, Event> = - // // serde_json::from_slice(&bytes)?; - // // reports.lock().unwrap().push(events); - // Ok::<_, Error>(Response::new(Body::from(vec![]))) - // } - // })) - // } - // })); - let (rs1, jwk1) = new_rsa_jwk("1".into()); let (rs2, jwk2) = new_rsa_jwk("2".into()); let (ec1, jwk3) = new_ec_jwk("3".into()); let (ec2, jwk4) = new_ec_jwk("4".into()); + let jwt1 = new_rsa_jwt("1".into(), rs1); + let jwt2 = new_rsa_jwt("2".into(), rs2); + let jwt3 = new_ec_jwt("3".into(), ec1); + let jwt4 = new_ec_jwt("4".into(), ec2); + let foo = jose_jwk::JwkSet { keys: vec![jwk1, jwk3], }; @@ -405,6 +432,8 @@ mod tests { keys: vec![jwk2, jwk4], }; + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let server = hyper1::server::conn::http1::Builder::new(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { loop { @@ -461,9 +490,21 @@ mod tests { } let jwk_cache = Arc::new(JWKCacheEntryLock::default()); - let permit = jwk_cache.acquire_permit().await; - let entry = jwk_cache - .renew_jwks(permit, &client, &Fetch(addr)) + + jwk_cache + .check_jwt(jwt1, &client, &Fetch(addr)) + .await + .unwrap(); + jwk_cache + .check_jwt(jwt2, &client, &Fetch(addr)) + .await + .unwrap(); + jwk_cache + .check_jwt(jwt3, &client, &Fetch(addr)) + .await + .unwrap(); + jwk_cache + .check_jwt(jwt4, &client, &Fetch(addr)) .await .unwrap(); }