diff --git a/Cargo.lock b/Cargo.lock index 1721c185f0..898ff1eabb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1309,6 +1309,7 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", + "indexmap 2.0.1", "jsonwebtoken", "regex", "remote_storage", @@ -1339,6 +1340,7 @@ dependencies = [ "flate2", "futures", "http 1.1.0", + "indexmap 2.0.1", "jsonwebtoken", "metrics", "nix 0.27.1", @@ -1347,17 +1349,20 @@ dependencies = [ "once_cell", "opentelemetry", "opentelemetry_sdk", + "p256 0.13.2", "postgres", "postgres_initdb", "regex", "remote_storage", "reqwest", + "ring", "rlimit", "rust-ini", "serde", "serde_json", "serde_with", "signal-hook", + "spki 0.7.3", "tar", "thiserror 1.0.69", "tokio", @@ -1377,6 +1382,7 @@ dependencies = [ "vm_monitor", "walkdir", "workspace_hack", + "x509-cert", "zstd", ] @@ -1801,6 +1807,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" dependencies = [ "const-oid", + "der_derive", + "flagset", "pem-rfc7468", "zeroize", ] @@ -1819,6 +1827,17 @@ dependencies = [ "rusticata-macros", ] +[[package]] +name = "der_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "deranged" version = "0.3.11" @@ -2282,6 +2301,12 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flagset" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3ea1ec5f8307826a5b71094dd91fc04d4ae75d5709b20ad351c7fb4815c86ec" + [[package]] name = "flate2" version = "1.0.26" @@ -6425,9 +6450,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -7135,6 +7160,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "tokio" version = "1.43.0" @@ -8387,12 +8433,15 @@ dependencies = [ "chrono", "clap", "clap_builder", + "const-oid", "crypto-bigint 0.5.5", "der 0.7.8", "deranged", "digest", "displaydoc", + "ecdsa 0.16.9", "either", + "elliptic-curve 0.13.8", "env_filter", "env_logger", "fail", @@ -8427,6 +8476,7 @@ dependencies = [ "num-rational", "num-traits", "once_cell", + "p256 0.13.2", "parquet", "prettyplease", "proc-macro2", @@ -8439,6 +8489,7 @@ dependencies = [ "reqwest", "rustls 0.23.18", "scopeguard", + "sec1 0.7.3", "serde", "serde_json", "sha2", @@ -8484,6 +8535,18 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "x509-cert" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94" +dependencies = [ + "const-oid", + "der 0.7.8", + "spki 0.7.3", + "tls_codec", +] + [[package]] name = "x509-certificate" version = "0.23.1" @@ -8612,9 +8675,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" dependencies = [ "serde", "zeroize_derive", diff --git a/Cargo.toml b/Cargo.toml index 7b86a64e9a..82fb463182 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,7 +112,7 @@ hyper0 = { package = "hyper", version = "0.14" } hyper = "1.4" hyper-util = "0.1" tokio-tungstenite = "0.21.0" -indexmap = "2" +indexmap = { version = "2", features = ["serde"] } indoc = "2" ipnet = "2.10.0" itertools = "0.10" diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 6e46185e36..d5483018b4 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -1735,6 +1735,8 @@ RUN set -e \ libevent-dev \ libtool \ pkg-config \ + libcurl4-openssl-dev \ + libssl-dev \ && apt clean && rm -rf /var/lib/apt/lists/* # Use `dist_man_MANS=` to skip manpage generation (which requires python3/pandoc) @@ -1743,7 +1745,7 @@ RUN set -e \ && git clone --recurse-submodules --depth 1 --branch ${PGBOUNCER_TAG} https://github.com/pgbouncer/pgbouncer.git pgbouncer \ && cd pgbouncer \ && ./autogen.sh \ - && ./configure --prefix=/usr/local/pgbouncer --without-openssl \ + && ./configure --prefix=/usr/local/pgbouncer \ && make -j $(nproc) dist_man_MANS= \ && make install dist_man_MANS= diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index dd2896714d..90951e7ddb 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -26,6 +26,7 @@ fail.workspace = true flate2.workspace = true futures.workspace = true http.workspace = true +indexmap.workspace = true jsonwebtoken.workspace = true metrics.workspace = true nix.workspace = true @@ -34,16 +35,19 @@ num_cpus.workspace = true once_cell.workspace = true opentelemetry.workspace = true opentelemetry_sdk.workspace = true +p256 = { version = "0.13", features = ["pem"] } postgres.workspace = true regex.workspace = true +reqwest = { workspace = true, features = ["json"] } +ring = "0.17" serde.workspace = true serde_with.workspace = true serde_json.workspace = true signal-hook.workspace = true +spki = { version = "0.7.3", features = ["std"] } tar.workspace = true tower.workspace = true tower-http.workspace = true -reqwest = { workspace = true, features = ["json"] } tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tokio-postgres.workspace = true tokio-util.workspace = true @@ -57,6 +61,7 @@ thiserror.workspace = true url.workspace = true uuid.workspace = true walkdir.workspace = true +x509-cert = { version = "0.2.5" } postgres_initdb.workspace = true compute_api.workspace = true diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index c2a3e38ed6..a0654ea0e4 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -41,6 +41,7 @@ use crate::rsyslog::configure_audit_rsyslog; use crate::spec::*; use crate::swap::resize_swap; use crate::sync_sk::{check_if_synced, ping_safekeeper}; +use crate::tls::watch_cert_for_changes; use crate::{config, extension_server, local_proxy}; pub static SYNC_SAFEKEEPERS_PID: AtomicU32 = AtomicU32::new(0); @@ -112,6 +113,7 @@ pub struct ComputeNode { // key: ext_archive_name, value: started download time, download_completed? pub ext_download_progress: RwLock, bool)>>, + pub compute_ctl_config: ComputeCtlConfig, } // store some metrics about download size that might impact startup time @@ -135,8 +137,6 @@ pub struct ComputeState { /// passed by the control plane with a /configure HTTP request. pub pspec: Option, - pub compute_ctl_config: ComputeCtlConfig, - /// If the spec is passed by a /configure request, 'startup_span' is the /// /configure request's tracing span. The main thread enters it when it /// processes the compute startup, so that the compute startup is considered @@ -160,7 +160,6 @@ impl ComputeState { last_active: None, error: None, pspec: None, - compute_ctl_config: ComputeCtlConfig::default(), startup_span: None, metrics: ComputeMetrics::default(), } @@ -314,7 +313,6 @@ impl ComputeNode { let pspec = ParsedSpec::try_from(cli_spec).map_err(|msg| anyhow::anyhow!(msg))?; new_state.pspec = Some(pspec); } - new_state.compute_ctl_config = compute_ctl_config; Ok(ComputeNode { params, @@ -323,6 +321,7 @@ impl ComputeNode { state: Mutex::new(new_state), state_changed: Condvar::new(), ext_download_progress: RwLock::new(HashMap::new()), + compute_ctl_config, }) } @@ -345,7 +344,7 @@ impl ComputeNode { // requests while configuration is still in progress. crate::http::server::Server::External { port: this.params.external_http_port, - jwks: this.state.lock().unwrap().compute_ctl_config.jwks.clone(), + config: this.compute_ctl_config.clone(), compute_id: this.params.compute_id.clone(), } .launch(&this); @@ -524,6 +523,16 @@ impl ComputeNode { // Collect all the tasks that must finish here let mut pre_tasks = tokio::task::JoinSet::new(); + // Make sure TLS certificates are properly loaded and in the right place. + if self.compute_ctl_config.tls.is_some() { + let this = self.clone(); + pre_tasks.spawn(async move { + this.watch_cert_for_changes().await; + + Ok::<(), anyhow::Error>(()) + }); + } + // If there are any remote extensions in shared_preload_libraries, start downloading them if pspec.spec.remote_extensions.is_some() { let (this, spec) = (self.clone(), pspec.spec.clone()); @@ -579,11 +588,13 @@ impl ComputeNode { if let Some(pgbouncer_settings) = &pspec.spec.pgbouncer_settings { info!("tuning pgbouncer"); + let pgbouncer_settings = pgbouncer_settings.clone(); + let tls_config = self.compute_ctl_config.tls.clone(); + // Spawn a background task to do the tuning, // so that we don't block the main thread that starts Postgres. - let pgbouncer_settings = pgbouncer_settings.clone(); let _handle = tokio::spawn(async move { - let res = tune_pgbouncer(pgbouncer_settings).await; + let res = tune_pgbouncer(pgbouncer_settings, tls_config).await; if let Err(err) = res { error!("error while tuning pgbouncer: {err:?}"); // Continue with the startup anyway @@ -1105,9 +1116,10 @@ impl ComputeNode { // Remove/create an empty pgdata directory and put configuration there. self.create_pgdata()?; config::write_postgres_conf( - &pgdata_path.join("postgresql.conf"), + pgdata_path, &pspec.spec, self.params.internal_http_port, + &self.compute_ctl_config.tls, )?; // Syncing safekeepers is only safe with primary nodes: if a primary @@ -1489,11 +1501,13 @@ impl ComputeNode { if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings { info!("tuning pgbouncer"); + let pgbouncer_settings = pgbouncer_settings.clone(); + let tls_config = self.compute_ctl_config.tls.clone(); + // Spawn a background task to do the tuning, // so that we don't block the main thread that starts Postgres. - let pgbouncer_settings = pgbouncer_settings.clone(); tokio::spawn(async move { - let res = tune_pgbouncer(pgbouncer_settings).await; + let res = tune_pgbouncer(pgbouncer_settings, tls_config).await; if let Err(err) = res { error!("error while tuning pgbouncer: {err:?}"); } @@ -1505,7 +1519,8 @@ impl ComputeNode { // Spawn a background task to do the configuration, // so that we don't block the main thread that starts Postgres. - let local_proxy = local_proxy.clone(); + let mut local_proxy = local_proxy.clone(); + local_proxy.tls = self.compute_ctl_config.tls.clone(); tokio::spawn(async move { if let Err(err) = local_proxy::configure(&local_proxy) { error!("error while configuring local_proxy: {err:?}"); @@ -1515,8 +1530,12 @@ impl ComputeNode { // Write new config let pgdata_path = Path::new(&self.params.pgdata); - let postgresql_conf_path = pgdata_path.join("postgresql.conf"); - config::write_postgres_conf(&postgresql_conf_path, &spec, self.params.internal_http_port)?; + config::write_postgres_conf( + pgdata_path, + &spec, + self.params.internal_http_port, + &self.compute_ctl_config.tls, + )?; if !spec.skip_pg_catalog_updates { let max_concurrent_connections = spec.reconfigure_concurrency; @@ -1587,6 +1606,56 @@ impl ComputeNode { Ok(()) } + pub async fn watch_cert_for_changes(self: Arc) { + // update status on cert renewal + if let Some(tls_config) = &self.compute_ctl_config.tls { + let tls_config = tls_config.clone(); + + // wait until the cert exists. + let mut cert_watch = watch_cert_for_changes(tls_config.cert_path.clone()).await; + + tokio::task::spawn_blocking(move || { + let handle = tokio::runtime::Handle::current(); + 'cert_update: loop { + // let postgres/pgbouncer/local_proxy know the new cert/key exists. + // we need to wait until it's configurable first. + + let mut state = self.state.lock().unwrap(); + 'status_update: loop { + match state.status { + // let's update the state to config pending + ComputeStatus::ConfigurationPending | ComputeStatus::Running => { + state.set_status( + ComputeStatus::ConfigurationPending, + &self.state_changed, + ); + break 'status_update; + } + + // exit loop + ComputeStatus::Failed + | ComputeStatus::TerminationPending + | ComputeStatus::Terminated => break 'cert_update, + + // wait + ComputeStatus::Init + | ComputeStatus::Configuration + | ComputeStatus::Empty => { + state = self.state_changed.wait(state).unwrap(); + } + } + } + drop(state); + + // wait for a new certificate update + if handle.block_on(cert_watch.changed()).is_err() { + break; + } + } + }); + } + } + /// Update the `last_active` in the shared state, but ensure that it's a more recent one. pub fn update_last_active(&self, last_active: Option>) { let mut state = self.state.lock().unwrap(); diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 0760568ff8..7aa7360f9d 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -6,11 +6,13 @@ use std::io::Write; use std::io::prelude::*; use std::path::Path; +use compute_api::responses::TlsConfig; use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption}; use crate::pg_helpers::{ GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value, }; +use crate::tls::{self, SERVER_CRT, SERVER_KEY}; /// Check that `line` is inside a text file and put it there if it is not. /// Create file if it doesn't exist. @@ -38,10 +40,12 @@ pub fn line_in_file(path: &Path, line: &str) -> Result { /// Create or completely rewrite configuration file specified by `path` pub fn write_postgres_conf( - path: &Path, + pgdata_path: &Path, spec: &ComputeSpec, extension_server_port: u16, + tls_config: &Option, ) -> Result<()> { + let path = pgdata_path.join("postgresql.conf"); // File::create() destroys the file content if it exists. let mut file = File::create(path)?; @@ -86,6 +90,20 @@ pub fn write_postgres_conf( )?; } + // tls + if let Some(tls_config) = tls_config { + writeln!(file, "ssl = on")?; + + // postgres requires the keyfile to be in a secure file, + // currently too complicated to ensure that at the VM level, + // so we just copy them to another file instead. :shrug: + tls::update_key_path_blocking(pgdata_path, tls_config); + + // these are the default, but good to be explicit. + writeln!(file, "ssl_cert_file = '{}'", SERVER_CRT)?; + writeln!(file, "ssl_key_file = '{}'", SERVER_KEY)?; + } + // Locales if cfg!(target_os = "macos") { writeln!(file, "lc_messages='C'")?; diff --git a/compute_tools/src/http/server.rs b/compute_tools/src/http/server.rs index b70b6c619c..10f767e97c 100644 --- a/compute_tools/src/http/server.rs +++ b/compute_tools/src/http/server.rs @@ -8,8 +8,8 @@ use axum::Router; use axum::middleware::{self}; use axum::response::IntoResponse; use axum::routing::{get, post}; +use compute_api::responses::ComputeCtlConfig; use http::StatusCode; -use jsonwebtoken::jwk::JwkSet; use tokio::net::TcpListener; use tower::ServiceBuilder; use tower_http::{ @@ -41,7 +41,7 @@ pub enum Server { }, External { port: u16, - jwks: JwkSet, + config: ComputeCtlConfig, compute_id: String, }, } @@ -79,7 +79,7 @@ impl From<&Server> for Router> { router } Server::External { - jwks, compute_id, .. + config, compute_id, .. } => { let unauthenticated_router = Router::>::new().route("/metrics", get(metrics::get_metrics)); @@ -95,7 +95,7 @@ impl From<&Server> for Router> { .route("/terminate", post(terminate::terminate)) .layer(AsyncRequireAuthorizationLayer::new(Authorize::new( compute_id.clone(), - jwks.clone(), + config.jwks.clone(), ))); router diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 5c78bbcd02..a681fad0b0 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -26,3 +26,4 @@ pub mod spec; mod spec_apply; pub mod swap; pub mod sync_sk; +pub mod tls; diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index dd8d8e9b8b..802e3e93d9 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -10,8 +10,10 @@ use std::str::FromStr; use std::time::{Duration, Instant}; use anyhow::{Result, bail}; +use compute_api::responses::TlsConfig; use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role}; use futures::StreamExt; +use indexmap::IndexMap; use ini::Ini; use notify::{RecursiveMode, Watcher}; use postgres::config::Config; @@ -406,7 +408,7 @@ pub fn create_pgdata(pgdata: &str) -> Result<()> { /// Update pgbouncer.ini with provided options fn update_pgbouncer_ini( - pgbouncer_config: HashMap, + pgbouncer_config: IndexMap, pgbouncer_ini_path: &str, ) -> Result<()> { let mut conf = Ini::load_from_file(pgbouncer_ini_path)?; @@ -427,7 +429,10 @@ fn update_pgbouncer_ini( /// Tune pgbouncer. /// 1. Apply new config using pgbouncer admin console /// 2. Add new values to pgbouncer.ini to preserve them after restart -pub async fn tune_pgbouncer(pgbouncer_config: HashMap) -> Result<()> { +pub async fn tune_pgbouncer( + mut pgbouncer_config: IndexMap, + tls_config: Option, +) -> Result<()> { let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() { // for VMs use pgbouncer specific way to connect to // pgbouncer admin console without password @@ -473,19 +478,21 @@ pub async fn tune_pgbouncer(pgbouncer_config: HashMap) -> Result } }; - // Apply new config - for (option_name, value) in pgbouncer_config.iter() { - let query = format!("SET {}={}", option_name, value); - // keep this log line for debugging purposes - info!("Applying pgbouncer setting change: {}", query); + if let Some(tls_config) = tls_config { + // pgbouncer starts in a half-ok state if it cannot find these files. + // It will default to client_tls_sslmode=deny, which causes proxy to error. + // There is a small window at startup where these files don't yet exist in the VM. + // Best to wait until it exists. + loop { + if let Ok(true) = tokio::fs::try_exists(&tls_config.key_path).await { + break; + } + tokio::time::sleep(Duration::from_millis(500)).await + } - if let Err(err) = client.simple_query(&query).await { - // Don't fail on error, just print it into log - error!( - "Failed to apply pgbouncer setting change: {}, {}", - query, err - ); - }; + pgbouncer_config.insert("client_tls_cert_file".to_string(), tls_config.cert_path); + pgbouncer_config.insert("client_tls_key_file".to_string(), tls_config.key_path); + pgbouncer_config.insert("client_tls_sslmode".to_string(), "allow".to_string()); } // save values to pgbouncer.ini @@ -501,6 +508,13 @@ pub async fn tune_pgbouncer(pgbouncer_config: HashMap) -> Result }; update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?; + info!("Applying pgbouncer setting change"); + + if let Err(err) = client.simple_query("RELOAD").await { + // Don't fail on error, just print it into log + error!("Failed to apply pgbouncer setting change, {err}",); + }; + Ok(()) } diff --git a/compute_tools/src/tls.rs b/compute_tools/src/tls.rs new file mode 100644 index 0000000000..5a310d8ac4 --- /dev/null +++ b/compute_tools/src/tls.rs @@ -0,0 +1,118 @@ +use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration}; + +use anyhow::{Context, Result, bail}; +use compute_api::responses::TlsConfig; +use ring::digest; +use spki::ObjectIdentifier; +use spki::der::{Decode, PemReader}; +use x509_cert::Certificate; + +#[derive(Clone, Copy)] +pub struct CertDigest(digest::Digest); + +pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver { + let mut digest = compute_digest(&cert_path).await; + let (tx, rx) = tokio::sync::watch::channel(digest); + tokio::spawn(async move { + while !tx.is_closed() { + let new_digest = compute_digest(&cert_path).await; + if digest.0.as_ref() != new_digest.0.as_ref() { + digest = new_digest; + _ = tx.send(digest); + } + + tokio::time::sleep(Duration::from_secs(60)).await + } + }); + rx +} + +async fn compute_digest(cert_path: &str) -> CertDigest { + loop { + match try_compute_digest(cert_path).await { + Ok(d) => break d, + Err(e) => { + tracing::error!("could not read cert file {e:?}"); + tokio::time::sleep(Duration::from_secs(1)).await + } + } + } +} + +async fn try_compute_digest(cert_path: &str) -> Result { + let data = tokio::fs::read(cert_path).await?; + // sha256 is extremely collision resistent. can safely assume the digest to be unique + Ok(CertDigest(digest::digest(&digest::SHA256, &data))) +} + +pub const SERVER_CRT: &str = "server.crt"; +pub const SERVER_KEY: &str = "server.key"; + +pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) { + loop { + match try_update_key_path_blocking(pg_data, tls_config) { + Ok(()) => break, + Err(e) => { + tracing::error!("could not create key file {e:?}"); + std::thread::sleep(Duration::from_secs(1)) + } + } + } +} + +// Postgres requires the keypath be "secure". This means +// 1. Owned by the postgres user. +// 2. Have permission 600. +fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> { + let key = std::fs::read_to_string(&tls_config.key_path)?; + let crt = std::fs::read_to_string(&tls_config.cert_path)?; + + // to mitigate a race condition during renewal. + verify_key_cert(&key, &crt)?; + + let mut key_file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(pg_data.join(SERVER_KEY))?; + + let mut crt_file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(pg_data.join(SERVER_CRT))?; + + key_file.write_all(key.as_bytes())?; + crt_file.write_all(crt.as_bytes())?; + + Ok(()) +} + +fn verify_key_cert(key: &str, cert: &str) -> Result<()> { + const ECDSA_WITH_SHA256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.4.3.2"); + + let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?) + .context("decode cert")?; + + match cert.signature_algorithm.oid { + ECDSA_WITH_SHA256 => { + let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?; + + let a = key.public_key().to_sec1_bytes(); + let b = cert + .tbs_certificate + .subject_public_key_info + .subject_public_key + .raw_bytes(); + + if *a != *b { + bail!("private key file does not match certificate") + } + } + _ => bail!("unknown TLS key type"), + } + + Ok(()) +} diff --git a/libs/compute_api/Cargo.toml b/libs/compute_api/Cargo.toml index 0d1618c1b2..81b0cd19a1 100644 --- a/libs/compute_api/Cargo.toml +++ b/libs/compute_api/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] anyhow.workspace = true chrono.workspace = true +indexmap.workspace = true jsonwebtoken.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 3300fbf7dd..c8f6019c5c 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -139,6 +139,7 @@ pub struct ComputeCtlConfig { /// Set of JSON web keys that the compute can use to authenticate /// communication from the control plane. pub jwks: JwkSet, + pub tls: Option, } impl Default for ComputeCtlConfig { @@ -147,10 +148,17 @@ impl Default for ComputeCtlConfig { jwks: JwkSet { keys: Vec::default(), }, + tls: None, } } } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TlsConfig { + pub key_path: String, + pub cert_path: String, +} + /// Response of the `/computes/{compute_id}/spec` control-plane API. #[derive(Deserialize, Debug)] pub struct ControlPlaneSpecResponse { diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 77f2e1e631..af4264f8d2 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -5,12 +5,15 @@ //! and connect it to the storage nodes. use std::collections::HashMap; +use indexmap::IndexMap; use regex::Regex; use remote_storage::RemotePath; use serde::{Deserialize, Serialize}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; +use crate::responses::TlsConfig; + /// String type alias representing Postgres identifier and /// intended to be used for DB / role names. pub type PgIdent = String; @@ -125,7 +128,7 @@ pub struct ComputeSpec { // information about available remote extensions pub remote_extensions: Option, - pub pgbouncer_settings: Option>, + pub pgbouncer_settings: Option>, // Stripe size for pageserver sharding, in pages #[serde(default)] @@ -357,6 +360,9 @@ pub struct LocalProxySpec { #[serde(default)] #[serde(skip_serializing_if = "Option::is_none")] pub jwks: Option>, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub tls: Option, } #[derive(Clone, Debug, Deserialize, Serialize)] diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index dedd225cba..ee7f6ffcd7 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{Context, bail, ensure}; +use arc_swap::ArcSwapOption; use camino::{Utf8Path, Utf8PathBuf}; use clap::Parser; use compute_api::spec::LocalProxySpec; @@ -27,6 +28,7 @@ use crate::config::{ }; use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; +use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; @@ -190,7 +192,11 @@ pub async fn run() -> anyhow::Result<()> { // 2. The config file is written but the signal hook is not yet received // 3. local_proxy completes startup but has no config loaded, despite there being a registerd config. refresh_config_notify.notify_one(); - tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify)); + tokio::spawn(refresh_config_loop( + config, + args.config_path, + refresh_config_notify, + )); maintenance_tasks.spawn(crate::http::health_server::task_main( metrics_listener, @@ -269,7 +275,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }; Ok(Box::leak(Box::new(ProxyConfig { - tls_config: None, + tls_config: ArcSwapOption::from(None), metric_collection: None, http_config, authentication_config: AuthenticationConfig { @@ -311,14 +317,16 @@ enum RefreshConfigError { Parse(#[from] serde_json::Error), #[error(transparent)] Validate(anyhow::Error), + #[error(transparent)] + Tls(anyhow::Error), } -async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { +async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc) { let mut init = true; loop { rx.notified().await; - match refresh_config_inner(&path).await { + match refresh_config_inner(config, &path).await { Ok(()) => {} // don't log for file not found errors if this is the first time we are checking // for computes that don't use local_proxy, this is not an error. @@ -327,6 +335,9 @@ async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { { debug!(error=?e, ?path, "could not read config file"); } + Err(RefreshConfigError::Tls(e)) => { + error!(error=?e, ?path, "could not read TLS certificates"); + } Err(e) => { error!(error=?e, ?path, "could not read config file"); } @@ -336,7 +347,10 @@ async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { } } -async fn refresh_config_inner(path: &Utf8Path) -> Result<(), RefreshConfigError> { +async fn refresh_config_inner( + config: &ProxyConfig, + path: &Utf8Path, +) -> Result<(), RefreshConfigError> { let bytes = tokio::fs::read(&path).await?; let data: LocalProxySpec = serde_json::from_slice(&bytes)?; @@ -406,5 +420,20 @@ async fn refresh_config_inner(path: &Utf8Path) -> Result<(), RefreshConfigError> info!("successfully loaded new config"); JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set }))); + if let Some(tls_config) = data.tls { + let tls_config = tokio::task::spawn_blocking(move || { + crate::tls::server_config::configure_tls( + &tls_config.key_path, + &tls_config.cert_path, + None, + false, + ) + }) + .await + .propagate_task_panic() + .map_err(RefreshConfigError::Tls)?; + config.tls_config.store(Some(Arc::new(tls_config))); + } + Ok(()) } diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index eec0bf8f99..feca5ccf88 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::Duration; use anyhow::bail; +use arc_swap::ArcSwapOption; use futures::future::Either; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; @@ -563,6 +564,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { (None, None) => None, _ => bail!("either both or neither tls-key and tls-cert must be specified"), }; + let tls_config = ArcSwapOption::from(tls_config.map(Arc::new)); let backup_metric_collection_config = config::MetricBackupCollectionConfig { remote_storage_config: args.metric_backup_collection_remote_storage.clone(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 1bcd22e98f..ad398c122c 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{Context, Ok, bail, ensure}; +use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; @@ -17,7 +18,7 @@ pub use crate::tls::server_config::{TlsConfig, configure_tls}; use crate::types::Host; pub struct ProxyConfig { - pub tls_config: Option, + pub tls_config: ArcSwapOption, pub metric_collection: Option, pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 4662860b3f..1156545f34 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -177,7 +177,8 @@ pub(crate) async fn handle_client( let proto = ctx.protocol(); let request_gauge = metrics.connection_requests.guard(proto); - let tls = config.tls_config.as_ref(); + let tls = config.tls_config.load(); + let tls = tls.as_deref(); let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 955f754497..2582e4c069 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -114,7 +114,7 @@ pub(crate) async fn handshake( let mut read_buf = read_buf.reader(); let mut res = Ok(()); - let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config()) + let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session while !read_buf.get_ref().is_empty() { diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0c6d352600..2e7d332a8b 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -278,7 +278,8 @@ pub(crate) async fn handle_client( let proto = ctx.protocol(); let request_gauge = metrics.connection_requests.guard(proto); - let tls = config.tls_config.as_ref(); + let tls = config.tls_config.load(); + let tls = tls.as_deref(); let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index e0b7539538..2c3e70138d 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -96,16 +96,18 @@ fn generate_tls_config<'a>( .with_safe_default_protocol_versions() .context("ring should support the default protocol versions")? .with_no_client_auth() - .with_single_cert(vec![cert.clone()], key.clone_key())? - .into(); + .with_single_cert(vec![cert.clone()], key.clone_key())?; let mut cert_resolver = CertResolver::new(); cert_resolver.add_cert(key, vec![cert], true)?; let common_names = cert_resolver.get_common_names(); + let config = Arc::new(config); + TlsConfig { - config, + http_config: config.clone(), + pg_config: config, common_names, cert_resolver: Arc::new(cert_resolver), } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index a7f46cbe58..00164d631a 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -19,6 +19,7 @@ use std::pin::{Pin, pin}; use std::sync::Arc; use anyhow::Context; +use arc_swap::ArcSwapOption; use async_trait::async_trait; use atomic_take::AtomicTake; use bytes::Bytes; @@ -117,18 +118,7 @@ pub async fn task_main( auth_backend, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); - let tls_acceptor: Arc = match config.tls_config.as_ref() { - Some(config) => { - let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config()); - // prefer http2, but support http/1.1 - tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - Arc::new(tls_server_config) - } - None => { - warn!("TLS config is missing"); - Arc::new(NoTls) - } - }; + let tls_acceptor: Arc = Arc::new(&config.tls_config); let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` @@ -216,22 +206,20 @@ pub(crate) type AsyncRW = Pin>; #[async_trait] trait MaybeTlsAcceptor: Send + Sync + 'static { - async fn accept(self: Arc, conn: ChainRW) -> std::io::Result; + async fn accept(&self, conn: ChainRW) -> std::io::Result; } #[async_trait] -impl MaybeTlsAcceptor for rustls::ServerConfig { - async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { - Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?)) - } -} - -struct NoTls; - -#[async_trait] -impl MaybeTlsAcceptor for NoTls { - async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { - Ok(Box::pin(conn)) +impl MaybeTlsAcceptor for &'static ArcSwapOption { + async fn accept(&self, conn: ChainRW) -> std::io::Result { + match &*self.load() { + Some(config) => Ok(Box::pin( + TlsAcceptor::from(config.http_config.clone()) + .accept(conn) + .await?, + )), + None => Ok(Box::pin(conn)), + } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 47009086c3..a79a478126 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -614,7 +614,9 @@ async fn handle_inner( &config.authentication_config, ctx, request.headers(), - config.tls_config.as_ref(), + // todo: race condition? + // we're unlikely to change the common names. + config.tls_config.load().as_deref(), )?; info!( user = conn_info.conn_info.user_info.user.as_str(), diff --git a/proxy/src/tls/server_config.rs b/proxy/src/tls/server_config.rs index 903c0b712b..4cbd0474c2 100644 --- a/proxy/src/tls/server_config.rs +++ b/proxy/src/tls/server_config.rs @@ -9,17 +9,14 @@ use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint}; pub struct TlsConfig { - pub config: Arc, + // unfortunate split since we cannot change the ALPN on demand. + // + pub http_config: Arc, + pub pg_config: Arc, pub common_names: HashSet, pub cert_resolver: Arc, } -impl TlsConfig { - pub fn to_server_config(&self) -> Arc { - self.config.clone() - } -} - /// Configure TLS for the main endpoint. pub fn configure_tls( key_path: &str, @@ -71,8 +68,15 @@ pub fn configure_tls( config.key_log = Arc::new(rustls::KeyLogFile::new()); } + let mut http_config = config.clone(); + let mut pg_config = config; + + http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + pg_config.alpn_protocols = vec![b"postgresql".to_vec()]; + Ok(TlsConfig { - config: Arc::new(config), + http_config: Arc::new(http_config), + pg_config: Arc::new(pg_config), common_names, cert_resolver, }) diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index f1696c5ff9..6a726f0585 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -26,11 +26,14 @@ camino = { version = "1", default-features = false, features = ["serde1"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } clap = { version = "4", features = ["derive", "env", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] } +const-oid = { version = "0.9", default-features = false, features = ["db", "std"] } crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] } -der = { version = "0.7", default-features = false, features = ["oid", "pem", "std"] } +der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } digest = { version = "0.10", features = ["mac", "oid", "std"] } +ecdsa = { version = "0.16", features = ["pem", "signing", "std", "verifying"] } either = { version = "1" } +elliptic-curve = { version = "0.13", default-features = false, features = ["digest", "hazmat", "jwk", "pem", "std"] } env_filter = { version = "0.1", default-features = false, features = ["regex"] } env_logger = { version = "0.11" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } @@ -65,6 +68,7 @@ num-iter = { version = "0.1", default-features = false, features = ["i128", "std num-rational = { version = "0.4", default-features = false, features = ["num-bigint-std", "std"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } +p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } @@ -74,6 +78,7 @@ regex-syntax = { version = "0.8" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] } rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] } scopeguard = { version = "1" } +sec1 = { version = "0.7", features = ["pem", "serde", "std", "subtle"] } serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["alloc", "raw_value"] } sha2 = { version = "0.10", features = ["asm", "oid"] }