fixup: bang it into shape

This commit is contained in:
Vlad Lazar
2025-07-23 15:58:43 +01:00
parent 3c5fad0184
commit bcecb03d2d
13 changed files with 159 additions and 74 deletions

3
Cargo.lock generated
View File

@@ -1570,6 +1570,7 @@ dependencies = [
"postgres_connection",
"regex",
"reqwest",
"rsa",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
@@ -1589,6 +1590,7 @@ dependencies = [
"utils",
"whoami",
"workspace_hack",
"x509-parser",
]
[[package]]
@@ -2207,6 +2209,7 @@ dependencies = [
"http-body-util",
"itertools 0.10.5",
"jsonwebtoken",
"postgres_backend",
"prometheus",
"rand 0.8.5",
"remote_storage",

View File

@@ -23,6 +23,18 @@ pub(in crate::http) struct Authorize {
impl Authorize {
pub fn new(compute_id: String, jwks: JwkSet) -> Self {
let mut validation = Validation::new(Algorithm::EdDSA);
// BEGIN HADRON
let use_rsa = jwks.keys.iter().any(|jwk| {
jwk.common
.key_algorithm
.is_some_and(|alg| alg == jsonwebtoken::jwk::KeyAlgorithm::RS256)
});
if use_rsa {
validation = Validation::new(Algorithm::RS256);
}
// END HADRON
validation.validate_exp = true;
// Unused by the control plane
validation.validate_nbf = false;

View File

@@ -46,3 +46,5 @@ endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true
x509-parser.workspace = true
rsa = "0.9"

View File

@@ -37,18 +37,8 @@
//! <other PostgreSQL files>
//! ```
//!
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
@@ -62,21 +52,31 @@ use compute_api::spec::{
};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, RSAKeyParameters, RSAKeyType,
};
use nix::sys::signal::{Signal, kill};
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey, traits::PublicKeyParts};
use safekeeper_api::PgMajorVersion;
use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::debug;
use url::Host;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::ShardStripeSize;
use x509_parser::parse_x509_certificate;
use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
@@ -155,23 +155,76 @@ impl ComputeControlPlane {
.unwrap_or(self.base_port)
}
// BEGIN HADRON
/// Extract SubjectPublicKeyInfo from a PEM that can be either a X509 certificate or a public key
fn extract_spki_from_pem(pem: &Pem) -> Result<Vec<u8>> {
if pem.tag() == "CERTIFICATE" {
// Handle X509 certificate
let (_, cert) = parse_x509_certificate(pem.contents())?;
let public_key = cert.public_key();
Ok(public_key.subject_public_key.data.to_vec())
} else {
// Handle public key directly
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
Ok(spki.subject_public_key.raw_bytes().to_vec())
}
}
/// Create RSA JWK from certificate PEM
fn create_rsa_jwk_from_cert(pem: &Pem, key_hash: &[u8]) -> Result<Jwk> {
let public_key = Self::extract_spki_from_pem(pem)?;
// Extract RSA parameters (n, e) from RSA public key DER data
let rsa_key = RsaPublicKey::from_pkcs1_der(&public_key)?;
let n = rsa_key.n().to_bytes_be();
let e = rsa_key.e().to_bytes_be();
Ok(Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::RS256),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
x509_sha256_fingerprint: None::<String>,
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
key_type: RSAKeyType::RSA,
n: URL_SAFE_NO_PAD.encode(n),
e: URL_SAFE_NO_PAD.encode(e),
}),
})
}
// END HADRON
/// Create a JSON Web Key Set. This ideally matches the way we create a JWKS
/// from the production control plane.
fn create_jwks_from_pem(pem: &Pem) -> Result<JwkSet> {
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
let public_key = spki.subject_public_key.raw_bytes();
let public_key = Self::extract_spki_from_pem(pem)?;
let mut hasher = Sha256::new();
hasher.update(public_key);
hasher.update(&public_key);
let key_hash = hasher.finalize();
// BEGIN HADRON
if pem.tag() == "CERTIFICATE" {
// Assume RSA if we are parsing keys from a certificate.
let jwk = Self::create_rsa_jwk_from_cert(pem, &key_hash)?;
return Ok(JwkSet { keys: vec![jwk] });
}
// END HADRON
Ok(JwkSet {
keys: vec![Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
@@ -180,7 +233,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
x: URL_SAFE_NO_PAD.encode(public_key),
}),
}],
})

View File

@@ -2,6 +2,7 @@ use crate::background_process::{self, start_process, stop_process};
use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
@@ -16,15 +17,22 @@ pub struct EndpointStorage {
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub auth_type: AuthType,
}
impl EndpointStorage {
pub fn from_env(env: &LocalEnv) -> EndpointStorage {
let auth_type = match env.token_auth_type {
AuthType::HadronJWT => AuthType::HadronJWT,
AuthType::NeonJWT | AuthType::Trust => AuthType::NeonJWT,
};
EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
auth_type,
}
}
@@ -46,12 +54,14 @@ impl EndpointStorage {
pemfile: Utf8PathBuf,
local_path: Utf8PathBuf,
r#type: String,
auth_type: AuthType,
}
let cfg = Cfg {
listen: self.listen_addr(),
pemfile: parent.join(self.pemfile.clone()),
local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR),
r#type: "LocalFs".to_string(),
auth_type: self.auth_type,
};
std::fs::create_dir_all(self.config_path().parent().unwrap())?;
std::fs::write(self.config_path(), serde_json::to_string(&cfg)?)

View File

@@ -120,7 +120,7 @@ impl PageServerNode {
// Note: In Hadron the "control plane" is HCC. HCC does not require a token on the trusted port PS connects
// to, so we do not need to set any tokens when using HadronJWT. In the future we may consider using mTLS
// instead of JWT for HTTP auth.
if matches!(conf.http_auth_type, AuthType::NeonJWT) {
if matches!(conf.http_auth_type, AuthType::NeonJWT | AuthType::HadronJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
@@ -135,7 +135,8 @@ impl PageServerNode {
}
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.contains(&AuthType::NeonJWT)
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
{
// Keys are generated in the toplevel repo dir, pageservers' workdirs
// are one level below that, so refer to keys with ../

View File

@@ -618,24 +618,24 @@ impl StorageController {
if let StorageControllerPrivateKey::HadronPrivateKey(key_path, _) = private_key {
args.push(format!("--private-key-path={key_path}"));
}
// We are setting --jwt-token for Hadron as well in this test to avoid bifurcation between Neon and
// We are setting all JWT tokens for Hadron as well in this test to avoid bifurcation between Neon and
// Hadron test cases. In production we do not need to set this as HTTP auth is not enabled on the
// pageserver. We use network segmentation to ensure that only trusted components can talk to
// pageserver's http port
let jwt_token = private_key.encode_token(&claims)?;
args.push(format!("--jwt-token={jwt_token}"));
if let StorageControllerPrivateKey::EdPrivateKey(key) = private_key {
let peer_claims = Claims::new(None, Scope::Admin);
let peer_jwt_token =
encode_from_key_file(&peer_claims, key).expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
let peer_claims = Claims::new(None, Scope::Admin);
let peer_jwt_token = private_key
.encode_token(&peer_claims)
.expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
let claims = Claims::new(None, Scope::SafekeeperData);
let jwt_token =
encode_from_key_file(&claims, key).expect("failed to generate jwt token");
args.push(format!("--safekeeper-jwt-token={jwt_token}"));
}
let claims = Claims::new(None, Scope::SafekeeperData);
let jwt_token = private_key
.encode_token(&claims)
.expect("failed to generate jwt token");
args.push(format!("--safekeeper-jwt-token={jwt_token}"));
}
if let Some(public_key) = &self.public_key {
@@ -903,7 +903,7 @@ impl StorageController {
if let Some(private_key) = &self.private_key {
println!("Getting claims for path {path}");
if let Some(required_claims) = Self::get_claims_for_path(&path)? {
println!("Got claims {:?} for path {}", required_claims, path);
println!("Got claims {required_claims:?} for path {path}");
let jwt_token = private_key.encode_token(&required_claims)?;
builder = builder.header(
reqwest::header::AUTHORIZATION,

View File

@@ -20,6 +20,7 @@ tokio.workspace = true
tracing.workspace = true
utils = { path = "../libs/utils", default-features = false }
workspace_hack.workspace = true
postgres_backend.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
http-body-util.workspace = true

View File

@@ -7,7 +7,6 @@ use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
use axum_extra::TypedHeader;
use axum_extra::headers::{Authorization, authorization::Bearer};
use camino::Utf8PathBuf;
use jsonwebtoken::{DecodingKey, Validation};
use remote_storage::{GenericRemoteStorage, RemotePath};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@@ -15,28 +14,9 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::auth::JwtAuth;
use utils::id::{EndpointId, TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
impl JwtAuth {
pub fn new(key: &[u8]) -> Result<Self> {
Ok(Self {
decoding_key: DecodingKey::from_ed_pem(key)?,
validation: Validation::new(VALIDATION_ALGO),
})
}
pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
}
}
fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
let key = clean_utf8(&Utf8PathBuf::from(key));
if key.starts_with("..") || key == "." || key == "/" {
@@ -157,7 +137,8 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
let claims: EndpointStorageClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
.map_err(|e| bad_request(e, "decoding token"))?
.claims;
// Read paths may have different endpoint ids. For readonly -> readwrite replica
// prewarming, endpoint must read other endpoint's data.
@@ -224,7 +205,8 @@ impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
let claims: DeletePrefixClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?;
.map_err(|e| bad_request(e, "invalid token"))?
.claims;
let route = DeletePrefixClaims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,

View File

@@ -5,8 +5,10 @@
mod app;
use anyhow::Context;
use clap::Parser;
use postgres_backend::AuthType;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::auth::JwtAuth;
use utils::logging;
//see set()
@@ -18,6 +20,10 @@ const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
const fn default_auth_type() -> AuthType {
AuthType::NeonJWT
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
@@ -39,6 +45,8 @@ struct Config {
storage_kind: remote_storage::TypedRemoteStorageKind,
#[serde(default = "max_upload_file_limit")]
max_upload_file_limit: usize,
#[serde(default = "default_auth_type")]
auth_type: AuthType,
}
#[tokio::main]
@@ -61,10 +69,15 @@ async fn main() -> anyhow::Result<()> {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
let auth = endpoint_storage::JwtAuth::new(&pemfile)?;
if config.auth_type == AuthType::Trust {
anyhow::bail!("Trust based auth is not supported");
}
let auth = match config.auth_type {
AuthType::NeonJWT => JwtAuth::from_key_path(&config.pemfile)?,
AuthType::HadronJWT => JwtAuth::from_cert_path(&config.pemfile)?,
AuthType::Trust => unreachable!(),
};
let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());

View File

@@ -467,8 +467,9 @@ fn start_pageserver(
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!("Loading public key(s) for verifying JWT tokens from {key_path:?}");
let use_hadron_jwt =
conf.http_auth_type == AuthType::HadronJWT || conf.pg_auth_type == AuthType::HadronJWT;
let use_hadron_jwt = conf.http_auth_type == AuthType::HadronJWT
|| conf.pg_auth_type == AuthType::HadronJWT
|| conf.grpc_auth_type == AuthType::HadronJWT;
let jwt_auth = if use_hadron_jwt {
// To validate Hadron JWTs we need to extract decoding keys from X509 certificates.

View File

@@ -32,8 +32,12 @@ def assert_client_not_authorized(env: NeonEnv, http_client: PageserverHttpClient
assert_client_authorized(env, http_client)
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
ps = env.pageserver
@@ -71,9 +75,10 @@ def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
):
env.pageserver.tenant_create(TenantId.generate(), auth_token=tenant_token)
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
@@ -90,9 +95,10 @@ def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
cur.execute("SELECT sum(key) FROM t")
assert cur.fetchone() == (5000050000,)
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -144,9 +150,10 @@ def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
assert_client_not_authorized(env, pageserver_http_client_old)
assert_client_authorized(env, pageserver_http_client_new)
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -183,7 +190,10 @@ def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("auth_enabled", [False, True])
def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = auth_enabled
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.auth_enabled = auth_enabled
env = neon_env_builder.init_start()

View File

@@ -1403,15 +1403,12 @@ def test_storage_controller_s3_time_travel_recovery(
env.storage_controller.consistency_check()
@pytest.mark.skip(
reason="""
[BRC-1269, BRC-1270] Hadron currently uses network segmentation to prevent all storage controller (non-HCC) HTTP APIs from being
accessed from untrusted networks, so auth is currently permenantly disabled for all of these APIs in storage controller code.
"""
)
def test_storage_controller_auth(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
env = neon_env_builder.init_start()
assert env.auth_token_type == "NeonJWT"
svc = env.storage_controller
api = env.storage_controller_api