add basic tls test

This commit is contained in:
Conrad Ludgate
2025-06-18 15:39:58 +01:00
parent 52be0146d3
commit 86fe3150f0
6 changed files with 111 additions and 15 deletions

View File

@@ -3,7 +3,6 @@ use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
use anyhow::{Context, Result, bail};
use compute_api::responses::TlsConfig;
use ring::digest;
use x509_cert::Certificate;
#[derive(Clone, Copy)]
pub struct CertDigest(digest::Digest);
@@ -105,7 +104,10 @@ pub fn update_key_path_blocking(pg_data: &Path, key_pair: &KeyPair) -> Result<()
}
fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
use x509_cert::Certificate;
use x509_cert::der::oid::db::rfc5912::ECDSA_WITH_SHA_256;
use x509_cert::der::oid::db::rfc8410::ID_ED_25519;
use x509_cert::der::pem;
let certs = Certificate::load_pem_chain(cert.as_bytes())
.context("decoding PEM encoded certificates")?;
@@ -116,22 +118,30 @@ fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
bail!("no certificates found");
};
let pubkey = cert
.tbs_certificate
.subject_public_key_info
.subject_public_key
.raw_bytes();
match cert.signature_algorithm.oid {
ECDSA_WITH_SHA_256 => {
let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
let a = key.public_key().to_sec1_bytes();
let b = cert
.tbs_certificate
.subject_public_key_info
.subject_public_key
.raw_bytes();
if *a != *b {
if *key.public_key().to_sec1_bytes() != *pubkey {
bail!("private key file does not match certificate")
}
}
_ => bail!("unknown TLS key type"),
ID_ED_25519 => {
use ring::signature::{Ed25519KeyPair, KeyPair};
let (_, bytes) = pem::decode_vec(key.as_bytes())
.map_err(|_| anyhow::anyhow!("invalid key encoding"))?;
let key = Ed25519KeyPair::from_pkcs8_maybe_unchecked(&bytes).context("parse key")?;
if *key.public_key().as_ref() != *pubkey {
bail!("private key file does not match certificate")
}
}
oid => bail!("unknown TLS key type: {oid}"),
}
Ok(())

View File

@@ -1090,6 +1090,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
storage_controller: None,
control_plane_hooks_api: None,
generate_local_ssl_certs: false,
generate_compute_ssl_certs: false,
}
};

View File

@@ -54,7 +54,6 @@ use compute_api::requests::{
};
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TerminateResponse,
TlsConfig,
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol,
@@ -213,8 +212,13 @@ impl ComputeControlPlane {
let internal_http_port = internal_http_port.unwrap_or_else(|| external_http_port + 1);
let compute_ctl_config = ComputeCtlConfig {
jwks: Self::create_jwks_from_pem(&self.env.read_public_key()?)?,
tls: None::<TlsConfig>,
tls: self.env.get_tls_config()?,
};
let mut features = vec![];
if compute_ctl_config.tls.is_some() {
features.push(ComputeFeature::TlsExperimental);
}
let ep = Arc::new(Endpoint {
endpoint_id: endpoint_id.to_owned(),
pg_address: SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), pg_port),
@@ -241,7 +245,7 @@ impl ComputeControlPlane {
drop_subscriptions_before_start,
grpc,
reconfigure_concurrency: 1,
features: vec![],
features: features.clone(),
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
@@ -263,7 +267,7 @@ impl ComputeControlPlane {
skip_pg_catalog_updates,
drop_subscriptions_before_start,
reconfigure_concurrency: 1,
features: vec![],
features,
cluster: None,
compute_ctl_config,
privileged_role_name,

View File

@@ -12,6 +12,7 @@ use std::{env, fs};
use anyhow::{Context, bail};
use clap::ValueEnum;
use compute_api::responses::TlsConfig;
use pageserver_api::config::PostHogConfig;
use pem::Pem;
use postgres_backend::AuthType;
@@ -96,6 +97,9 @@ pub struct LocalEnv {
/// Flag to generate SSL certificates for components that need it.
/// Also generates root CA certificate that is used to sign all other certificates.
pub generate_local_ssl_certs: bool,
/// Flag to generate SSL certificates for compute.
pub generate_compute_ssl_certs: bool,
}
/// On-disk state stored in `.neon/config`.
@@ -124,6 +128,10 @@ pub struct OnDiskConfig {
// to load new config file. May be removed after this field is in release branch.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub generate_local_ssl_certs: bool,
// Note: skip serializing because in compat tests old storage controller fails
// to load new config file. May be removed after this field is in release branch.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub generate_compute_ssl_certs: bool,
}
fn fail_if_pageservers_field_specified<'de, D>(_: D) -> Result<Vec<PageServerConf>, D::Error>
@@ -153,6 +161,7 @@ pub struct NeonLocalInitConf {
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub generate_local_ssl_certs: bool,
pub generate_compute_ssl_certs: bool,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -545,6 +554,33 @@ impl LocalEnv {
)
}
fn compute_ssl_paths(&self) -> Option<(PathBuf, PathBuf)> {
if self.generate_compute_ssl_certs {
Some((
self.base_data_dir.join("compute_server.crt"),
self.base_data_dir.join("compute_server.key"),
))
} else {
None
}
}
pub fn generate_compute_ssl_cert(&self) -> anyhow::Result<()> {
self.generate_ssl_ca_cert()?;
let (cert_path, key_path) = self.compute_ssl_paths().unwrap();
if !fs::exists(&cert_path)? {
generate_ssl_cert(
&cert_path,
&key_path,
self.ssl_ca_cert_path().unwrap().as_path(),
self.ssl_ca_key_path().unwrap().as_path(),
)?;
}
Ok(())
}
/// Creates HTTP client with local SSL CA certificates.
pub fn create_http_client(&self) -> reqwest::Client {
let ssl_ca_certs = self.ssl_ca_cert_path().map(|ssl_ca_file| {
@@ -674,6 +710,7 @@ impl LocalEnv {
control_plane_compute_hook_api: _,
branch_name_mappings,
generate_local_ssl_certs,
generate_compute_ssl_certs,
endpoint_storage,
} = on_disk_config;
LocalEnv {
@@ -691,6 +728,7 @@ impl LocalEnv {
control_plane_hooks_api,
branch_name_mappings,
generate_local_ssl_certs,
generate_compute_ssl_certs,
endpoint_storage,
}
};
@@ -807,6 +845,7 @@ impl LocalEnv {
control_plane_compute_hook_api: None,
branch_name_mappings: self.branch_name_mappings.clone(),
generate_local_ssl_certs: self.generate_local_ssl_certs,
generate_compute_ssl_certs: self.generate_compute_ssl_certs,
endpoint_storage: self.endpoint_storage.clone(),
},
)
@@ -861,6 +900,21 @@ impl LocalEnv {
Ok(pem)
}
/// Get the TLS config if set.
pub fn get_tls_config(&self) -> anyhow::Result<Option<TlsConfig>> {
match self.compute_ssl_paths() {
Some((cert_path, key_path)) => {
self.generate_compute_ssl_cert()?;
Ok(Some(TlsConfig {
key_path: key_path.to_str().context("utf8")?.to_string(),
cert_path: cert_path.to_str().context("utf8")?.to_string(),
}))
}
None => Ok(None),
}
}
/// Materialize the [`NeonLocalInitConf`] to disk. Called during [`neon_local init`].
pub fn init(conf: NeonLocalInitConf, force: &InitForceMode) -> anyhow::Result<()> {
let base_path = base_path();
@@ -913,6 +967,7 @@ impl LocalEnv {
safekeepers,
control_plane_api,
generate_local_ssl_certs,
generate_compute_ssl_certs,
control_plane_hooks_api,
endpoint_storage,
} = conf;
@@ -966,12 +1021,16 @@ impl LocalEnv {
control_plane_hooks_api,
branch_name_mappings: Default::default(),
generate_local_ssl_certs,
generate_compute_ssl_certs,
endpoint_storage,
};
if generate_local_ssl_certs {
env.generate_ssl_ca_cert()?;
}
if generate_compute_ssl_certs {
env.generate_compute_ssl_cert()?;
}
// create endpoints dir
fs::create_dir_all(env.endpoints_path())?;

View File

@@ -506,6 +506,8 @@ class NeonEnvBuilder:
# Flag to use https listener in storage broker, generate local ssl certs,
# and force pageservers and safekeepers to use https for storage broker api.
self.use_https_storage_broker_api: bool = False
# Flag to enable TLS for computes
self.use_compute_tls: bool = False
self.pageserver_virtual_file_io_engine: str | None = pageserver_virtual_file_io_engine
self.pageserver_get_vectored_concurrent_io: str | None = (
@@ -1112,11 +1114,13 @@ class NeonEnv:
self.initial_tenant = config.initial_tenant
self.initial_timeline = config.initial_timeline
self.generate_compute_ssl_certs = config.use_compute_tls
self.generate_local_ssl_certs = (
config.use_https_pageserver_api
or config.use_https_safekeeper_api
or config.use_https_storage_controller_api
or config.use_https_storage_broker_api
or config.use_compute_tls
)
self.ssl_ca_file = (
self.repo_dir.joinpath("rootCA.crt") if self.generate_local_ssl_certs else None
@@ -1199,6 +1203,7 @@ class NeonEnv:
"listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}",
},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
"generate_compute_ssl_certs": self.generate_compute_ssl_certs,
}
if config.use_https_storage_broker_api:

View File

@@ -234,3 +234,20 @@ def test_storage_broker_https_api(neon_env_builder: NeonEnvBuilder):
workload.init()
workload.write_rows(10)
workload.validate()
def test_compute_tls(
neon_env_builder: NeonEnvBuilder,
):
neon_env_builder.use_compute_tls = True
env = neon_env_builder.init_start()
env.create_branch("test_compute_tls")
with env.endpoints.create_start("test_compute_tls") as endpoint:
res = endpoint.safe_psql(
"select ssl from pg_stat_ssl where pid = pg_backend_pid();",
sslmode="verify-full",
sslrootcert=env.ssl_ca_file,
)
assert res == [(True,)]