Compare commits

...

22 Commits

Author SHA1 Message Date
Vlad Lazar
a1cc1f33dc Merge remote-tracking branch 'origin/main' into vlad/hadron-jwt 2025-07-31 11:29:07 +01:00
Vlad Lazar
94dc55f405 chore: hakari 2025-07-29 10:07:09 +01:00
Vlad Lazar
50ed144689 fixup: don't create unused token for safeekeepers 2025-07-29 09:45:47 +01:00
Vlad Lazar
7de0e326a3 sq 2025-07-29 09:45:40 +01:00
Vlad Lazar
88b260bfc7 Merge remote-tracking branch 'origin' into vlad/hadron-jwt 2025-07-29 09:43:09 +01:00
Vlad Lazar
3bf55c8e93 review: bail out instead of panicking 2025-07-25 12:22:25 +01:00
Vlad Lazar
688d0771d3 review: validate that neon and hadron tokens aren't mixed 2025-07-25 12:15:08 +01:00
Vlad Lazar
8f7314c429 fixup: add OpenSSL license back to the allow list 2025-07-24 15:34:03 +01:00
Vlad Lazar
9d8a3c518b fixup: format doc comment 2025-07-24 15:33:18 +01:00
Vlad Lazar
c63b6c5bd3 chore: cargo hakari 2025-07-24 13:38:37 +01:00
Vlad Lazar
00699d86a2 fixup: bring back the SK peer jwt token 2025-07-24 13:31:07 +01:00
Vlad Lazar
10da740e65 fixup: pylints 2025-07-23 19:41:30 +01:00
Vlad Lazar
84dcfa26bb fixup: endpoint storage tests 2025-07-23 19:39:30 +01:00
Vlad Lazar
382ab511a6 Merge remote-tracking branch 'origin' into vlad/hadron-jwt 2025-07-23 19:15:31 +01:00
Vlad Lazar
2e8eeb3b50 fixup: put pg versions back 2025-07-23 19:14:04 +01:00
Vlad Lazar
bcecb03d2d fixup: bang it into shape 2025-07-23 15:58:43 +01:00
Vlad Lazar
3c5fad0184 sq 2025-07-22 18:00:14 +01:00
Vlad Lazar
9ab3203776 sq 2025-07-22 18:00:03 +01:00
Vlad Lazar
b762de56ff fixup: make it build 2025-07-22 15:49:55 +01:00
William Huang
2ddf8f64ce Augment the JwtAuth utility to support RS256 signatures and extracting decoding keys from X509 certificates (#165) 2025-07-22 12:29:45 +01:00
Vlad Lazar
f0ac89ff6f sq 2025-07-22 12:24:59 +01:00
William Huang
9661022e34 Enable JWT auth in Hadron API endpoints accepting untrusted connections (#179) 2025-07-22 12:23:57 +01:00
38 changed files with 1060 additions and 190 deletions

215
Cargo.lock generated
View File

@@ -173,6 +173,45 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048"
dependencies = [
"asn1-rs-derive",
"asn1-rs-impl",
"displaydoc",
"nom",
"num-traits",
"rusticata-macros",
"thiserror 1.0.69",
"time",
]
[[package]]
name = "asn1-rs-derive"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
"synstructure",
]
[[package]]
name = "asn1-rs-impl"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "assert-json-diff"
version = "2.0.2"
@@ -307,6 +346,30 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08b5d4e069cbc868041a64bd68dc8cb39a0d79585cd6c5a24caa8c2d622121be"
dependencies = [
"aws-lc-sys",
"untrusted 0.7.1",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "aws-runtime"
version = "1.4.4"
@@ -968,6 +1031,29 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
"which",
]
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -1260,6 +1346,15 @@ dependencies = [
"replace_with",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
@@ -1492,6 +1587,7 @@ dependencies = [
"postgres_connection",
"regex",
"reqwest",
"rsa",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
@@ -1511,6 +1607,7 @@ dependencies = [
"utils",
"whoami",
"workspace_hack",
"x509-parser",
]
[[package]]
@@ -1836,6 +1933,20 @@ dependencies = [
"zeroize",
]
[[package]]
name = "der-parser"
version = "9.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553"
dependencies = [
"asn1-rs",
"displaydoc",
"nom",
"num-bigint",
"num-traits",
"rusticata-macros",
]
[[package]]
name = "der_derive"
version = "0.7.3"
@@ -1992,6 +2103,12 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.14"
@@ -2109,6 +2226,7 @@ dependencies = [
"http-body-util",
"itertools 0.10.5",
"jsonwebtoken",
"postgres_backend",
"prometheus",
"rand 0.9.1",
"remote_storage",
@@ -2391,6 +2509,12 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -2840,6 +2964,15 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -3614,6 +3747,12 @@ dependencies = [
"spin",
]
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.172"
@@ -4189,6 +4328,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "oid-registry"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9"
dependencies = [
"asn1-rs",
]
[[package]]
name = "once_cell"
version = "1.20.2"
@@ -5072,7 +5220,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.71.1",
"bytes",
"crc32c",
"criterion",
@@ -5737,6 +5885,7 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"aws-lc-rs",
"pem",
"ring",
"rustls-pki-types",
@@ -6052,7 +6201,7 @@ dependencies = [
"cfg-if",
"getrandom 0.2.11",
"libc",
"untrusted",
"untrusted 0.9.0",
"windows-sys 0.52.0",
]
@@ -6173,6 +6322,15 @@ dependencies = [
"semver",
]
[[package]]
name = "rusticata-macros"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
dependencies = [
"nom",
]
[[package]]
name = "rustix"
version = "0.38.41"
@@ -6300,7 +6458,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -6311,7 +6469,7 @@ checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -6322,7 +6480,7 @@ checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -6484,7 +6642,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -7067,6 +7225,7 @@ dependencies = [
"hyper 0.14.30",
"itertools 0.10.5",
"json-structural-diff",
"jsonwebtoken",
"lasso",
"measured",
"metrics",
@@ -8241,6 +8400,12 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8340,6 +8505,7 @@ dependencies = [
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"oid-registry",
"once_cell",
"pem",
"pin-project-lite",
@@ -8347,7 +8513,10 @@ dependencies = [
"pprof",
"pq_proto",
"rand 0.9.1",
"rcgen",
"regex",
"rustls-pemfile 2.1.1",
"rustls-pki-types",
"scopeguard",
"sentry",
"serde",
@@ -8368,6 +8537,7 @@ dependencies = [
"tracing-utils",
"uuid",
"walkdir",
"x509-parser",
]
[[package]]
@@ -8482,7 +8652,7 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.71.1",
"postgres_ffi",
"utils",
]
@@ -8647,6 +8817,18 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "whoami"
version = "1.5.1"
@@ -9019,6 +9201,7 @@ dependencies = [
"der 0.7.8",
"deranged",
"digest",
"displaydoc",
"ecdsa 0.16.9",
"either",
"elliptic-curve 0.13.8",
@@ -9066,6 +9249,7 @@ dependencies = [
"prost 0.13.5",
"quote",
"rand 0.9.1",
"rcgen",
"regex",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
@@ -9151,6 +9335,23 @@ dependencies = [
"zeroize",
]
[[package]]
name = "x509-parser"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69"
dependencies = [
"asn1-rs",
"data-encoding",
"der-parser",
"lazy_static",
"nom",
"oid-registry",
"rusticata-macros",
"thiserror 1.0.69",
"time",
]
[[package]]
name = "xattr"
version = "1.0.0"

View File

@@ -142,6 +142,7 @@ nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket"
notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
oid-registry = "0.7.1"
once_cell = "1.13"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
@@ -173,6 +174,7 @@ rustc-hash = "2.1.1"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
rustls-split = "0.3"
scopeguard = "1.1"
sysinfo = "0.29.2"
sd-notify = "0.4.1"
@@ -235,6 +237,7 @@ rustls-native-certs = "0.8"
whoami = "1.5.1"
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
x509-parser = "0.16"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"

View File

@@ -46,3 +46,5 @@ endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true
x509-parser.workspace = true
rsa = "0.9"

View File

@@ -1049,6 +1049,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
// User (likely interactive) did not provide a description of the environment, give them the default
NeonLocalInitConf {
control_plane_api: Some(DEFAULT_PAGESERVER_CONTROL_PLANE_API.parse().unwrap()),
auth_token_type: AuthType::NeonJWT,
broker: NeonBroker {
listen_addr: Some(DEFAULT_BROKER_ADDR.parse().unwrap()),
listen_https_addr: None,
@@ -1584,7 +1585,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
pageserver_conninfo.prefer_protocol = prefer_protocol;
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
let auth_token = if matches!(
ps_conf.pg_auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
) {
let claims = Claims::new(Some(endpoint.tenant_id), Scope::Tenant);
Some(env.generate_auth_token(&claims)?)

View File

@@ -37,18 +37,8 @@
//! <other PostgreSQL files>
//! ```
//!
use std::collections::{BTreeMap, HashMap};
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
@@ -66,20 +56,30 @@ pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionI
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, RSAKeyParameters, RSAKeyType,
};
use nix::sys::signal::{Signal, kill};
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey, traits::PublicKeyParts};
use safekeeper_api::PgMajorVersion;
use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use std::collections::{BTreeMap, HashMap};
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::debug;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use x509_parser::parse_x509_certificate;
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT;
use postgres_connection::parse_host_port;
@@ -161,23 +161,76 @@ impl ComputeControlPlane {
.unwrap_or(self.base_port)
}
// BEGIN HADRON
/// Extract SubjectPublicKeyInfo from a PEM that can be either a X509 certificate or a public key
fn extract_spki_from_pem(pem: &Pem) -> Result<Vec<u8>> {
if pem.tag() == "CERTIFICATE" {
// Handle X509 certificate
let (_, cert) = parse_x509_certificate(pem.contents())?;
let public_key = cert.public_key();
Ok(public_key.subject_public_key.data.to_vec())
} else {
// Handle public key directly
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
Ok(spki.subject_public_key.raw_bytes().to_vec())
}
}
/// Create RSA JWK from certificate PEM
fn create_rsa_jwk_from_cert(pem: &Pem, key_hash: &[u8]) -> Result<Jwk> {
let public_key = Self::extract_spki_from_pem(pem)?;
// Extract RSA parameters (n, e) from RSA public key DER data
let rsa_key = RsaPublicKey::from_pkcs1_der(&public_key)?;
let n = rsa_key.n().to_bytes_be();
let e = rsa_key.e().to_bytes_be();
Ok(Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::RS256),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
x509_sha256_fingerprint: None::<String>,
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
key_type: RSAKeyType::RSA,
n: URL_SAFE_NO_PAD.encode(n),
e: URL_SAFE_NO_PAD.encode(e),
}),
})
}
// END HADRON
/// Create a JSON Web Key Set. This ideally matches the way we create a JWKS
/// from the production control plane.
fn create_jwks_from_pem(pem: &Pem) -> Result<JwkSet> {
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
let public_key = spki.subject_public_key.raw_bytes();
let public_key = Self::extract_spki_from_pem(pem)?;
let mut hasher = Sha256::new();
hasher.update(public_key);
hasher.update(&public_key);
let key_hash = hasher.finalize();
// BEGIN HADRON
if pem.tag() == "CERTIFICATE" {
// Assume RSA if we are parsing keys from a certificate.
let jwk = Self::create_rsa_jwk_from_cert(pem, &key_hash)?;
return Ok(JwkSet { keys: vec![jwk] });
}
// END HADRON
Ok(JwkSet {
keys: vec![Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
@@ -186,7 +239,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
x: URL_SAFE_NO_PAD.encode(public_key),
}),
}],
})

View File

@@ -2,6 +2,7 @@ use crate::background_process::{self, start_process, stop_process};
use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
@@ -16,15 +17,22 @@ pub struct EndpointStorage {
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub auth_type: AuthType,
}
impl EndpointStorage {
pub fn from_env(env: &LocalEnv) -> EndpointStorage {
let auth_type = match env.token_auth_type {
AuthType::HadronJWT => AuthType::HadronJWT,
AuthType::NeonJWT | AuthType::Trust => AuthType::NeonJWT,
};
EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
auth_type,
}
}
@@ -46,12 +54,14 @@ impl EndpointStorage {
pemfile: Utf8PathBuf,
local_path: Utf8PathBuf,
r#type: String,
auth_type: AuthType,
}
let cfg = Cfg {
listen: self.listen_addr(),
pemfile: parent.join(self.pemfile.clone()),
local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR),
r#type: "LocalFs".to_string(),
auth_type: self.auth_type,
};
std::fs::create_dir_all(self.config_path().parent().unwrap())?;
std::fs::write(self.config_path(), serde_json::to_string(&cfg)?)

View File

@@ -18,7 +18,7 @@ use postgres_backend::AuthType;
use reqwest::{Certificate, Url};
use safekeeper_api::PgMajorVersion;
use serde::{Deserialize, Serialize};
use utils::auth::encode_from_key_file;
use utils::auth::{encode_from_key_file, encode_hadron_token};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::broker::StorageBroker;
@@ -60,6 +60,9 @@ pub struct LocalEnv {
// --tenant_id is not explicitly specified.
pub default_tenant_id: Option<TenantId>,
// The type of tokens to use for authentication in the test environment. Determines
// the type of key pairs and tokens generated in the test.
pub token_auth_type: AuthType,
// used to issue tokens during e.g pg start
pub private_key_path: PathBuf,
/// Path to environment's public key
@@ -105,6 +108,7 @@ pub struct OnDiskConfig {
pub pg_distrib_dir: PathBuf,
pub neon_distrib_dir: PathBuf,
pub default_tenant_id: Option<TenantId>,
pub token_auth_type: Option<AuthType>,
pub private_key_path: PathBuf,
pub public_key_path: PathBuf,
pub broker: NeonBroker,
@@ -153,6 +157,7 @@ pub struct NeonLocalInitConf {
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub generate_local_ssl_certs: bool,
pub auth_token_type: AuthType,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -374,7 +379,7 @@ pub struct SafekeeperConf {
pub sync: bool,
pub remote_storage: Option<String>,
pub backup_threads: Option<u32>,
pub auth_enabled: bool,
pub auth_type: AuthType,
pub listen_addr: Option<String>,
}
@@ -389,7 +394,7 @@ impl Default for SafekeeperConf {
sync: true,
remote_storage: None,
backup_threads: None,
auth_enabled: false,
auth_type: AuthType::Trust,
listen_addr: None,
}
}
@@ -663,6 +668,7 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id,
token_auth_type,
private_key_path,
public_key_path,
broker,
@@ -681,6 +687,7 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id,
token_auth_type: token_auth_type.unwrap_or(AuthType::NeonJWT),
private_key_path,
public_key_path,
broker,
@@ -796,6 +803,7 @@ impl LocalEnv {
pg_distrib_dir: self.pg_distrib_dir.clone(),
neon_distrib_dir: self.neon_distrib_dir.clone(),
default_tenant_id: self.default_tenant_id,
token_auth_type: Some(self.token_auth_type),
private_key_path: self.private_key_path.clone(),
public_key_path: self.public_key_path.clone(),
broker: self.broker.clone(),
@@ -825,8 +833,18 @@ impl LocalEnv {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn generate_auth_token<S: Serialize>(&self, claims: &S) -> anyhow::Result<String> {
let key = self.read_private_key()?;
encode_from_key_file(claims, &key)
match self.token_auth_type {
AuthType::NeonJWT => {
let key_data = self.read_private_key()?;
encode_from_key_file(claims, &key_data)
}
AuthType::HadronJWT => {
let private_key_path = self.get_private_key_path();
let key_data = fs::read(private_key_path)?;
encode_hadron_token(claims, &key_data)
}
_ => panic!("unsupported token auth type {:?}", self.token_auth_type),
}
}
/// Get the path to the private key.
@@ -915,6 +933,7 @@ impl LocalEnv {
generate_local_ssl_certs,
control_plane_hooks_api,
endpoint_storage,
auth_token_type,
} = conf;
// Find postgres binaries.
@@ -943,6 +962,7 @@ impl LocalEnv {
generate_auth_keys(
base_path.join("auth_private_key.pem").as_path(),
base_path.join("auth_public_key.pem").as_path(),
auth_token_type,
)
.context("generate auth keys")?;
let private_key_path = PathBuf::from("auth_private_key.pem");
@@ -956,6 +976,7 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id: Some(default_tenant_id),
token_auth_type: auth_token_type,
private_key_path,
public_key_path,
broker,
@@ -1035,39 +1056,63 @@ pub fn base_path() -> PathBuf {
}
/// Generate a public/private key pair for JWT authentication
fn generate_auth_keys(private_key_path: &Path, public_key_path: &Path) -> anyhow::Result<()> {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
fn generate_auth_keys(
private_key_path: &Path,
public_key_path: &Path,
auth_type: AuthType,
) -> anyhow::Result<()> {
if auth_type == AuthType::NeonJWT {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
} else if auth_type == AuthType::HadronJWT {
// Generate the RSA key pair. Note that the public key is embedded in an X509 certificate.
//
// openssl req -x509 -newkey rsa:4096 -keyout auth_private_key.pem -out auth_public_key.pem -nodes -subj "/CN=eng-brickstore@databricks.com"
let keygen_output = Command::new("openssl")
.arg("req")
.args(["-x509", "-newkey", "rsa:4096", "-sha256"])
.args(["-keyout", private_key_path.to_str().unwrap()])
.args(["-out", public_key_path.to_str().unwrap()])
.args(["-nodes"])
.args(["-subj", "/CN=eng-brickstore@databricks.com"])
.output()
.context("Failed to generate RSA key pair for Hadron token auth")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
}
Ok(())

View File

@@ -73,7 +73,7 @@ impl PageServerNode {
{
match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(
AuthType::NeonJWT | AuthType::HadronJWT => Some(
env.generate_auth_token(&Claims::new(None, Scope::PageServerApi))
.unwrap(),
),
@@ -117,7 +117,10 @@ impl PageServerNode {
// Storage controller uses the same auth as pageserver: if JWT is enabled
// for us, we will also need it to talk to them.
if matches!(conf.http_auth_type, AuthType::NeonJWT) {
// Note: In Hadron the "control plane" is HCC. HCC does not require a token on the trusted port PS connects
// to, so we do not need to set any tokens when using HadronJWT. In the future we may consider using mTLS
// instead of JWT for HTTP auth.
if matches!(conf.http_auth_type, AuthType::NeonJWT | AuthType::HadronJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
@@ -132,7 +135,8 @@ impl PageServerNode {
}
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.contains(&AuthType::NeonJWT)
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
{
// Keys are generated in the toplevel repo dir, pageservers' workdirs
// are one level below that, so refer to keys with ../

View File

@@ -13,6 +13,7 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use postgres_connection::PgConnectionConfig;
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
@@ -110,7 +111,7 @@ impl SafekeeperNode {
}
// Generate a token file for authentication with other safekeepers
if self.conf.auth_enabled {
if self.conf.auth_type != AuthType::Trust {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::SafekeeperData))?;
@@ -156,7 +157,7 @@ impl SafekeeperNode {
"--id".to_owned(),
id_string,
"--listen-pg".to_owned(),
listen_pg,
listen_pg.clone(),
"--listen-http".to_owned(),
listen_http,
"--availability-zone".to_owned(),
@@ -186,7 +187,11 @@ impl SafekeeperNode {
}
let key_path = self.env.base_data_dir.join("auth_public_key.pem");
if self.conf.auth_enabled {
if self.conf.auth_type != AuthType::Trust {
args.extend([
"--token-auth-type".to_owned(),
self.conf.auth_type.to_string(),
]);
let key_path_string = key_path
.to_str()
.with_context(|| {
@@ -205,6 +210,15 @@ impl SafekeeperNode {
"--http-auth-public-key-path".to_owned(),
key_path_string.clone(),
]);
let token_path = self.datadir_path().join("peer_jwt_token");
let token_path_str = token_path
.to_str()
.with_context(|| {
format!("Token path {token_path:?} cannot be represented as a unicode string")
})?
.to_owned();
args.extend(["--auth-token-path".to_owned(), token_path_str]);
}
if let Some(https_port) = self.conf.https_port {
@@ -217,26 +231,14 @@ impl SafekeeperNode {
args.push(format!("--ssl-ca-file={}", ssl_ca_file.to_str().unwrap()));
}
if self.conf.auth_enabled {
let token_path = self.datadir_path().join("peer_jwt_token");
let token_path_str = token_path
.to_str()
.with_context(|| {
format!("Token path {token_path:?} cannot be represented as a unicode string")
})?
.to_owned();
args.extend(["--auth-token-path".to_owned(), token_path_str]);
}
args.extend_from_slice(extra_opts);
let env_variables = Vec::new();
background_process::start_process(
&format!("safekeeper-{id}"),
&datadir,
&self.env.safekeeper_bin(),
&args,
env_variables,
self.safekeeper_env_variables()?,
background_process::InitialPidFile::Expect(self.pid_file()),
retry_timeout,
|| async {
@@ -250,6 +252,11 @@ impl SafekeeperNode {
.await
}
fn safekeeper_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// TODO: remove me
Ok(vec![])
}
///
/// Stop the server.
///

View File

@@ -30,14 +30,14 @@ use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tracing::instrument;
use url::Url;
use utils::auth::{Claims, Scope, encode_from_key_file};
use utils::auth::{Claims, Scope, encode_from_key_file, encode_hadron_token};
use utils::id::{NodeId, TenantId};
use whoami::username;
pub struct StorageController {
env: LocalEnv,
private_key: Option<Pem>,
public_key: Option<Pem>,
private_key: Option<StorageControllerPrivateKey>,
public_key: Option<StorageControllerPublicKey>,
client: reqwest::Client,
config: NeonStorageControllerConf,
@@ -108,6 +108,25 @@ pub struct InspectResponse {
pub attachment: Option<(u32, NodeId)>,
}
enum StorageControllerPublicKey {
RawPublicKey(Pem),
PublicKeyCertPath(Utf8PathBuf),
}
enum StorageControllerPrivateKey {
EdPrivateKey(Pem),
HadronPrivateKey(Utf8PathBuf, Vec<u8>),
}
impl StorageControllerPrivateKey {
pub fn encode_token(&self, claims: &Claims) -> anyhow::Result<String> {
match self {
Self::EdPrivateKey(key_data) => encode_from_key_file(claims, key_data),
Self::HadronPrivateKey(_, key_data) => encode_hadron_token(claims, key_data),
}
}
}
impl StorageController {
pub fn from_env(env: &LocalEnv) -> Self {
// Assume all pageservers have symmetric auth configuration: this service
@@ -152,7 +171,30 @@ impl StorageController {
)
.expect("Failed to parse PEM file")
};
(Some(private_key), Some(public_key))
(
Some(StorageControllerPrivateKey::EdPrivateKey(private_key)),
Some(StorageControllerPublicKey::RawPublicKey(public_key)),
)
}
AuthType::HadronJWT => {
let private_key_path = env.get_private_key_path();
let private_key =
fs::read(private_key_path.clone()).expect("failed to read private key");
// If pageserver auth is enabled, this implicitly enables auth for this service,
// using the same credentials.
let public_key_path =
camino::Utf8PathBuf::try_from(env.base_data_dir.join("auth_public_key.pem"))
.unwrap();
(
Some(StorageControllerPrivateKey::HadronPrivateKey(
camino::Utf8PathBuf::try_from(private_key_path).unwrap(),
private_key,
)),
Some(StorageControllerPublicKey::PublicKeyCertPath(
public_key_path,
)),
)
}
};
@@ -575,23 +617,38 @@ impl StorageController {
if let Some(private_key) = &self.private_key {
let claims = Claims::new(None, Scope::PageServerApi);
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
if let StorageControllerPrivateKey::HadronPrivateKey(key_path, _) = private_key {
args.push(format!("--private-key-path={key_path}"));
}
// We are setting all JWT tokens for Hadron as well in this test to avoid bifurcation between Neon and
// Hadron test cases. In production we do not need to set this as HTTP auth is not enabled on the
// pageserver. We use network segmentation to ensure that only trusted components can talk to
// pageserver's http port
let jwt_token = private_key.encode_token(&claims)?;
args.push(format!("--jwt-token={jwt_token}"));
let peer_claims = Claims::new(None, Scope::Admin);
let peer_jwt_token = encode_from_key_file(&peer_claims, private_key)
let peer_jwt_token = private_key
.encode_token(&peer_claims)
.expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
let claims = Claims::new(None, Scope::SafekeeperData);
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
let jwt_token = private_key
.encode_token(&claims)
.expect("failed to generate jwt token");
args.push(format!("--safekeeper-jwt-token={jwt_token}"));
}
if let Some(public_key) = &self.public_key {
args.push(format!("--public-key=\"{public_key}\""));
match public_key {
StorageControllerPublicKey::RawPublicKey(public_key) => {
args.push(format!("--public-key=\"{public_key}\""));
}
StorageControllerPublicKey::PublicKeyCertPath(public_key_path) => {
args.push(format!("--public-key-cert-path={public_key_path}"));
}
}
}
if let Some(control_plane_hooks_api) = &self.env.control_plane_hooks_api {
@@ -632,7 +689,13 @@ impl StorageController {
self.env.base_data_dir.display()
));
if self.env.safekeepers.iter().any(|sk| sk.auth_enabled) && self.private_key.is_none() {
if self
.env
.safekeepers
.iter()
.any(|sk| sk.auth_type != AuthType::Trust)
&& self.private_key.is_none()
{
anyhow::bail!("Safekeeper set up for auth but no private key specified");
}
@@ -847,7 +910,7 @@ impl StorageController {
println!("Getting claims for path {path}");
if let Some(required_claims) = Self::get_claims_for_path(&path)? {
println!("Got claims {required_claims:?} for path {path}");
let jwt_token = encode_from_key_file(&required_claims, private_key)?;
let jwt_token = private_key.encode_token(&required_claims)?;
builder = builder.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {jwt_token}"),

View File

@@ -46,6 +46,7 @@ allow = [
"ISC",
"MIT",
"MPL-2.0",
"OpenSSL",
"Unicode-3.0",
]
confidence-threshold = 0.8

View File

@@ -20,6 +20,7 @@ tokio.workspace = true
tracing.workspace = true
utils = { path = "../libs/utils", default-features = false }
workspace_hack.workspace = true
postgres_backend.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
http-body-util.workspace = true

View File

@@ -206,12 +206,16 @@ mod tests {
use axum::{body::Body, extract::Request, response::Response};
use http_body_util::BodyExt;
use itertools::iproduct;
use jsonwebtoken::DecodingKey;
use std::env::var;
use std::sync::Arc;
use std::time::Duration;
use test_log::test as testlog;
use tower::{Service, util::ServiceExt};
use utils::id::{TenantId, TimelineId};
use utils::{
auth::JwtAuth,
id::{TenantId, TimelineId},
};
// see libs/remote_storage/tests/test_real_s3.rs
const REAL_S3_ENV: &str = "ENABLE_REAL_S3_REMOTE_STORAGE";
@@ -251,7 +255,9 @@ mod tests {
};
let proxy = Storage {
auth: endpoint_storage::JwtAuth::new(TEST_PUB_KEY_ED25519).unwrap(),
auth: JwtAuth::new(vec![
DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap(),
]),
storage,
cancel: cancel.clone(),
max_upload_file_limit: usize::MAX,
@@ -352,7 +358,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}
@@ -501,7 +507,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}

View File

@@ -7,7 +7,6 @@ use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
use axum_extra::TypedHeader;
use axum_extra::headers::{Authorization, authorization::Bearer};
use camino::Utf8PathBuf;
use jsonwebtoken::{DecodingKey, Validation};
use remote_storage::{GenericRemoteStorage, RemotePath};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@@ -15,28 +14,9 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::auth::JwtAuth;
use utils::id::{EndpointId, TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
impl JwtAuth {
pub fn new(key: &[u8]) -> Result<Self> {
Ok(Self {
decoding_key: DecodingKey::from_ed_pem(key)?,
validation: Validation::new(VALIDATION_ALGO),
})
}
pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
}
}
fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
let key = clean_utf8(&Utf8PathBuf::from(key));
if key.starts_with("..") || key == "." || key == "/" {
@@ -157,7 +137,8 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
let claims: EndpointStorageClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
.map_err(|e| bad_request(e, "decoding token"))?
.claims;
// Read paths may have different endpoint ids. For readonly -> readwrite replica
// prewarming, endpoint must read other endpoint's data.
@@ -224,7 +205,8 @@ impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
let claims: DeletePrefixClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?;
.map_err(|e| bad_request(e, "invalid token"))?
.claims;
let route = DeletePrefixClaims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,

View File

@@ -5,8 +5,10 @@
mod app;
use anyhow::Context;
use clap::Parser;
use postgres_backend::AuthType;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::auth::JwtAuth;
use utils::logging;
//see set()
@@ -18,6 +20,10 @@ const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
const fn default_auth_type() -> AuthType {
AuthType::NeonJWT
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
@@ -39,6 +45,8 @@ struct Config {
storage_kind: remote_storage::TypedRemoteStorageKind,
#[serde(default = "max_upload_file_limit")]
max_upload_file_limit: usize,
#[serde(default = "default_auth_type")]
auth_type: AuthType,
}
#[tokio::main]
@@ -61,10 +69,15 @@ async fn main() -> anyhow::Result<()> {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
let auth = endpoint_storage::JwtAuth::new(&pemfile)?;
if config.auth_type == AuthType::Trust {
anyhow::bail!("Trust based auth is not supported");
}
let auth = match config.auth_type {
AuthType::NeonJWT => JwtAuth::from_key_path(&config.pemfile)?,
AuthType::HadronJWT => JwtAuth::from_cert_path(&config.pemfile)?,
AuthType::Trust => unreachable!(),
};
let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());

View File

@@ -705,8 +705,10 @@ pub fn check_permission_with(
check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
) -> Result<(), ApiError> {
match req.context::<Claims>() {
Some(claims) => Ok(check_permission(&claims)
.map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
Some(claims) => Ok(check_permission(&claims).map_err(|err| {
tracing::info!("Authorization error: {err}");
ApiError::Forbidden("JWT authentication error".to_string())
})?),
None => Ok(()), // claims is None because auth is disabled
}
}

View File

@@ -194,6 +194,10 @@ pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
// Similar to above but uses Hadron JWT. Hadron JWTs are slightly different in that:
// 1. Decoding keys are loaded from PEM-encoded X509 certificates instead of plain key files.
// 2. Signature algorithm is RSA-based (may change in the future).
HadronJWT,
}
impl FromStr for AuthType {
@@ -203,6 +207,7 @@ impl FromStr for AuthType {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
"HadronJWT" => Ok(Self::HadronJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
@@ -213,6 +218,7 @@ impl fmt::Display for AuthType {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
AuthType::HadronJWT => "HadronJWT",
})
}
}
@@ -613,7 +619,10 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
assert!(matches!(
self.auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
));
let (_, jwt_response) = m.split_last().context("protocol violation")?;
@@ -712,7 +721,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
AuthType::NeonJWT | AuthType::HadronJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;

View File

@@ -19,6 +19,7 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
camino.workspace = true
camino-tempfile.workspace = true
chrono.workspace = true
diatomic-waker.workspace = true
git-version.workspace = true
@@ -28,6 +29,7 @@ fail.workspace = true
futures = { workspace = true }
jsonwebtoken.workspace = true
nix = { workspace = true, features = ["ioctl"] }
oid-registry.workspace = true
once_cell.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
@@ -48,9 +50,12 @@ tracing-utils.workspace = true
rand.workspace = true
scopeguard.workspace = true
uuid.workspace = true
rustls-pemfile.workspace = true
rustls-pki-types.workspace = true
strum.workspace = true
strum_macros.workspace = true
walkdir.workspace = true
x509-parser.workspace = true
pq_proto.workspace = true
postgres_connection.workspace = true
@@ -67,6 +72,7 @@ camino-tempfile.workspace = true
pprof.workspace = true
serde_assert.workspace = true
tokio = { workspace = true, features = ["test-util"] }
rcgen = { version = "=0.13.1", features = ["crypto", "aws_lc_rs"] }
[[bench]]
name = "benchmarks"

View File

@@ -1,9 +1,9 @@
// For details about authentication see docs/authentication.md
use std::borrow::Cow;
use std::fmt::Display;
use std::fs;
use std::sync::Arc;
use std::{borrow::Cow, io, path::Path};
use anyhow::Result;
use arc_swap::ArcSwap;
@@ -11,14 +11,17 @@ use camino::Utf8Path;
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
};
use oid_registry::OID_PKCS1_RSAENCRYPTION;
use pem::Pem;
use rustls_pki_types::CertificateDer;
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use crate::id::TenantId;
/// Algorithm to use. We require EdDSA.
/// Signature algorithms to use. We allow EdDSA and RSA/SHA-256.
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
const HADRON_STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
#[serde(rename_all = "lowercase")]
@@ -95,6 +98,14 @@ impl Claims {
endpoint_id: None,
}
}
pub fn new_for_endpoint(endpoint_id: Uuid) -> Self {
Self {
tenant_id: None,
endpoint_id: Some(endpoint_id),
scope: Scope::TenantEndpoint,
}
}
}
pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
@@ -175,6 +186,96 @@ impl JwtAuth {
Ok(Self::new(decoding_keys))
}
// Helper function to parse a X509 certificate file and extract the RSA public keys from it as `DecodingKey`s.
// - `ceritificate_file_path`: the path to the certificate file. It must be a file, not a directory or anything else.
// Returns the successfully extracted decoding keys. Non-RSA keys and non-X509-parsable certificates are skipped.
// Multuple keys may be returned because a single file can contain multiple certificates.
fn extract_rsa_decoding_keys_from_certificate<P: AsRef<Path>>(
certificate_file_path: P,
) -> Result<Vec<DecodingKey>> {
let certs: io::Result<Vec<CertificateDer<'static>>> = rustls_pemfile::certs(
&mut io::BufReader::new(fs::File::open(certificate_file_path)?),
)
.collect();
Ok(certs?
.iter()
.filter_map(
|cert| match x509_parser::parse_x509_certificate(cert) {
Ok((_, cert)) => {
let public_key = cert.public_key();
// Note that we are just extracting the public key from the certificate, not the signature.
// So the algorithm is just the asymmetric crypto such as RSA, no hashes of or anything like
// that.
if *public_key.algorithm.oid() == OID_PKCS1_RSAENCRYPTION {
Some(DecodingKey::from_rsa_der(&public_key.subject_public_key.data))
} else {
tracing::warn!(
"Unsupported public key algorithm: {:?} found in certificate. Skipping.",
public_key.algorithm
);
None
}
}
Err(e) => {
tracing::warn!("Error parsing certificate: {}. Skipping.", e);
None
}
},
)
.collect())
}
/// Create a `JwtAuth` that can decode tokens using RSA public keys in X509 certificates from the given path.
/// - `cert_path`: the path to a directory or a file containing X509 certificates. If it is a directory, all files
/// under the first level of the directory will be inspected for certificates.
/// Returns the `JwtAuth` with the decoding keys extracted from the certificates, or error.
/// Used by Hadron.
pub fn from_cert_path(cert_path: &Utf8Path) -> Result<Self> {
tracing::info!(
"Loading public keys in certificates from path: {}",
cert_path
);
let mut decoding_keys = Vec::new();
let metadata = cert_path.metadata()?;
if metadata.is_dir() {
for entry in fs::read_dir(cert_path)? {
let path = entry?.path();
if !path.is_file() {
// Ignore directories (don't recurse)
continue;
}
decoding_keys.extend(
Self::extract_rsa_decoding_keys_from_certificate(path).unwrap_or_default(),
);
}
} else if metadata.is_file() {
decoding_keys.extend(
Self::extract_rsa_decoding_keys_from_certificate(cert_path).unwrap_or_default(),
);
} else {
anyhow::bail!("{cert_path} is neither a directory or a file")
}
if decoding_keys.is_empty() {
anyhow::bail!(
"Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."
);
}
// Note that we need to create a `JwtAuth` with a different `validation` from the default one created by `new()` in this case
// because the `jsonwebtoken` crate requires that all algorithms in `validation.algorithms` belong to the same algorithm family
// (all RSA or all EdDSA).
let mut validation = Validation::default();
validation.algorithms = vec![HADRON_STORAGE_TOKEN_ALGORITHM];
validation.required_spec_claims = [].into();
Ok(Self {
validation,
decoding_keys,
})
}
pub fn from_key(key: String) -> Result<Self> {
Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?]))
}
@@ -217,8 +318,28 @@ pub fn encode_from_key_file<S: Serialize>(claims: &S, pem: &Pem) -> Result<Strin
Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
}
/// Encode (i.e., sign) a Hadron auth token with the given claims and RSA private key. This is used
/// by HCC to sign tokens when deploying compute or returning the compute spec. The resulting token
/// is used by the compute node to authenticate with HCC and PS/SK.
pub fn encode_hadron_token<S: Serialize>(claims: &S, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
encode_hadron_token_with_encoding_key(claims, &key)
}
pub fn encode_hadron_token_with_encoding_key<S: Serialize>(
claims: &S,
encoding_key: &EncodingKey,
) -> Result<String> {
Ok(encode(
&Header::new(HADRON_STORAGE_TOKEN_ALGORITHM),
claims,
encoding_key,
)?)
}
#[cfg(test)]
mod tests {
use io::Write;
use std::str::FromStr;
use super::*;
@@ -243,8 +364,8 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
fn test_decode() {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
scope: Scope::Tenant,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -272,8 +393,8 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
fn test_encode() {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
scope: Scope::Tenant,
};
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();
@@ -287,4 +408,72 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
assert_eq!(decoded.claims, claims);
}
#[test]
fn test_decode_with_key_from_certificate() {
// Tests that we can sign (encode) a token with a RSA private key and verify (decode) it with the
// corresponding public key extracted from a certificate.
// Generate two RSA key pairs and create self-signed certificates with it.
let key_pair_1 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
let key_pair_2 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
let mut params = rcgen::CertificateParams::default();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "eng-brickstore@databricks.com");
let cert_1 = params.clone().self_signed(&key_pair_1).unwrap();
let cert_2 = params.self_signed(&key_pair_2).unwrap();
// Write the certificates and keys to a temporary dir.
let dir = camino_tempfile::tempdir().unwrap();
{
fs::File::create(dir.path().join("cert_1.pem"))
.unwrap()
.write_all(cert_1.pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("key_1.pem"))
.unwrap()
.write_all(key_pair_1.serialize_pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("cert_2.pem"))
.unwrap()
.write_all(cert_2.pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("key_2.pem"))
.unwrap()
.write_all(key_pair_2.serialize_pem().as_bytes())
.unwrap();
}
// Instantiate a `JwtAuth` with the certificate path. The resulting `JwtAuth` should extract the RSA public
// keys out of the X509 certificates and use them as the decoding keys. Since we specified a directory, both
// X509 certificates will be loaded, but the private key files are skipped.
let auth = JwtAuth::from_cert_path(dir.path()).unwrap();
assert_eq!(auth.decoding_keys.len(), 2);
// Also create a `JwtAuth`, specifying a single certificate file for it to get the decoding key from.
let auth_cert_1 = JwtAuth::from_cert_path(&dir.path().join("cert_1.pem")).unwrap();
assert_eq!(auth_cert_1.decoding_keys.len(), 1);
// Encode tokens with some claims.
let claims = Claims {
tenant_id: Some(TenantId::generate()),
endpoint_id: None,
scope: Scope::Tenant,
};
let encoded_1 =
encode_hadron_token(&claims, key_pair_1.serialize_pem().as_bytes()).unwrap();
let encoded_2 =
encode_hadron_token(&claims, key_pair_2.serialize_pem().as_bytes()).unwrap();
// Verify that we can decode the token with matching decoding keys (decoding also verifies the signature).
assert_eq!(auth.decode::<Claims>(&encoded_1).unwrap().claims, claims);
assert_eq!(auth.decode::<Claims>(&encoded_2).unwrap().claims, claims);
assert_eq!(
auth_cert_1.decode::<Claims>(&encoded_1).unwrap().claims,
claims
);
// Verify that the token cannot be decoded with a mismatched decode key.
assert!(auth_cert_1.decode::<Claims>(&encoded_2).is_err());
}
}

View File

@@ -458,25 +458,37 @@ fn start_pageserver(
let http_auth;
let pg_auth;
let grpc_auth;
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) {
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
{
// unwrap is ok because check is performed when creating config, so path is set and exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!("Loading public key(s) for verifying JWT tokens from {key_path:?}");
let jwt_auth = JwtAuth::from_key_path(key_path)?;
let use_hadron_jwt = conf.http_auth_type == AuthType::HadronJWT
|| conf.pg_auth_type == AuthType::HadronJWT
|| conf.grpc_auth_type == AuthType::HadronJWT;
let jwt_auth = if use_hadron_jwt {
// To validate Hadron JWTs we need to extract decoding keys from X509 certificates.
JwtAuth::from_cert_path(key_path)?
} else {
JwtAuth::from_key_path(key_path)?
};
let auth: Arc<SwappableJwtAuth> = Arc::new(SwappableJwtAuth::new(jwt_auth));
http_auth = match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth.clone()),
};
pg_auth = match conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth.clone()),
};
grpc_auth = match conf.grpc_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth),
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth),
};
} else {
http_auth = None;

View File

@@ -629,6 +629,13 @@ impl PageServerConf {
}
};
let auth_types = [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type];
if auth_types.contains(&AuthType::NeonJWT) && auth_types.contains(&AuthType::HadronJWT) {
return Err(anyhow::anyhow!(
"Mixing neon and hadron style JWT tokens is not supported"
));
}
Ok(conf)
}

View File

@@ -44,6 +44,7 @@ use pageserver_api::models::{
TopTenantShardItem, TopTenantShardsRequest, TopTenantShardsResponse,
};
use pageserver_api::shard::{ShardCount, TenantShardId};
use postgres_backend::AuthType;
use postgres_ffi::PgMajorVersion;
use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError};
use scopeguard::defer;
@@ -55,6 +56,7 @@ use tokio::time::Instant;
use tokio_util::io::StreamReader;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::auth::JwtAuth;
use utils::auth::SwappableJwtAuth;
use utils::generation::Generation;
use utils::id::{TenantId, TimelineId};
@@ -560,6 +562,10 @@ async fn reload_auth_validation_keys_handler(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
// Note to Bricksters: This API returns 400 if HTTP auth is not enabled. This is because `state.auth` is only
// determined by HTTP auth.
// TODO(william.huang): In practice both HTTP and PG auth point to the same SwappableJwtAuth object. Refactor
// this code so that we can swap out the underlying shared auth object even if HTTP auth is None.
check_permission(&request, None)?;
let config = get_config(&request);
let state = get_state(&request);
@@ -570,7 +576,12 @@ async fn reload_auth_validation_keys_handler(
let key_path = config.auth_validation_public_key_path.as_ref().unwrap();
info!("Reloading public key(s) for verifying JWT tokens from {key_path:?}");
match utils::auth::JwtAuth::from_key_path(key_path) {
let new_jwt_auth = if config.http_auth_type == AuthType::HadronJWT {
JwtAuth::from_cert_path(key_path)
} else {
JwtAuth::from_key_path(key_path)
};
match new_jwt_auth {
Ok(new_auth) => {
shared_auth.swap(new_auth);
json_response(StatusCode::OK, ())

View File

@@ -15,6 +15,7 @@ use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use http_utils::tls_certs::ReloadingCertificateResolver;
use metrics::set_build_info_metric;
use postgres_backend::AuthType;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT,
@@ -109,10 +110,15 @@ struct Args {
/// Listen https endpoint for management and metrics in the form host:port.
#[arg(long, default_value = None)]
listen_https: Option<String>,
/// Advertised endpoint for receiving/sending WAL in the form host:port. If not
/// Advertised endpoint to PS for receiving/sending WAL in the form host:port. If not
/// specified, listen_pg is used to advertise instead.
#[arg(long, default_value = None)]
advertise_pg: Option<String>,
/// Advertised endpoint to compute for receiving/sending WAL in the form host:port.
/// Required if --hcc-base-url is specified.
// TODO(vlad): pull in hcc-base-url too
#[arg(long, default_value = None)]
advertise_pg_tenant_only: Option<String>,
/// Availability zone of the safekeeper.
#[arg(long)]
availability_zone: Option<String>,
@@ -164,6 +170,12 @@ struct Args {
/// WAL backup horizon.
#[arg(long)]
disable_wal_backup: bool,
/// Token authentication type. Allowed values are "NeonJWT" and "HadronJWT". Any specified value only takes effect if
/// --pg-auth-public-key-path, --pg-tenant-only-auth-public-key-path, or --http-auth-public-key-path is specified.
/// NeonJWT: Decoding keys are loaded from plain public key files in the specified key path.
/// HadronJWT: Decoding keys are loaded from X509 certificates in the specified key path.
#[arg(long, verbatim_doc_comment, default_value = "NeonJWT")]
token_auth_type: AuthType,
/// If given, enables auth on incoming connections to WAL service endpoint
/// (--listen-pg). Value specifies path to a .pem public key used for
/// validations of JWT tokens. Empty string is allowed and means disabling
@@ -361,9 +373,19 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading pg auth JWT key from {path}");
Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
))
match args.token_auth_type {
AuthType::NeonJWT => Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
)),
AuthType::HadronJWT => Some(Arc::new(
JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
)),
_ => panic!(
"AuthType {auth_type} is not allowed when --pg-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
}
}
};
let pg_tenant_only_auth = match args.pg_tenant_only_auth_public_key_path.as_ref() {
@@ -373,9 +395,19 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading pg tenant only auth JWT key from {path}");
Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
))
match args.token_auth_type {
AuthType::NeonJWT => Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
)),
AuthType::HadronJWT => Some(Arc::new(
JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
)),
_ => panic!(
"AuthType {auth_type} is not allowed when --pg-tenant-only-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
}
}
};
let http_auth = match args.http_auth_public_key_path.as_ref() {
@@ -385,7 +417,17 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading http auth JWT key(s) from {path}");
let jwt_auth = JwtAuth::from_key_path(path).context("failed to load the auth key")?;
let jwt_auth = match args.token_auth_type {
AuthType::NeonJWT => {
JwtAuth::from_key_path(path).context("failed to load the auth key")?
}
AuthType::HadronJWT => JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
_ => panic!(
"AuthType {auth_type} is not allowed when --http-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
};
Some(Arc::new(SwappableJwtAuth::new(jwt_auth)))
}
};
@@ -434,6 +476,7 @@ async fn main() -> anyhow::Result<()> {
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
auth_type: args.token_auth_type,
pg_auth,
pg_tenant_only_auth,
http_auth,
@@ -457,7 +500,7 @@ async fn main() -> anyhow::Result<()> {
enable_tls_wal_service_api: args.enable_tls_wal_service_api,
force_metric_collection_on_scrape: args.force_metric_collection_on_scrape,
/* BEGIN_HADRON */
advertise_pg_addr_tenant_only: None,
advertise_pg_addr_tenant_only: args.advertise_pg_tenant_only,
enable_pull_timeline_on_startup: args.enable_pull_timeline_on_startup,
hcc_base_url: None,
global_disk_check_interval: args.global_disk_check_interval,

View File

@@ -1,6 +1,7 @@
#![deny(clippy::undocumented_unsafe_blocks)]
extern crate hyper0 as hyper;
use postgres_backend::AuthType;
use std::time::Duration;
@@ -128,6 +129,7 @@ pub struct SafeKeeperConf {
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub auth_type: AuthType,
pub pg_auth: Option<Arc<JwtAuth>>,
pub pg_tenant_only_auth: Option<Arc<JwtAuth>>,
pub http_auth: Option<Arc<SwappableJwtAuth>>,
@@ -173,6 +175,7 @@ impl SafeKeeperConf {
peer_recovery_enabled: true,
wal_backup_enabled: true,
backup_parallel_jobs: 1,
auth_type: AuthType::HadronJWT,
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,

View File

@@ -103,7 +103,7 @@ async fn handle_socket(
};
let auth_type = match auth_key {
None => AuthType::Trust,
Some(_) => AuthType::NeonJWT,
Some(_) => conf.auth_type,
};
let auth_pair = auth_key.map(|key| (allowed_auth_scope, key));
let mut conn_handler = SafekeeperPostgresHandler::new(

View File

@@ -14,6 +14,7 @@ use desim::network::TCP;
use desim::node_os::NodeOs;
use desim::proto::{AnyMessage, NetEvent, NodeEvent};
use http::Uri;
use postgres_backend::AuthType;
use safekeeper::SafeKeeperConf;
use safekeeper::safekeeper::{
ProposerAcceptorMessage, SK_PROTO_VERSION_3, SafeKeeper, UNKNOWN_SERVER_VERSION,
@@ -169,6 +170,7 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
availability_zone: None,
peer_recovery_enabled: false,
backup_parallel_jobs: 0,
auth_type: AuthType::NeonJWT,
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,

View File

@@ -31,6 +31,7 @@ humantime.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
json-structural-diff.workspace = true
jsonwebtoken.workspace = true
lasso.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true
@@ -74,4 +75,4 @@ http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -9,7 +9,6 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au
Ok(())
}
#[allow(dead_code)]
pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> {
if claims.scope != Scope::TenantEndpoint {
return Err(AuthError("Scope mismatch. Permission denied".into()));

View File

@@ -0,0 +1,52 @@
use anyhow::{Result, bail};
use camino::Utf8Path;
use jsonwebtoken::EncodingKey;
use std::fs;
use utils::{
auth::{Claims, Scope, encode_hadron_token_with_encoding_key},
id::TenantId,
};
use uuid::Uuid;
pub struct HadronTokenGenerator {
encoding_key: EncodingKey,
}
impl HadronTokenGenerator {
pub fn new(path: &Utf8Path) -> anyhow::Result<Self> {
let key_data = match fs::read(path) {
Ok(ok) => ok,
Err(e) => bail!("Error reading private key file {path:?}. Error: {e}"),
};
let encoding_key = match EncodingKey::from_rsa_pem(&key_data) {
Ok(ok) => ok,
Err(e) => {
bail!("Error reading private key file {path:?} as RSA private key. Error: {e}")
}
};
Ok(Self { encoding_key })
}
pub fn generate_tenant_scope_token(&self, tenant_id: TenantId) -> Result<String> {
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
self.internal_encode_token(&claims)
}
pub fn generate_tenant_endpoint_scope_token(&self, endpoint_id: Uuid) -> Result<String> {
let claims = Claims::new_for_endpoint(endpoint_id);
self.internal_encode_token(&claims)
}
pub fn generate_ps_sk_auth_token(&self) -> Result<String> {
let claims = Claims {
tenant_id: None,
endpoint_id: None,
scope: Scope::SafekeeperData,
};
self.internal_encode_token(&claims)
}
fn internal_encode_token(&self, claims: &Claims) -> Result<String> {
encode_hadron_token_with_encoding_key(claims, &self.encoding_key)
}
}

View File

@@ -40,6 +40,7 @@ use tokio_util::sync::CancellationToken;
use tracing::warn;
use utils::auth::{Scope, SwappableJwtAuth};
use utils::id::{NodeId, TenantId, TimelineId};
use uuid::Uuid;
use crate::http;
use crate::metrics::{
@@ -1801,6 +1802,23 @@ fn check_permissions(request: &Request<Body>, required_scope: Scope) -> Result<(
}
})
}
/// Similar to `check_permissions()` above, but checks for TenantEndpoint scope specifically. Used by the compute spec-fetch API.
/// Access by Admin-scope tokens is also permitted.
/// TODO(william.huang): Merge with the previous function by refactoring `Scope` to make it carry the dependent arguments.
/// E.g., `Scope::TenantEndpoint(EndpointId)`, `Scope::Tenant(TenantId)`, etc.
#[allow(unused)]
fn check_endpoint_permission(request: &Request<Body>, endpoint_id: Uuid) -> Result<(), ApiError> {
check_permission_with(
request,
|claims| match crate::auth::check_endpoint_permission(claims, endpoint_id) {
Err(e) => match crate::auth::check_permission(claims, Scope::Admin) {
Ok(()) => Ok(()),
Err(_) => Err(e),
},
Ok(()) => Ok(()),
},
)
}
#[derive(Clone, Debug)]
struct RequestMeta {

View File

@@ -6,6 +6,7 @@ extern crate hyper0 as hyper;
mod auth;
mod background_node_operations;
mod compute_hook;
pub mod hadron_token;
pub mod hadron_utils;
mod heartbeater;
pub mod http;

View File

@@ -14,6 +14,7 @@ use metrics::BuildInfo;
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::config::PostHogConfig;
use reqwest::Certificate;
use storage_controller::hadron_token::HadronTokenGenerator;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
use storage_controller::persistence::Persistence;
@@ -70,10 +71,26 @@ struct Cli {
#[arg(long)]
listen_https: Option<std::net::SocketAddr>,
/// Public key for JWT authentication of clients
/// PEM-encoded public key string for JWT authentication of clients.
#[arg(long)]
public_key: Option<String>,
/// Path to public key certificates used for JWT authentiation of clients.
/// Only one of `public_key` and `public_key_cert_path` should be set.
/// `public_key` or `public_key_cert_path` can point to either a file or a directory.
/// When pointed to a directory, public keys in all files in the first level of
/// the directory (i.e., no subdirectories) will be loaded.
#[arg(long)]
public_key_cert_path: Option<Utf8PathBuf>,
/// Path to the file containing the private key used to generate JWTs for client
/// authentication. The file should contain a single PEM-encoded private key.
/// The HCC uses this key to sign JWTs handed out to other components.
/// Note that unlike the `public_key` and `public_key_cert_path` args above,
/// `private_key_path` must specify a file path, not a directory.
#[arg(long)]
private_key_path: Option<Utf8PathBuf>,
/// Token for authenticating this service with the pageservers it controls
#[arg(long)]
jwt_token: Option<String>,
@@ -256,6 +273,7 @@ struct Secrets {
safekeeper_jwt_token: Option<String>,
control_plane_jwt_token: Option<String>,
peer_jwt_token: Option<String>,
token_generator: Option<HadronTokenGenerator>,
}
const POSTHOG_CONFIG_ENV: &str = "POSTHOG_CONFIG";
@@ -281,7 +299,16 @@ impl Secrets {
let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV) {
Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?),
None => None,
None => {
if let Some(path) = args.public_key_cert_path.as_ref() {
Some(
JwtAuth::from_cert_path(path)
.context("Loading public key from certificates")?,
)
} else {
None
}
}
};
let this = Self {
@@ -300,6 +327,11 @@ impl Secrets {
Self::CONTROL_PLANE_JWT_TOKEN_ENV,
),
peer_jwt_token: Self::load_secret(&args.peer_jwt_token, Self::PEER_JWT_TOKEN_ENV),
token_generator: args
.private_key_path
.as_ref()
.map(|path| HadronTokenGenerator::new(path))
.transpose()?,
};
Ok(this)
@@ -489,12 +521,12 @@ async fn async_main() -> anyhow::Result<()> {
let persistence = Arc::new(Persistence::new(secrets.database_url).await);
let service = Service::spawn(config, persistence.clone()).await?;
let service = Service::spawn(config, persistence.clone(), secrets.token_generator).await?;
let auth = secrets
let jwt_auth = secrets
.public_key
.map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
let router = make_router(service.clone(), auth, build_info)
let router = make_router(service.clone(), jwt_auth, build_info)
.build()
.map_err(|err| anyhow!(err))?;
let http_service =

View File

@@ -4,6 +4,7 @@ pub(crate) mod safekeeper_reconciler;
mod safekeeper_service;
mod tenant_shard_iterator;
use crate::hadron_token::HadronTokenGenerator;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap, HashSet};
@@ -518,6 +519,11 @@ pub struct Service {
inner: Arc<std::sync::RwLock<ServiceState>>,
config: Config,
persistence: Arc<Persistence>,
// HadronTokenGenerator to generate (sign) JWTs during compute deployment and compute-spec generation.
#[allow(unused)]
token_generator: Option<HadronTokenGenerator>,
compute_hook: Arc<ComputeHook>,
result_tx: tokio::sync::mpsc::UnboundedSender<ReconcileResultRequest>,
@@ -1668,7 +1674,11 @@ impl Service {
}
}
pub async fn spawn(config: Config, persistence: Arc<Persistence>) -> anyhow::Result<Arc<Self>> {
pub async fn spawn(
config: Config,
persistence: Arc<Persistence>,
token_generator: Option<HadronTokenGenerator>,
) -> anyhow::Result<Arc<Self>> {
let (result_tx, result_rx) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = tokio::sync::mpsc::unbounded_channel();
@@ -1925,6 +1935,7 @@ impl Service {
))),
config: config.clone(),
persistence,
token_generator,
compute_hook: Arc::new(ComputeHook::new(config.clone())?),
result_tx,
heartbeater_ps,

View File

@@ -13,10 +13,11 @@ if TYPE_CHECKING:
@dataclass
class AuthKeys:
priv: str
algorithm: str
def generate_token(self, *, scope: TokenScope, **token_data: Any) -> str:
token_data = {key: str(val) for key, val in token_data.items()}
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm="EdDSA")
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm=self.algorithm)
# cast(Any, self.priv)
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
@@ -46,3 +47,4 @@ class TokenScope(StrEnum):
TENANT = "tenant"
SCRUBBER = "scrubber"
INFRA = "infra"
TENANT_ENDPOINT = "tenantendpoint"

View File

@@ -28,11 +28,15 @@ import asyncpg
import backoff
import boto3
import httpx
import jwt
import psycopg2
import psycopg2.sql
import pytest
import requests
import toml
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from jwcrypto import jwk
# Type-related stuff
@@ -402,6 +406,15 @@ class PageserverImportConfig:
return ("timeline_import_config", value)
@dataclass
class HadronTokenDecoder:
public_key: str
algorithm: str
def decode_token(self, token: str) -> dict[str, Any]:
return jwt.decode(token, self.public_key, algorithms=[self.algorithm])
class NeonEnvBuilder:
"""
Builder object to create a Neon runtime environment
@@ -472,6 +485,7 @@ class NeonEnvBuilder:
self.safekeepers_id_start = safekeepers_id_start
self.safekeepers_enable_fsync = safekeepers_enable_fsync
self.auth_enabled = auth_enabled
self.use_hadron_auth_tokens = False
self.default_branch_name = default_branch_name
self.env: NeonEnv | None = None
self.keep_remote_storage_contents: bool = True
@@ -1121,6 +1135,11 @@ class NeonEnv:
self.repo_dir.joinpath("rootCA.crt") if self.generate_local_ssl_certs else None
)
# The auth token type used in the test environment. neon_local is instruted to generate key pairs
# according to the auth token type. The keys are always generated but are only used if
# config.auth_enabled == True.
self.auth_token_type: str = "HadronJWT" if config.use_hadron_auth_tokens else "NeonJWT"
neon_local_env_vars = {}
if self.rust_log_override is not None:
neon_local_env_vars["RUST_LOG"] = self.rust_log_override
@@ -1198,6 +1217,7 @@ class NeonEnv:
"listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}",
},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
"auth_token_type": self.auth_token_type,
}
if config.use_https_storage_broker_api:
@@ -1245,9 +1265,9 @@ class NeonEnv:
)
# Create config for pageserver
http_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
http_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
pg_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
grpc_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
for ps_id in range(
self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers
):
@@ -1385,9 +1405,8 @@ class NeonEnv:
"https_port": port.https,
"sync": config.safekeepers_enable_fsync,
"use_https_safekeeper_api": config.use_https_safekeeper_api,
"auth_type": self.auth_token_type if config.auth_enabled else "Trust",
}
if config.auth_enabled:
sk_cfg["auth_enabled"] = True
if self.safekeepers_remote_storage is not None:
sk_cfg["remote_storage"] = (
self.safekeepers_remote_storage.to_toml_inline_table().strip()
@@ -1578,29 +1597,66 @@ class NeonEnv:
@cached_property
def auth_keys(self) -> AuthKeys:
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_text()
return AuthKeys(priv=priv)
algorithm = "EdDSA" if self.auth_token_type == "NeonJWT" else "RS256"
return AuthKeys(priv=priv, algorithm=algorithm)
@cached_property
def hadron_token_decoder(self) -> HadronTokenDecoder:
cert = (Path(self.repo_dir) / "auth_public_key.pem").read_text()
x509_cert = x509.load_pem_x509_certificate(cert.encode(), default_backend())
pem_public_key = (
x509_cert.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode()
)
return HadronTokenDecoder(public_key=pem_public_key, algorithm="RS256")
def regenerate_keys_at(self, privkey_path: Path, pubkey_path: Path):
# compare generate_auth_keys() in local_env.rs
subprocess.run(
["openssl", "genpkey", "-algorithm", "ed25519", "-out", privkey_path],
cwd=self.repo_dir,
check=True,
)
if self.auth_token_type == "NeonJWT":
# compare generate_auth_keys() in local_env.rs
subprocess.run(
["openssl", "genpkey", "-algorithm", "ed25519", "-out", privkey_path],
cwd=self.repo_dir,
check=True,
)
subprocess.run(
[
"openssl",
"pkey",
"-in",
privkey_path,
"-pubout",
"-out",
pubkey_path,
],
cwd=self.repo_dir,
check=True,
)
subprocess.run(
[
"openssl",
"pkey",
"-in",
privkey_path,
"-pubout",
"-out",
pubkey_path,
],
cwd=self.repo_dir,
check=True,
)
elif self.auth_token_type == "HadronJWT":
# compare generate_auth_keys() in local_env.rs
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-sha256",
"-keyout",
privkey_path,
"-out",
pubkey_path,
"-nodes",
"-subj",
"/CN=eng-brickstore@databricks.com",
],
cwd=self.repo_dir,
check=True,
)
del self.auth_keys
def generate_endpoint_id(self) -> str:
@@ -2021,10 +2077,10 @@ class NeonStorageController(MetricsGetter, LogUtils):
return resp
def headers(self, scope: TokenScope | None) -> dict[str, str]:
def headers(self, scope: TokenScope | None, **token_data: Any) -> dict[str, str]:
headers = {}
if self.auth_enabled and scope is not None:
jwt_token = self.env.auth_keys.generate_token(scope=scope)
jwt_token = self.env.auth_keys.generate_token(scope=scope, **token_data)
headers["Authorization"] = f"Bearer {jwt_token}"
return headers

View File

@@ -32,8 +32,11 @@ def assert_client_not_authorized(env: NeonEnv, http_client: PageserverHttpClient
assert_client_authorized(env, http_client)
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
ps = env.pageserver
@@ -72,8 +75,10 @@ def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
env.pageserver.tenant_create(TenantId.generate(), auth_token=tenant_token)
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
@@ -91,8 +96,10 @@ def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
assert cur.fetchone() == (5000050000,)
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -145,8 +152,10 @@ def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
assert_client_authorized(env, pageserver_http_client_new)
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -183,7 +192,12 @@ def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("auth_enabled", [False, True])
def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_auth_failures(
neon_env_builder: NeonEnvBuilder, auth_enabled: bool, use_hadron_auth_tokens: bool
):
neon_env_builder.auth_enabled = auth_enabled
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.auth_enabled = auth_enabled
env = neon_env_builder.init_start()

View File

@@ -1406,6 +1406,9 @@ def test_storage_controller_s3_time_travel_recovery(
def test_storage_controller_auth(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
env = neon_env_builder.init_start()
assert env.auth_token_type == "NeonJWT"
svc = env.storage_controller
api = env.storage_controller_api

View File

@@ -78,6 +78,7 @@ parquet = { version = "53", default-features = false, features = ["zstd"] }
portable-atomic = { version = "1", features = ["require-cas"] }
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
rand = { version = "0.9" }
rcgen = { version = "0.13", features = ["aws_lc_rs"] }
regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
regex-syntax = { version = "0.8" }
@@ -126,6 +127,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
displaydoc = { version = "0.2" }
either = { version = "1" }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
half = { version = "2", default-features = false, features = ["num-traits"] }