diff --git a/Cargo.lock b/Cargo.lock index 46668e6e6d..0f4b51aece 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1570,6 +1570,7 @@ dependencies = [ "postgres_connection", "regex", "reqwest", + "rsa", "safekeeper_api", "safekeeper_client", "scopeguard", @@ -1589,6 +1590,7 @@ dependencies = [ "utils", "whoami", "workspace_hack", + "x509-parser", ] [[package]] @@ -2207,6 +2209,7 @@ dependencies = [ "http-body-util", "itertools 0.10.5", "jsonwebtoken", + "postgres_backend", "prometheus", "rand 0.8.5", "remote_storage", diff --git a/compute_tools/src/http/middleware/authorize.rs b/compute_tools/src/http/middleware/authorize.rs index a82f46e062..1b0bf4d9c5 100644 --- a/compute_tools/src/http/middleware/authorize.rs +++ b/compute_tools/src/http/middleware/authorize.rs @@ -23,6 +23,18 @@ pub(in crate::http) struct Authorize { impl Authorize { pub fn new(compute_id: String, jwks: JwkSet) -> Self { let mut validation = Validation::new(Algorithm::EdDSA); + + // BEGIN HADRON + let use_rsa = jwks.keys.iter().any(|jwk| { + jwk.common + .key_algorithm + .is_some_and(|alg| alg == jsonwebtoken::jwk::KeyAlgorithm::RS256) + }); + if use_rsa { + validation = Validation::new(Algorithm::RS256); + } + // END HADRON + validation.validate_exp = true; // Unused by the control plane validation.validate_nbf = false; diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index bbaa3f12b9..76334ac8e0 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -46,3 +46,5 @@ endpoint_storage.workspace = true compute_api.workspace = true workspace_hack.workspace = true tracing.workspace = true +x509-parser.workspace = true +rsa = "0.9" diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 4c569d7005..892180a4dc 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -37,18 +37,8 @@ //! //! ``` //! -use std::collections::BTreeMap; -use std::fmt::Display; -use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; -use std::path::PathBuf; -use std::process::Command; -use std::str::FromStr; -use std::sync::Arc; -use std::time::{Duration, Instant}; - use anyhow::{Context, Result, anyhow, bail}; -use base64::Engine; -use base64::prelude::BASE64_URL_SAFE_NO_PAD; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; use compute_api::requests::{ COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest, }; @@ -62,21 +52,31 @@ use compute_api::spec::{ }; use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations, - OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, + OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, RSAKeyParameters, RSAKeyType, }; use nix::sys::signal::{Signal, kill}; use pem::Pem; use reqwest::header::CONTENT_TYPE; +use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey, traits::PublicKeyParts}; use safekeeper_api::PgMajorVersion; use safekeeper_api::membership::SafekeeperGeneration; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use spki::der::Decode; use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef}; +use std::collections::BTreeMap; +use std::fmt::Display; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; +use std::path::PathBuf; +use std::process::Command; +use std::str::FromStr; +use std::sync::Arc; +use std::time::{Duration, Instant}; use tracing::debug; use url::Host; use utils::id::{NodeId, TenantId, TimelineId}; use utils::shard::ShardStripeSize; +use x509_parser::parse_x509_certificate; use crate::local_env::LocalEnv; use crate::postgresql_conf::PostgresConf; @@ -155,23 +155,76 @@ impl ComputeControlPlane { .unwrap_or(self.base_port) } + // BEGIN HADRON + + /// Extract SubjectPublicKeyInfo from a PEM that can be either a X509 certificate or a public key + fn extract_spki_from_pem(pem: &Pem) -> Result> { + if pem.tag() == "CERTIFICATE" { + // Handle X509 certificate + let (_, cert) = parse_x509_certificate(pem.contents())?; + let public_key = cert.public_key(); + Ok(public_key.subject_public_key.data.to_vec()) + } else { + // Handle public key directly + let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?; + Ok(spki.subject_public_key.raw_bytes().to_vec()) + } + } + + /// Create RSA JWK from certificate PEM + fn create_rsa_jwk_from_cert(pem: &Pem, key_hash: &[u8]) -> Result { + let public_key = Self::extract_spki_from_pem(pem)?; + + // Extract RSA parameters (n, e) from RSA public key DER data + let rsa_key = RsaPublicKey::from_pkcs1_der(&public_key)?; + let n = rsa_key.n().to_bytes_be(); + let e = rsa_key.e().to_bytes_be(); + + Ok(Jwk { + common: CommonParameters { + public_key_use: Some(PublicKeyUse::Signature), + key_operations: Some(vec![KeyOperations::Verify]), + key_algorithm: Some(KeyAlgorithm::RS256), + key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)), + x509_url: None::, + x509_chain: None::>, + x509_sha1_fingerprint: None::, + x509_sha256_fingerprint: None::, + }, + algorithm: AlgorithmParameters::RSA(RSAKeyParameters { + key_type: RSAKeyType::RSA, + n: URL_SAFE_NO_PAD.encode(n), + e: URL_SAFE_NO_PAD.encode(e), + }), + }) + } + + // END HADRON + /// Create a JSON Web Key Set. This ideally matches the way we create a JWKS /// from the production control plane. fn create_jwks_from_pem(pem: &Pem) -> Result { - let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?; - let public_key = spki.subject_public_key.raw_bytes(); + let public_key = Self::extract_spki_from_pem(pem)?; let mut hasher = Sha256::new(); - hasher.update(public_key); + hasher.update(&public_key); let key_hash = hasher.finalize(); + // BEGIN HADRON + if pem.tag() == "CERTIFICATE" { + // Assume RSA if we are parsing keys from a certificate. + let jwk = Self::create_rsa_jwk_from_cert(pem, &key_hash)?; + return Ok(JwkSet { keys: vec![jwk] }); + } + // END HADRON + Ok(JwkSet { keys: vec![Jwk { common: CommonParameters { public_key_use: Some(PublicKeyUse::Signature), key_operations: Some(vec![KeyOperations::Verify]), key_algorithm: Some(KeyAlgorithm::EdDSA), - key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)), + key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)), x509_url: None::, x509_chain: None::>, x509_sha1_fingerprint: None::, @@ -180,7 +233,7 @@ impl ComputeControlPlane { algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters { key_type: OctetKeyPairType::OctetKeyPair, curve: EllipticCurve::Ed25519, - x: BASE64_URL_SAFE_NO_PAD.encode(public_key), + x: URL_SAFE_NO_PAD.encode(public_key), }), }], }) diff --git a/control_plane/src/endpoint_storage.rs b/control_plane/src/endpoint_storage.rs index 171aaeddb4..70e8c59778 100644 --- a/control_plane/src/endpoint_storage.rs +++ b/control_plane/src/endpoint_storage.rs @@ -2,6 +2,7 @@ use crate::background_process::{self, start_process, stop_process}; use crate::local_env::LocalEnv; use anyhow::{Context, Result}; use camino::Utf8PathBuf; +use postgres_backend::AuthType; use std::io::Write; use std::net::SocketAddr; use std::time::Duration; @@ -16,15 +17,22 @@ pub struct EndpointStorage { pub data_dir: Utf8PathBuf, pub pemfile: Utf8PathBuf, pub addr: SocketAddr, + pub auth_type: AuthType, } impl EndpointStorage { pub fn from_env(env: &LocalEnv) -> EndpointStorage { + let auth_type = match env.token_auth_type { + AuthType::HadronJWT => AuthType::HadronJWT, + AuthType::NeonJWT | AuthType::Trust => AuthType::NeonJWT, + }; + EndpointStorage { bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(), data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(), pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(), addr: env.endpoint_storage.listen_addr, + auth_type, } } @@ -46,12 +54,14 @@ impl EndpointStorage { pemfile: Utf8PathBuf, local_path: Utf8PathBuf, r#type: String, + auth_type: AuthType, } let cfg = Cfg { listen: self.listen_addr(), pemfile: parent.join(self.pemfile.clone()), local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR), r#type: "LocalFs".to_string(), + auth_type: self.auth_type, }; std::fs::create_dir_all(self.config_path().parent().unwrap())?; std::fs::write(self.config_path(), serde_json::to_string(&cfg)?) diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 29709f702a..d7e6e3f8f4 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -120,7 +120,7 @@ impl PageServerNode { // Note: In Hadron the "control plane" is HCC. HCC does not require a token on the trusted port PS connects // to, so we do not need to set any tokens when using HadronJWT. In the future we may consider using mTLS // instead of JWT for HTTP auth. - if matches!(conf.http_auth_type, AuthType::NeonJWT) { + if matches!(conf.http_auth_type, AuthType::NeonJWT | AuthType::HadronJWT) { let jwt_token = self .env .generate_auth_token(&Claims::new(None, Scope::GenerationsApi)) @@ -135,7 +135,8 @@ impl PageServerNode { } if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] - .contains(&AuthType::NeonJWT) + .iter() + .any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT) { // Keys are generated in the toplevel repo dir, pageservers' workdirs // are one level below that, so refer to keys with ../ diff --git a/control_plane/src/storage_controller.rs b/control_plane/src/storage_controller.rs index e862f53c49..0026e5ec8a 100644 --- a/control_plane/src/storage_controller.rs +++ b/control_plane/src/storage_controller.rs @@ -618,24 +618,24 @@ impl StorageController { if let StorageControllerPrivateKey::HadronPrivateKey(key_path, _) = private_key { args.push(format!("--private-key-path={key_path}")); } - // We are setting --jwt-token for Hadron as well in this test to avoid bifurcation between Neon and + // We are setting all JWT tokens for Hadron as well in this test to avoid bifurcation between Neon and // Hadron test cases. In production we do not need to set this as HTTP auth is not enabled on the // pageserver. We use network segmentation to ensure that only trusted components can talk to // pageserver's http port let jwt_token = private_key.encode_token(&claims)?; args.push(format!("--jwt-token={jwt_token}")); - if let StorageControllerPrivateKey::EdPrivateKey(key) = private_key { - let peer_claims = Claims::new(None, Scope::Admin); - let peer_jwt_token = - encode_from_key_file(&peer_claims, key).expect("failed to generate jwt token"); - args.push(format!("--peer-jwt-token={peer_jwt_token}")); + let peer_claims = Claims::new(None, Scope::Admin); + let peer_jwt_token = private_key + .encode_token(&peer_claims) + .expect("failed to generate jwt token"); + args.push(format!("--peer-jwt-token={peer_jwt_token}")); - let claims = Claims::new(None, Scope::SafekeeperData); - let jwt_token = - encode_from_key_file(&claims, key).expect("failed to generate jwt token"); - args.push(format!("--safekeeper-jwt-token={jwt_token}")); - } + let claims = Claims::new(None, Scope::SafekeeperData); + let jwt_token = private_key + .encode_token(&claims) + .expect("failed to generate jwt token"); + args.push(format!("--safekeeper-jwt-token={jwt_token}")); } if let Some(public_key) = &self.public_key { @@ -903,7 +903,7 @@ impl StorageController { if let Some(private_key) = &self.private_key { println!("Getting claims for path {path}"); if let Some(required_claims) = Self::get_claims_for_path(&path)? { - println!("Got claims {:?} for path {}", required_claims, path); + println!("Got claims {required_claims:?} for path {path}"); let jwt_token = private_key.encode_token(&required_claims)?; builder = builder.header( reqwest::header::AUTHORIZATION, diff --git a/endpoint_storage/Cargo.toml b/endpoint_storage/Cargo.toml index c2e21d02e2..ecddefe1d9 100644 --- a/endpoint_storage/Cargo.toml +++ b/endpoint_storage/Cargo.toml @@ -20,6 +20,7 @@ tokio.workspace = true tracing.workspace = true utils = { path = "../libs/utils", default-features = false } workspace_hack.workspace = true +postgres_backend.workspace = true [dev-dependencies] camino-tempfile.workspace = true http-body-util.workspace = true diff --git a/endpoint_storage/src/lib.rs b/endpoint_storage/src/lib.rs index d1625dc843..f6e1128892 100644 --- a/endpoint_storage/src/lib.rs +++ b/endpoint_storage/src/lib.rs @@ -7,7 +7,6 @@ use axum::{RequestPartsExt, http::StatusCode, http::request::Parts}; use axum_extra::TypedHeader; use axum_extra::headers::{Authorization, authorization::Bearer}; use camino::Utf8PathBuf; -use jsonwebtoken::{DecodingKey, Validation}; use remote_storage::{GenericRemoteStorage, RemotePath}; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -15,28 +14,9 @@ use std::result::Result as StdResult; use std::sync::Arc; use tokio_util::sync::CancellationToken; use tracing::{debug, error}; +use utils::auth::JwtAuth; use utils::id::{EndpointId, TenantId, TimelineId}; -// simplified version of utils::auth::JwtAuth -pub struct JwtAuth { - decoding_key: DecodingKey, - validation: Validation, -} - -pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA; -impl JwtAuth { - pub fn new(key: &[u8]) -> Result { - Ok(Self { - decoding_key: DecodingKey::from_ed_pem(key)?, - validation: Validation::new(VALIDATION_ALGO), - }) - } - - pub fn decode(&self, token: &str) -> Result { - Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?) - } -} - fn normalize_key(key: &str) -> StdResult { let key = clean_utf8(&Utf8PathBuf::from(key)); if key.starts_with("..") || key == "." || key == "/" { @@ -157,7 +137,8 @@ impl FromRequestParts> for S3Path { let claims: EndpointStorageClaims = state .auth .decode(bearer.token()) - .map_err(|e| bad_request(e, "decoding token"))?; + .map_err(|e| bad_request(e, "decoding token"))? + .claims; // Read paths may have different endpoint ids. For readonly -> readwrite replica // prewarming, endpoint must read other endpoint's data. @@ -224,7 +205,8 @@ impl FromRequestParts> for PrefixS3Path { let claims: DeletePrefixClaims = state .auth .decode(bearer.token()) - .map_err(|e| bad_request(e, "invalid token"))?; + .map_err(|e| bad_request(e, "invalid token"))? + .claims; let route = DeletePrefixClaims { tenant_id: path.tenant_id, timeline_id: path.timeline_id, diff --git a/endpoint_storage/src/main.rs b/endpoint_storage/src/main.rs index c96cef2083..859c6390ef 100644 --- a/endpoint_storage/src/main.rs +++ b/endpoint_storage/src/main.rs @@ -5,8 +5,10 @@ mod app; use anyhow::Context; use clap::Parser; +use postgres_backend::AuthType; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use tracing::info; +use utils::auth::JwtAuth; use utils::logging; //see set() @@ -18,6 +20,10 @@ const fn listen() -> SocketAddr { SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243) } +const fn default_auth_type() -> AuthType { + AuthType::NeonJWT +} + #[derive(Parser)] struct Args { #[arg(exclusive = true)] @@ -39,6 +45,8 @@ struct Config { storage_kind: remote_storage::TypedRemoteStorageKind, #[serde(default = "max_upload_file_limit")] max_upload_file_limit: usize, + #[serde(default = "default_auth_type")] + auth_type: AuthType, } #[tokio::main] @@ -61,10 +69,15 @@ async fn main() -> anyhow::Result<()> { anyhow::bail!("Supply either config file path or --config=inline-config"); }; - info!("Reading pemfile from {}", config.pemfile.clone()); - let pemfile = std::fs::read(config.pemfile.clone())?; - info!("Loading public key from {}", config.pemfile.clone()); - let auth = endpoint_storage::JwtAuth::new(&pemfile)?; + if config.auth_type == AuthType::Trust { + anyhow::bail!("Trust based auth is not supported"); + } + + let auth = match config.auth_type { + AuthType::NeonJWT => JwtAuth::from_key_path(&config.pemfile)?, + AuthType::HadronJWT => JwtAuth::from_cert_path(&config.pemfile)?, + AuthType::Trust => unreachable!(), + }; let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap(); info!("listening on {}", listener.local_addr().unwrap()); diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 0efdaeaef0..9bf1400b26 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -467,8 +467,9 @@ fn start_pageserver( let key_path = conf.auth_validation_public_key_path.as_ref().unwrap(); info!("Loading public key(s) for verifying JWT tokens from {key_path:?}"); - let use_hadron_jwt = - conf.http_auth_type == AuthType::HadronJWT || conf.pg_auth_type == AuthType::HadronJWT; + let use_hadron_jwt = conf.http_auth_type == AuthType::HadronJWT + || conf.pg_auth_type == AuthType::HadronJWT + || conf.grpc_auth_type == AuthType::HadronJWT; let jwt_auth = if use_hadron_jwt { // To validate Hadron JWTs we need to extract decoding keys from X509 certificates. diff --git a/test_runner/regress/test_auth.py b/test_runner/regress/test_auth.py index eba8197116..60f8380bf4 100644 --- a/test_runner/regress/test_auth.py +++ b/test_runner/regress/test_auth.py @@ -32,8 +32,12 @@ def assert_client_not_authorized(env: NeonEnv, http_client: PageserverHttpClient assert_client_authorized(env, http_client) -def test_pageserver_auth(neon_env_builder: NeonEnvBuilder): + +@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False]) +def test_pageserver_auth(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool): neon_env_builder.auth_enabled = True + neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens + env = neon_env_builder.init_start() ps = env.pageserver @@ -71,9 +75,10 @@ def test_pageserver_auth(neon_env_builder: NeonEnvBuilder): ): env.pageserver.tenant_create(TenantId.generate(), auth_token=tenant_token) - -def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False]) +def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool): neon_env_builder.auth_enabled = True + neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens neon_env_builder.num_safekeepers = 3 env = neon_env_builder.init_start() @@ -90,9 +95,10 @@ def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder): cur.execute("SELECT sum(key) FROM t") assert cur.fetchone() == (5000050000,) - -def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False]) +def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool): neon_env_builder.auth_enabled = True + neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens env = neon_env_builder.init_start() env.pageserver.allowed_errors.extend( [".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"] @@ -144,9 +150,10 @@ def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder): assert_client_not_authorized(env, pageserver_http_client_old) assert_client_authorized(env, pageserver_http_client_new) - -def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False]) +def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool): neon_env_builder.auth_enabled = True + neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens env = neon_env_builder.init_start() env.pageserver.allowed_errors.extend( [".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"] @@ -183,7 +190,10 @@ def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder): @pytest.mark.parametrize("auth_enabled", [False, True]) -def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): +@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False]) +def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool, use_hadron_auth_tokens: bool): + neon_env_builder.auth_enabled = auth_enabled + neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens neon_env_builder.auth_enabled = auth_enabled env = neon_env_builder.init_start() diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 686f1c27ae..dbd0388034 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -1403,15 +1403,12 @@ def test_storage_controller_s3_time_travel_recovery( env.storage_controller.consistency_check() -@pytest.mark.skip( - reason=""" - [BRC-1269, BRC-1270] Hadron currently uses network segmentation to prevent all storage controller (non-HCC) HTTP APIs from being - accessed from untrusted networks, so auth is currently permenantly disabled for all of these APIs in storage controller code. - """ -) def test_storage_controller_auth(neon_env_builder: NeonEnvBuilder): neon_env_builder.auth_enabled = True env = neon_env_builder.init_start() + + assert env.auth_token_type == "NeonJWT" + svc = env.storage_controller api = env.storage_controller_api