diff --git a/compute_tools/src/tls.rs b/compute_tools/src/tls.rs index af4299b7aa..ed335e4bca 100644 --- a/compute_tools/src/tls.rs +++ b/compute_tools/src/tls.rs @@ -3,7 +3,6 @@ use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration}; use anyhow::{Context, Result, bail}; use compute_api::responses::TlsConfig; use ring::digest; -use x509_cert::Certificate; #[derive(Clone, Copy)] pub struct CertDigest(digest::Digest); @@ -105,7 +104,10 @@ pub fn update_key_path_blocking(pg_data: &Path, key_pair: &KeyPair) -> Result<() } fn verify_key_cert(key: &str, cert: &str) -> Result<()> { + use x509_cert::Certificate; use x509_cert::der::oid::db::rfc5912::ECDSA_WITH_SHA_256; + use x509_cert::der::oid::db::rfc8410::ID_ED_25519; + use x509_cert::der::pem; let certs = Certificate::load_pem_chain(cert.as_bytes()) .context("decoding PEM encoded certificates")?; @@ -116,22 +118,30 @@ fn verify_key_cert(key: &str, cert: &str) -> Result<()> { bail!("no certificates found"); }; + let pubkey = cert + .tbs_certificate + .subject_public_key_info + .subject_public_key + .raw_bytes(); + match cert.signature_algorithm.oid { ECDSA_WITH_SHA_256 => { let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?; - - let a = key.public_key().to_sec1_bytes(); - let b = cert - .tbs_certificate - .subject_public_key_info - .subject_public_key - .raw_bytes(); - - if *a != *b { + if *key.public_key().to_sec1_bytes() != *pubkey { bail!("private key file does not match certificate") } } - _ => bail!("unknown TLS key type"), + ID_ED_25519 => { + use ring::signature::{Ed25519KeyPair, KeyPair}; + + let (_, bytes) = pem::decode_vec(key.as_bytes()) + .map_err(|_| anyhow::anyhow!("invalid key encoding"))?; + let key = Ed25519KeyPair::from_pkcs8_maybe_unchecked(&bytes).context("parse key")?; + if *key.public_key().as_ref() != *pubkey { + bail!("private key file does not match certificate") + } + } + oid => bail!("unknown TLS key type: {oid}"), } Ok(()) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 95500b0b18..d05fb4a7ce 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -1090,6 +1090,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result { storage_controller: None, control_plane_hooks_api: None, generate_local_ssl_certs: false, + generate_compute_ssl_certs: false, } }; diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 3065e202ab..1ae9483d50 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -54,7 +54,6 @@ use compute_api::requests::{ }; use compute_api::responses::{ ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TerminateResponse, - TlsConfig, }; use compute_api::spec::{ Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol, @@ -213,8 +212,13 @@ impl ComputeControlPlane { let internal_http_port = internal_http_port.unwrap_or_else(|| external_http_port + 1); let compute_ctl_config = ComputeCtlConfig { jwks: Self::create_jwks_from_pem(&self.env.read_public_key()?)?, - tls: None::, + tls: self.env.get_tls_config()?, }; + let mut features = vec![]; + if compute_ctl_config.tls.is_some() { + features.push(ComputeFeature::TlsExperimental); + } + let ep = Arc::new(Endpoint { endpoint_id: endpoint_id.to_owned(), pg_address: SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), pg_port), @@ -241,7 +245,7 @@ impl ComputeControlPlane { drop_subscriptions_before_start, grpc, reconfigure_concurrency: 1, - features: vec![], + features: features.clone(), cluster: None, compute_ctl_config: compute_ctl_config.clone(), privileged_role_name: privileged_role_name.clone(), @@ -263,7 +267,7 @@ impl ComputeControlPlane { skip_pg_catalog_updates, drop_subscriptions_before_start, reconfigure_concurrency: 1, - features: vec![], + features, cluster: None, compute_ctl_config, privileged_role_name, diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index d34dd39f61..3d2034d816 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -12,6 +12,7 @@ use std::{env, fs}; use anyhow::{Context, bail}; use clap::ValueEnum; +use compute_api::responses::TlsConfig; use pageserver_api::config::PostHogConfig; use pem::Pem; use postgres_backend::AuthType; @@ -96,6 +97,9 @@ pub struct LocalEnv { /// Flag to generate SSL certificates for components that need it. /// Also generates root CA certificate that is used to sign all other certificates. pub generate_local_ssl_certs: bool, + + /// Flag to generate SSL certificates for compute. + pub generate_compute_ssl_certs: bool, } /// On-disk state stored in `.neon/config`. @@ -124,6 +128,10 @@ pub struct OnDiskConfig { // to load new config file. May be removed after this field is in release branch. #[serde(skip_serializing_if = "std::ops::Not::not")] pub generate_local_ssl_certs: bool, + // Note: skip serializing because in compat tests old storage controller fails + // to load new config file. May be removed after this field is in release branch. + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub generate_compute_ssl_certs: bool, } fn fail_if_pageservers_field_specified<'de, D>(_: D) -> Result, D::Error> @@ -153,6 +161,7 @@ pub struct NeonLocalInitConf { pub control_plane_api: Option, pub control_plane_hooks_api: Option, pub generate_local_ssl_certs: bool, + pub generate_compute_ssl_certs: bool, } #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] @@ -545,6 +554,33 @@ impl LocalEnv { ) } + fn compute_ssl_paths(&self) -> Option<(PathBuf, PathBuf)> { + if self.generate_compute_ssl_certs { + Some(( + self.base_data_dir.join("compute_server.crt"), + self.base_data_dir.join("compute_server.key"), + )) + } else { + None + } + } + + pub fn generate_compute_ssl_cert(&self) -> anyhow::Result<()> { + self.generate_ssl_ca_cert()?; + + let (cert_path, key_path) = self.compute_ssl_paths().unwrap(); + if !fs::exists(&cert_path)? { + generate_ssl_cert( + &cert_path, + &key_path, + self.ssl_ca_cert_path().unwrap().as_path(), + self.ssl_ca_key_path().unwrap().as_path(), + )?; + } + + Ok(()) + } + /// Creates HTTP client with local SSL CA certificates. pub fn create_http_client(&self) -> reqwest::Client { let ssl_ca_certs = self.ssl_ca_cert_path().map(|ssl_ca_file| { @@ -674,6 +710,7 @@ impl LocalEnv { control_plane_compute_hook_api: _, branch_name_mappings, generate_local_ssl_certs, + generate_compute_ssl_certs, endpoint_storage, } = on_disk_config; LocalEnv { @@ -691,6 +728,7 @@ impl LocalEnv { control_plane_hooks_api, branch_name_mappings, generate_local_ssl_certs, + generate_compute_ssl_certs, endpoint_storage, } }; @@ -807,6 +845,7 @@ impl LocalEnv { control_plane_compute_hook_api: None, branch_name_mappings: self.branch_name_mappings.clone(), generate_local_ssl_certs: self.generate_local_ssl_certs, + generate_compute_ssl_certs: self.generate_compute_ssl_certs, endpoint_storage: self.endpoint_storage.clone(), }, ) @@ -861,6 +900,21 @@ impl LocalEnv { Ok(pem) } + /// Get the TLS config if set. + pub fn get_tls_config(&self) -> anyhow::Result> { + match self.compute_ssl_paths() { + Some((cert_path, key_path)) => { + self.generate_compute_ssl_cert()?; + + Ok(Some(TlsConfig { + key_path: key_path.to_str().context("utf8")?.to_string(), + cert_path: cert_path.to_str().context("utf8")?.to_string(), + })) + } + None => Ok(None), + } + } + /// Materialize the [`NeonLocalInitConf`] to disk. Called during [`neon_local init`]. pub fn init(conf: NeonLocalInitConf, force: &InitForceMode) -> anyhow::Result<()> { let base_path = base_path(); @@ -913,6 +967,7 @@ impl LocalEnv { safekeepers, control_plane_api, generate_local_ssl_certs, + generate_compute_ssl_certs, control_plane_hooks_api, endpoint_storage, } = conf; @@ -966,12 +1021,16 @@ impl LocalEnv { control_plane_hooks_api, branch_name_mappings: Default::default(), generate_local_ssl_certs, + generate_compute_ssl_certs, endpoint_storage, }; if generate_local_ssl_certs { env.generate_ssl_ca_cert()?; } + if generate_compute_ssl_certs { + env.generate_compute_ssl_cert()?; + } // create endpoints dir fs::create_dir_all(env.endpoints_path())?; diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 7f59547c73..b2f42c2cf6 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -506,6 +506,8 @@ class NeonEnvBuilder: # Flag to use https listener in storage broker, generate local ssl certs, # and force pageservers and safekeepers to use https for storage broker api. self.use_https_storage_broker_api: bool = False + # Flag to enable TLS for computes + self.use_compute_tls: bool = False self.pageserver_virtual_file_io_engine: str | None = pageserver_virtual_file_io_engine self.pageserver_get_vectored_concurrent_io: str | None = ( @@ -1112,11 +1114,13 @@ class NeonEnv: self.initial_tenant = config.initial_tenant self.initial_timeline = config.initial_timeline + self.generate_compute_ssl_certs = config.use_compute_tls self.generate_local_ssl_certs = ( config.use_https_pageserver_api or config.use_https_safekeeper_api or config.use_https_storage_controller_api or config.use_https_storage_broker_api + or config.use_compute_tls ) self.ssl_ca_file = ( self.repo_dir.joinpath("rootCA.crt") if self.generate_local_ssl_certs else None @@ -1199,6 +1203,7 @@ class NeonEnv: "listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}", }, "generate_local_ssl_certs": self.generate_local_ssl_certs, + "generate_compute_ssl_certs": self.generate_compute_ssl_certs, } if config.use_https_storage_broker_api: diff --git a/test_runner/regress/test_ssl.py b/test_runner/regress/test_ssl.py index 62879834c3..cda35f45b0 100644 --- a/test_runner/regress/test_ssl.py +++ b/test_runner/regress/test_ssl.py @@ -234,3 +234,20 @@ def test_storage_broker_https_api(neon_env_builder: NeonEnvBuilder): workload.init() workload.write_rows(10) workload.validate() + + +def test_compute_tls( + neon_env_builder: NeonEnvBuilder, +): + neon_env_builder.use_compute_tls = True + env = neon_env_builder.init_start() + + env.create_branch("test_compute_tls") + + with env.endpoints.create_start("test_compute_tls") as endpoint: + res = endpoint.safe_psql( + "select ssl from pg_stat_ssl where pid = pg_backend_pid();", + sslmode="verify-full", + sslrootcert=env.ssl_ca_file, + ) + assert res == [(True,)]