diff --git a/Cargo.lock b/Cargo.lock index f552e0a1bb..6d8c276bc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2809,6 +2809,7 @@ name = "http-utils" version = "0.1.0" dependencies = [ "anyhow", + "arc-swap", "bytes", "camino", "fail", @@ -2821,6 +2822,7 @@ dependencies = [ "pprof", "regex", "routerify", + "rustls 0.23.18", "rustls-pemfile 2.1.1", "serde", "serde_json", diff --git a/libs/http-utils/Cargo.toml b/libs/http-utils/Cargo.toml index 331ae4a9b8..6d24ee352a 100644 --- a/libs/http-utils/Cargo.toml +++ b/libs/http-utils/Cargo.toml @@ -6,6 +6,7 @@ license.workspace = true [dependencies] anyhow.workspace = true +arc-swap.workspace = true bytes.workspace = true camino.workspace = true fail.workspace = true @@ -18,14 +19,15 @@ pprof.workspace = true regex.workspace = true routerify.workspace = true rustls-pemfile.workspace = true -serde.workspace = true +rustls.workspace = true serde_json.workspace = true serde_path_to_error.workspace = true +serde.workspace = true thiserror.workspace = true -tracing.workspace = true -tokio.workspace = true tokio-rustls.workspace = true tokio-util.workspace = true +tokio.workspace = true +tracing.workspace = true url.workspace = true uuid.workspace = true diff --git a/libs/http-utils/src/tls_certs.rs b/libs/http-utils/src/tls_certs.rs index ad9b989956..0c18d84d98 100644 --- a/libs/http-utils/src/tls_certs.rs +++ b/libs/http-utils/src/tls_certs.rs @@ -1,24 +1,124 @@ +use std::{sync::Arc, time::Duration}; + use anyhow::Context; +use arc_swap::ArcSwap; use camino::Utf8Path; -use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::{ + pki_types::{CertificateDer, PrivateKeyDer}, + server::{ClientHello, ResolvesServerCert}, + sign::CertifiedKey, +}; -pub fn load_cert_chain(filename: &Utf8Path) -> anyhow::Result>> { - let file = std::fs::File::open(filename) - .context(format!("Failed to open certificate file {filename:?}"))?; - let mut reader = std::io::BufReader::new(file); +pub async fn load_cert_chain(filename: &Utf8Path) -> anyhow::Result>> { + let cert_data = tokio::fs::read(filename) + .await + .context(format!("failed reading certificate file {filename:?}"))?; + let mut reader = std::io::Cursor::new(&cert_data); - Ok(rustls_pemfile::certs(&mut reader).collect::, _>>()?) + let cert_chain = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .context(format!("failed parsing certificate from file {filename:?}"))?; + + Ok(cert_chain) } -pub fn load_private_key(filename: &Utf8Path) -> anyhow::Result> { - let file = std::fs::File::open(filename) - .context(format!("Failed to open private key file {filename:?}"))?; - let mut reader = std::io::BufReader::new(file); +pub async fn load_private_key(filename: &Utf8Path) -> anyhow::Result> { + let key_data = tokio::fs::read(filename) + .await + .context(format!("failed reading private key file {filename:?}"))?; + let mut reader = std::io::Cursor::new(&key_data); - let key = rustls_pemfile::private_key(&mut reader)?; + let key = rustls_pemfile::private_key(&mut reader) + .context(format!("failed parsing private key from file {filename:?}"))?; key.ok_or(anyhow::anyhow!( "no private key found in {}", filename.as_str(), )) } + +pub async fn load_certified_key( + key_filename: &Utf8Path, + cert_filename: &Utf8Path, +) -> anyhow::Result { + let cert_chain = load_cert_chain(cert_filename).await?; + let key = load_private_key(key_filename).await?; + + let key = rustls::crypto::ring::default_provider() + .key_provider + .load_private_key(key)?; + + let certified_key = CertifiedKey::new(cert_chain, key); + certified_key.keys_match()?; + Ok(certified_key) +} + +/// Implementation of [`rustls::server::ResolvesServerCert`] which reloads certificates from +/// the disk periodically. +#[derive(Debug)] +pub struct ReloadingCertificateResolver { + certified_key: ArcSwap, +} + +impl ReloadingCertificateResolver { + /// Creates a new Resolver by loading certificate and private key from FS and + /// creating tokio::task to reload them with provided reload_period. + pub async fn new( + key_filename: &Utf8Path, + cert_filename: &Utf8Path, + reload_period: Duration, + ) -> anyhow::Result> { + let this = Arc::new(Self { + certified_key: ArcSwap::from_pointee( + load_certified_key(key_filename, cert_filename).await?, + ), + }); + + tokio::spawn({ + let weak_this = Arc::downgrade(&this); + let key_filename = key_filename.to_owned(); + let cert_filename = cert_filename.to_owned(); + async move { + let start = tokio::time::Instant::now() + reload_period; + let mut interval = tokio::time::interval_at(start, reload_period); + let mut last_reload_failed = false; + loop { + interval.tick().await; + let this = match weak_this.upgrade() { + Some(this) => this, + None => break, // Resolver has been destroyed, exit. + }; + match load_certified_key(&key_filename, &cert_filename).await { + Ok(new_certified_key) => { + if new_certified_key.cert == this.certified_key.load().cert { + tracing::debug!("Certificate has not changed since last reloading"); + } else { + tracing::info!("Certificate has been reloaded"); + this.certified_key.store(Arc::new(new_certified_key)); + } + last_reload_failed = false; + } + Err(err) => { + // Note: Reloading certs may fail if it conflicts with the script updating + // the files at the same time. Warn only if the error is persistent. + if last_reload_failed { + tracing::warn!("Error reloading certificate: {err:#}"); + } else { + tracing::info!("Error reloading certificate: {err:#}"); + } + last_reload_failed = true; + } + } + } + } + }); + + Ok(this) + } +} + +impl ResolvesServerCert for ReloadingCertificateResolver { + fn resolve(&self, _client_hello: ClientHello<'_>) -> Option> { + Some(self.certified_key.load_full()) + } +} diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 2c483c9823..0d39a287c9 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -61,6 +61,8 @@ pub struct ConfigToml { pub listen_https_addr: Option, pub ssl_key_file: Utf8PathBuf, pub ssl_cert_file: Utf8PathBuf, + #[serde(with = "humantime_serde")] + pub ssl_cert_reload_period: Duration, pub ssl_ca_file: Option, pub availability_zone: Option, #[serde(with = "humantime_serde")] @@ -440,6 +442,7 @@ impl Default for ConfigToml { listen_https_addr: (None), ssl_key_file: Utf8PathBuf::from(DEFAULT_SSL_KEY_FILE), ssl_cert_file: Utf8PathBuf::from(DEFAULT_SSL_CERT_FILE), + ssl_cert_reload_period: Duration::from_secs(60), ssl_ca_file: None, availability_zone: (None), wait_lsn_timeout: (humantime::parse_duration(DEFAULT_WAIT_LSN_TIMEOUT) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 6f3af1125f..4cfc0c24f8 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -12,6 +12,7 @@ use std::time::Duration; use anyhow::{Context, anyhow}; use camino::Utf8Path; use clap::{Arg, ArgAction, Command}; +use http_utils::tls_certs::ReloadingCertificateResolver; use metrics::launch_timestamp::{LaunchTimestamp, set_launch_timestamp_metric}; use metrics::set_build_info_metric; use nix::sys::socket::{setsockopt, sockopt}; @@ -622,12 +623,15 @@ fn start_pageserver( let https_task = match https_listener { Some(https_listener) => { - let certs = http_utils::tls_certs::load_cert_chain(&conf.ssl_cert_file)?; - let key = http_utils::tls_certs::load_private_key(&conf.ssl_key_file)?; + let resolver = MGMT_REQUEST_RUNTIME.block_on(ReloadingCertificateResolver::new( + &conf.ssl_key_file, + &conf.ssl_cert_file, + conf.ssl_cert_reload_period, + ))?; let server_config = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_cert_resolver(resolver); let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index d35b1748ca..8f05daf5f5 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -56,8 +56,16 @@ pub struct PageServerConf { /// Example: 127.0.0.1:9899 pub listen_https_addr: Option, + /// Path to a file with certificate's private key for https API. + /// Default: server.key pub ssl_key_file: Utf8PathBuf, + /// Path to a file with a X509 certificate for https API. + /// Default: server.crt pub ssl_cert_file: Utf8PathBuf, + /// Period to reload certificate and private key from files. + /// Default: 60s. + pub ssl_cert_reload_period: Duration, + /// Trusted root CA certificate to use in https APIs. pub ssl_ca_cert: Option, /// Current availability zone. Used for traffic metrics. @@ -326,6 +334,7 @@ impl PageServerConf { listen_https_addr, ssl_key_file, ssl_cert_file, + ssl_cert_reload_period, ssl_ca_file, availability_zone, wait_lsn_timeout, @@ -388,6 +397,7 @@ impl PageServerConf { listen_https_addr, ssl_key_file, ssl_cert_file, + ssl_cert_reload_period, availability_zone, wait_lsn_timeout, wal_redo_timeout, diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 9ca79de179..d9b1b76a4c 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -21,7 +21,7 @@ use safekeeper::defaults::{ DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT, DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY, DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR, DEFAULT_SSL_CERT_FILE, - DEFAULT_SSL_KEY_FILE, + DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE, }; use safekeeper::{ BROKER_RUNTIME, GlobalTimelines, HTTP_RUNTIME, SafeKeeperConf, WAL_SERVICE_RUNTIME, broker, @@ -214,6 +214,9 @@ struct Args { /// Path to a file with a X509 certificate for https API. #[arg(long, default_value = DEFAULT_SSL_CERT_FILE)] ssl_cert_file: Utf8PathBuf, + /// Period to reload certificate and private key from files. + #[arg(long, value_parser = humantime::parse_duration, default_value = DEFAULT_SSL_CERT_RELOAD_PERIOD)] + pub ssl_cert_reload_period: Duration, /// Trusted root CA certificate to use in https APIs. #[arg(long)] ssl_ca_file: Option, @@ -394,6 +397,7 @@ async fn main() -> anyhow::Result<()> { max_delta_for_fanout: args.max_delta_for_fanout, ssl_key_file: args.ssl_key_file, ssl_cert_file: args.ssl_cert_file, + ssl_cert_reload_period: args.ssl_cert_reload_period, ssl_ca_cert, }); diff --git a/safekeeper/src/http/mod.rs b/safekeeper/src/http/mod.rs index 4908863a4b..003a75faa6 100644 --- a/safekeeper/src/http/mod.rs +++ b/safekeeper/src/http/mod.rs @@ -1,6 +1,7 @@ pub mod routes; use std::sync::Arc; +use http_utils::tls_certs::ReloadingCertificateResolver; pub use routes::make_router; pub use safekeeper_api::models; use tokio_util::sync::CancellationToken; @@ -29,12 +30,16 @@ pub async fn task_main_https( https_listener: std::net::TcpListener, global_timelines: Arc, ) -> anyhow::Result<()> { - let certs = http_utils::tls_certs::load_cert_chain(&conf.ssl_cert_file)?; - let key = http_utils::tls_certs::load_private_key(&conf.ssl_key_file)?; + let cert_resolver = ReloadingCertificateResolver::new( + &conf.ssl_key_file, + &conf.ssl_cert_file, + conf.ssl_cert_reload_period, + ) + .await?; let server_config = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_cert_resolver(cert_resolver); let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 7c81f77e55..21c8806349 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -73,6 +73,7 @@ pub mod defaults { pub const DEFAULT_SSL_KEY_FILE: &str = "server.key"; pub const DEFAULT_SSL_CERT_FILE: &str = "server.crt"; + pub const DEFAULT_SSL_CERT_RELOAD_PERIOD: &str = "60s"; } #[derive(Debug, Clone)] @@ -118,6 +119,7 @@ pub struct SafeKeeperConf { pub max_delta_for_fanout: Option, pub ssl_key_file: Utf8PathBuf, pub ssl_cert_file: Utf8PathBuf, + pub ssl_cert_reload_period: Duration, pub ssl_ca_cert: Option, } @@ -166,6 +168,7 @@ impl SafeKeeperConf { max_delta_for_fanout: None, ssl_key_file: Utf8PathBuf::from(defaults::DEFAULT_SSL_KEY_FILE), ssl_cert_file: Utf8PathBuf::from(defaults::DEFAULT_SSL_CERT_FILE), + ssl_cert_reload_period: Duration::from_secs(60), ssl_ca_cert: None, } } diff --git a/safekeeper/tests/walproposer_sim/safekeeper.rs b/safekeeper/tests/walproposer_sim/safekeeper.rs index 0dfdafcc51..65dfa64512 100644 --- a/safekeeper/tests/walproposer_sim/safekeeper.rs +++ b/safekeeper/tests/walproposer_sim/safekeeper.rs @@ -182,6 +182,7 @@ pub fn run_server(os: NodeOs, disk: Arc) -> Result<()> { max_delta_for_fanout: None, ssl_key_file: Utf8PathBuf::from(""), ssl_cert_file: Utf8PathBuf::from(""), + ssl_cert_reload_period: Duration::ZERO, ssl_ca_cert: None, }; diff --git a/storage_controller/src/main.rs b/storage_controller/src/main.rs index 78f415b19a..5fcf66b464 100644 --- a/storage_controller/src/main.rs +++ b/storage_controller/src/main.rs @@ -7,6 +7,7 @@ use anyhow::{Context, anyhow}; use camino::Utf8PathBuf; use clap::Parser; use futures::future::OptionFuture; +use http_utils::tls_certs::ReloadingCertificateResolver; use hyper0::Uri; use metrics::BuildInfo; use metrics::launch_timestamp::LaunchTimestamp; @@ -43,6 +44,7 @@ pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:21\0 const DEFAULT_SSL_KEY_FILE: &str = "server.key"; const DEFAULT_SSL_CERT_FILE: &str = "server.crt"; +const DEFAULT_SSL_CERT_RELOAD_PERIOD: &str = "60s"; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -195,6 +197,9 @@ struct Cli { /// Path to a file with a X509 certificate for https API. #[arg(long, default_value = DEFAULT_SSL_CERT_FILE)] ssl_cert_file: Utf8PathBuf, + /// Period to reload certificate and private key from files. + #[arg(long, default_value = DEFAULT_SSL_CERT_RELOAD_PERIOD)] + ssl_cert_reload_period: humantime::Duration, /// Trusted root CA certificate to use in https APIs. #[arg(long)] ssl_ca_file: Option, @@ -460,12 +465,17 @@ async fn async_main() -> anyhow::Result<()> { let https_server_task: OptionFuture<_> = match args.listen_https { Some(https_addr) => { let https_listener = tcp_listener::bind(https_addr)?; - let certs = http_utils::tls_certs::load_cert_chain(args.ssl_cert_file.as_path())?; - let key = http_utils::tls_certs::load_private_key(args.ssl_key_file.as_path())?; + + let resolver = ReloadingCertificateResolver::new( + &args.ssl_key_file, + &args.ssl_cert_file, + *args.ssl_cert_reload_period, + ) + .await?; let server_config = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_cert_resolver(resolver); let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); let https_server = diff --git a/test_runner/regress/test_ssl.py b/test_runner/regress/test_ssl.py index 2ca23ce6d5..9a7204ca17 100644 --- a/test_runner/regress/test_ssl.py +++ b/test_runner/regress/test_ssl.py @@ -1,3 +1,6 @@ +import os +import ssl + import pytest import requests from fixtures.neon_fixtures import NeonEnvBuilder, StorageControllerApiException @@ -79,3 +82,72 @@ def test_storage_controller_https_api(neon_env_builder: NeonEnvBuilder): addr = f"https://localhost:{env.storage_controller.port}/status" requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status() + + +def test_certificate_rotation(neon_env_builder: NeonEnvBuilder): + """ + Test that pageserver reloads certificates when they are updated on the disk. + Safekeepers and storage controller use the same server implementation, so + testing only pageserver is fine. + 1. Simple check that HTTPS API works. + 2. Check that the cert returned by the server matches the cert in file. + 3. Replace ps's cert (but not the key). + 4. Check that ps uses the old cert (because the new one doesn't match the key). + 5. Replace ps's key. + 6. Check that ps reloaded the cert and key and returns the new one. + """ + neon_env_builder.use_https_pageserver_api = True + # Speed up the test :) + neon_env_builder.pageserver_config_override = "ssl_cert_reload_period='100 ms'" + env = neon_env_builder.init_start() + + # We intentionally set an incorrect key/cert pair during the test to test this error. + env.pageserver.allowed_errors.append(".*Error reloading certificate.*") + + port = env.pageserver.service_port.https + assert port is not None + + # 1. Check if https works. + addr = f"https://localhost:{port}/v1/status" + requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status() + + ps_cert_path = env.pageserver.workdir / "server.crt" + ps_key_path = env.pageserver.workdir / "server.key" + ps_cert = open(ps_cert_path).read() + # We need another valid certificate to update to. + # Let's steal it from safekeeper. + sk_cert_path = env.safekeepers[0].data_dir / "server.crt" + sk_key_path = env.safekeepers[0].data_dir / "server.key" + sk_cert = open(sk_cert_path).read() + + # 2. Check that server's certificate match the cert in the file. + cur_cert = ssl.get_server_certificate(("localhost", port)) + assert cur_cert == ps_cert + + # 3. Replace ps's cert with sk's one. + os.rename(sk_cert_path, ps_cert_path) + + # Cert shouldn't be reloaded because it doesn't match private key. + def error_reloading_cert(): + assert env.pageserver.log_contains("Error reloading certificate: .* KeyMismatch") + + wait_until(error_reloading_cert) + + # 4. Check that it uses old cert. + requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status() + cur_cert = ssl.get_server_certificate(("localhost", port)) + assert cur_cert == ps_cert + + # 5. Replace ps's private key with sk's one. + os.rename(sk_key_path, ps_key_path) + + # Wait till ps reloads certificate. + def cert_reloaded(): + assert env.pageserver.log_contains("Certificate has been reloaded") + + wait_until(cert_reloaded) + + # 6. Check that server returns new cert. + requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status() + cur_cert = ssl.get_server_certificate(("localhost", port)) + assert cur_cert == sk_cert