diff --git a/Cargo.toml b/Cargo.toml index 6d91262882..4c90cc26e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,7 @@ nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket" notify = "6.0.0" num_cpus = "1.15" num-traits = "0.2.19" +oid-registry = "0.6.1" once_cell = "1.13" opentelemetry = "0.27" opentelemetry_sdk = "0.27" @@ -170,6 +171,7 @@ rustc-hash = "2.1.1" rustls = { version = "0.23.16", default-features = false } rustls-pemfile = "2" rustls-pki-types = "1.11" +rustls-split = "0.3" scopeguard = "1.1" sysinfo = "0.29.2" sd-notify = "0.4.1" diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 4b326949d7..5e83cc1f7f 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -19,6 +19,7 @@ anyhow.workspace = true bincode.workspace = true bytes.workspace = true camino.workspace = true +camino-tempfile.workspace = true chrono.workspace = true diatomic-waker.workspace = true git-version.workspace = true @@ -28,6 +29,7 @@ fail.workspace = true futures = { workspace = true } jsonwebtoken.workspace = true nix = { workspace = true, features = ["ioctl"] } +oid-registry.workspace = true once_cell.workspace = true pem.workspace = true pin-project-lite.workspace = true @@ -48,9 +50,12 @@ tracing-utils.workspace = true rand.workspace = true scopeguard.workspace = true uuid.workspace = true +rustls-pemfile.workspace = true +rustls-pki-types.workspace = true strum.workspace = true strum_macros.workspace = true walkdir.workspace = true +x509-parser.workspace = true pq_proto.workspace = true postgres_connection.workspace = true @@ -67,6 +72,7 @@ camino-tempfile.workspace = true pprof.workspace = true serde_assert.workspace = true tokio = { workspace = true, features = ["test-util"] } +rcgen = { version = "=0.13.1", features = ["crypto", "aws_lc_rs"] } [[bench]] name = "benchmarks" diff --git a/libs/utils/src/auth.rs b/libs/utils/src/auth.rs index a5b7e7f190..644f79c993 100644 --- a/libs/utils/src/auth.rs +++ b/libs/utils/src/auth.rs @@ -14,11 +14,14 @@ use jsonwebtoken::{ use pem::Pem; use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned}; use uuid::Uuid; +use oid_registry::OID_PKCS1_RSAENCRYPTION; +use rustls_pki_types::CertificateDer; use crate::id::TenantId; -/// Algorithm to use. We require EdDSA. +/// Signature algorithms to use. We allow EdDSA and RSA/SHA-256. const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA; +const HADRON_STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::RS256; #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] #[serde(rename_all = "lowercase")] @@ -183,6 +186,94 @@ impl JwtAuth { Ok(Self::new(decoding_keys)) } + // Helper function to parse a X509 certificate file and extract the RSA public keys from it as `DecodingKey`s. + // - `ceritificate_file_path`: the path to the certificate file. It must be a file, not a directory or anything else. + // Returns the successfully extracted decoding keys. Non-RSA keys and non-X509-parsable certificates are skipped. + // Multuple keys may be returned because a single file can contain multiple certificates. + fn extract_rsa_decoding_keys_from_certificate>( + certificate_file_path: P, + ) -> Result> { + let certs: io::Result>> = rustls_pemfile::certs( + &mut io::BufReader::new(fs::File::open(certificate_file_path)?), + ) + .collect(); + + Ok(certs? + .iter() + .filter_map( + |cert| match x509_parser::parse_x509_certificate(cert) { + Ok((_, cert)) => { + let public_key = cert.public_key(); + // Note that we are just extracting the public key from the certificate, not the signature. + // So the algorithm is just the asymmetric crypto such as RSA, no hashes of or anything like + // that. + if *public_key.algorithm.oid() == OID_PKCS1_RSAENCRYPTION { + Some(DecodingKey::from_rsa_der(&public_key.subject_public_key.data)) + } else { + tracing::warn!( + "Unsupported public key algorithm: {:?} found in certificate. Skipping.", + public_key.algorithm + ); + None + } + } + Err(e) => { + tracing::warn!("Error parsing certificate: {}. Skipping.", e); + None + } + }, + ) + .collect()) + } + + /// Create a `JwtAuth` that can decode tokens using RSA public keys in X509 certificates from the given path. + /// - `cert_path`: the path to a directory or a file containing X509 certificates. If it is a directory, all files + /// under the first level of the directory will be inspected for certificates. + /// Returns the `JwtAuth` with the decoding keys extracted from the certificates, or error. + /// Used by Hadron. + pub fn from_cert_path(cert_path: &Utf8Path) -> Result { + tracing::info!( + "Loading public keys in certificates from path: {}", + cert_path + ); + + let mut decoding_keys = Vec::new(); + + let metadata = cert_path.metadata()?; + if metadata.is_dir() { + for entry in fs::read_dir(cert_path)? { + let path = entry?.path(); + if !path.is_file() { + // Ignore directories (don't recurse) + continue; + } + decoding_keys.extend( + Self::extract_rsa_decoding_keys_from_certificate(path).unwrap_or_default(), + ); + } + } else if metadata.is_file() { + decoding_keys.extend( + Self::extract_rsa_decoding_keys_from_certificate(cert_path).unwrap_or_default(), + ); + } else { + anyhow::bail!("{cert_path} is neither a directory or a file") + } + if decoding_keys.is_empty() { + anyhow::bail!("Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."); + } + + // Note that we need to create a `JwtAuth` with a different `validation` from the default one created by `new()` in this case + // because the `jsonwebtoken` crate requires that all algorithms in `validation.algorithms` belong to the same algorithm family + // (all RSA or all EdDSA). + let mut validation = Validation::default(); + validation.algorithms = vec![HADRON_STORAGE_TOKEN_ALGORITHM]; + validation.required_spec_claims = [].into(); + Ok(Self { + validation, + decoding_keys, + }) + } + pub fn from_key(key: String) -> Result { Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?])) } @@ -246,6 +337,7 @@ pub fn encode_hadron_token_with_encoding_key( #[cfg(test)] mod tests { + use io::Write; use std::str::FromStr; use super::*;