feat(compute_ctl): use TLS if configured (#10972)

Closes: https://github.com/neondatabase/cloud/issues/22998

If control-plane reports that TLS should be used, load the certificates
(and watch for updates), make sure postgres use them, and detects
updates.

Procedure:
1. Load certificates
2. Reconfigure postgres/pgbouncer
3. Loop on a timer until certificates have loaded
4. Go to 1

Notes:
1. We only run this procedure if requested on startup by control plane.
2. We needed to compile pgbouncer with openssl enabled
3. Postgres doesn't allow tls keys to be globally accessible - must be
read only to the postgres user. I couldn't convince the autoscaling team
to let me put this logic into the VM settings, so instead compute_ctl
will copy the keys to be read-only by postgres.
4. To mitigate a race condition, we also verify that the key matches the
cert.
This commit is contained in:
Conrad Ludgate
2025-03-13 15:03:22 +00:00
committed by GitHub
parent b2286f5bcb
commit 3dec117572
24 changed files with 427 additions and 87 deletions

71
Cargo.lock generated
View File

@@ -1309,6 +1309,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"indexmap 2.0.1",
"jsonwebtoken",
"regex",
"remote_storage",
@@ -1339,6 +1340,7 @@ dependencies = [
"flate2",
"futures",
"http 1.1.0",
"indexmap 2.0.1",
"jsonwebtoken",
"metrics",
"nix 0.27.1",
@@ -1347,17 +1349,20 @@ dependencies = [
"once_cell",
"opentelemetry",
"opentelemetry_sdk",
"p256 0.13.2",
"postgres",
"postgres_initdb",
"regex",
"remote_storage",
"reqwest",
"ring",
"rlimit",
"rust-ini",
"serde",
"serde_json",
"serde_with",
"signal-hook",
"spki 0.7.3",
"tar",
"thiserror 1.0.69",
"tokio",
@@ -1377,6 +1382,7 @@ dependencies = [
"vm_monitor",
"walkdir",
"workspace_hack",
"x509-cert",
"zstd",
]
@@ -1801,6 +1807,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c"
dependencies = [
"const-oid",
"der_derive",
"flagset",
"pem-rfc7468",
"zeroize",
]
@@ -1819,6 +1827,17 @@ dependencies = [
"rusticata-macros",
]
[[package]]
name = "der_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "deranged"
version = "0.3.11"
@@ -2282,6 +2301,12 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flagset"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3ea1ec5f8307826a5b71094dd91fc04d4ae75d5709b20ad351c7fb4815c86ec"
[[package]]
name = "flate2"
version = "1.0.26"
@@ -6425,9 +6450,9 @@ dependencies = [
[[package]]
name = "sha1"
version = "0.10.5"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
@@ -7135,6 +7160,27 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tls_codec"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b"
dependencies = [
"tls_codec_derive",
"zeroize",
]
[[package]]
name = "tls_codec_derive"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "tokio"
version = "1.43.0"
@@ -8387,12 +8433,15 @@ dependencies = [
"chrono",
"clap",
"clap_builder",
"const-oid",
"crypto-bigint 0.5.5",
"der 0.7.8",
"deranged",
"digest",
"displaydoc",
"ecdsa 0.16.9",
"either",
"elliptic-curve 0.13.8",
"env_filter",
"env_logger",
"fail",
@@ -8427,6 +8476,7 @@ dependencies = [
"num-rational",
"num-traits",
"once_cell",
"p256 0.13.2",
"parquet",
"prettyplease",
"proc-macro2",
@@ -8439,6 +8489,7 @@ dependencies = [
"reqwest",
"rustls 0.23.18",
"scopeguard",
"sec1 0.7.3",
"serde",
"serde_json",
"sha2",
@@ -8484,6 +8535,18 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
[[package]]
name = "x509-cert"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1301e935010a701ae5f8655edc0ad17c44bad3ac5ce8c39185f75453b720ae94"
dependencies = [
"const-oid",
"der 0.7.8",
"spki 0.7.3",
"tls_codec",
]
[[package]]
name = "x509-certificate"
version = "0.23.1"
@@ -8612,9 +8675,9 @@ dependencies = [
[[package]]
name = "zeroize"
version = "1.7.0"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
dependencies = [
"serde",
"zeroize_derive",

View File

@@ -112,7 +112,7 @@ hyper0 = { package = "hyper", version = "0.14" }
hyper = "1.4"
hyper-util = "0.1"
tokio-tungstenite = "0.21.0"
indexmap = "2"
indexmap = { version = "2", features = ["serde"] }
indoc = "2"
ipnet = "2.10.0"
itertools = "0.10"

View File

@@ -1735,6 +1735,8 @@ RUN set -e \
libevent-dev \
libtool \
pkg-config \
libcurl4-openssl-dev \
libssl-dev \
&& apt clean && rm -rf /var/lib/apt/lists/*
# Use `dist_man_MANS=` to skip manpage generation (which requires python3/pandoc)
@@ -1743,7 +1745,7 @@ RUN set -e \
&& git clone --recurse-submodules --depth 1 --branch ${PGBOUNCER_TAG} https://github.com/pgbouncer/pgbouncer.git pgbouncer \
&& cd pgbouncer \
&& ./autogen.sh \
&& ./configure --prefix=/usr/local/pgbouncer --without-openssl \
&& ./configure --prefix=/usr/local/pgbouncer \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=

View File

@@ -26,6 +26,7 @@ fail.workspace = true
flate2.workspace = true
futures.workspace = true
http.workspace = true
indexmap.workspace = true
jsonwebtoken.workspace = true
metrics.workspace = true
nix.workspace = true
@@ -34,16 +35,19 @@ num_cpus.workspace = true
once_cell.workspace = true
opentelemetry.workspace = true
opentelemetry_sdk.workspace = true
p256 = { version = "0.13", features = ["pem"] }
postgres.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
ring = "0.17"
serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
spki = { version = "0.7.3", features = ["std"] }
tar.workspace = true
tower.workspace = true
tower-http.workspace = true
reqwest = { workspace = true, features = ["json"] }
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tokio-postgres.workspace = true
tokio-util.workspace = true
@@ -57,6 +61,7 @@ thiserror.workspace = true
url.workspace = true
uuid.workspace = true
walkdir.workspace = true
x509-cert = { version = "0.2.5" }
postgres_initdb.workspace = true
compute_api.workspace = true

View File

@@ -41,6 +41,7 @@ use crate::rsyslog::configure_audit_rsyslog;
use crate::spec::*;
use crate::swap::resize_swap;
use crate::sync_sk::{check_if_synced, ping_safekeeper};
use crate::tls::watch_cert_for_changes;
use crate::{config, extension_server, local_proxy};
pub static SYNC_SAFEKEEPERS_PID: AtomicU32 = AtomicU32::new(0);
@@ -112,6 +113,7 @@ pub struct ComputeNode {
// key: ext_archive_name, value: started download time, download_completed?
pub ext_download_progress: RwLock<HashMap<String, (DateTime<Utc>, bool)>>,
pub compute_ctl_config: ComputeCtlConfig,
}
// store some metrics about download size that might impact startup time
@@ -135,8 +137,6 @@ pub struct ComputeState {
/// passed by the control plane with a /configure HTTP request.
pub pspec: Option<ParsedSpec>,
pub compute_ctl_config: ComputeCtlConfig,
/// If the spec is passed by a /configure request, 'startup_span' is the
/// /configure request's tracing span. The main thread enters it when it
/// processes the compute startup, so that the compute startup is considered
@@ -160,7 +160,6 @@ impl ComputeState {
last_active: None,
error: None,
pspec: None,
compute_ctl_config: ComputeCtlConfig::default(),
startup_span: None,
metrics: ComputeMetrics::default(),
}
@@ -314,7 +313,6 @@ impl ComputeNode {
let pspec = ParsedSpec::try_from(cli_spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec);
}
new_state.compute_ctl_config = compute_ctl_config;
Ok(ComputeNode {
params,
@@ -323,6 +321,7 @@ impl ComputeNode {
state: Mutex::new(new_state),
state_changed: Condvar::new(),
ext_download_progress: RwLock::new(HashMap::new()),
compute_ctl_config,
})
}
@@ -345,7 +344,7 @@ impl ComputeNode {
// requests while configuration is still in progress.
crate::http::server::Server::External {
port: this.params.external_http_port,
jwks: this.state.lock().unwrap().compute_ctl_config.jwks.clone(),
config: this.compute_ctl_config.clone(),
compute_id: this.params.compute_id.clone(),
}
.launch(&this);
@@ -524,6 +523,16 @@ impl ComputeNode {
// Collect all the tasks that must finish here
let mut pre_tasks = tokio::task::JoinSet::new();
// Make sure TLS certificates are properly loaded and in the right place.
if self.compute_ctl_config.tls.is_some() {
let this = self.clone();
pre_tasks.spawn(async move {
this.watch_cert_for_changes().await;
Ok::<(), anyhow::Error>(())
});
}
// If there are any remote extensions in shared_preload_libraries, start downloading them
if pspec.spec.remote_extensions.is_some() {
let (this, spec) = (self.clone(), pspec.spec.clone());
@@ -579,11 +588,13 @@ impl ComputeNode {
if let Some(pgbouncer_settings) = &pspec.spec.pgbouncer_settings {
info!("tuning pgbouncer");
let pgbouncer_settings = pgbouncer_settings.clone();
let tls_config = self.compute_ctl_config.tls.clone();
// Spawn a background task to do the tuning,
// so that we don't block the main thread that starts Postgres.
let pgbouncer_settings = pgbouncer_settings.clone();
let _handle = tokio::spawn(async move {
let res = tune_pgbouncer(pgbouncer_settings).await;
let res = tune_pgbouncer(pgbouncer_settings, tls_config).await;
if let Err(err) = res {
error!("error while tuning pgbouncer: {err:?}");
// Continue with the startup anyway
@@ -1105,9 +1116,10 @@ impl ComputeNode {
// Remove/create an empty pgdata directory and put configuration there.
self.create_pgdata()?;
config::write_postgres_conf(
&pgdata_path.join("postgresql.conf"),
pgdata_path,
&pspec.spec,
self.params.internal_http_port,
&self.compute_ctl_config.tls,
)?;
// Syncing safekeepers is only safe with primary nodes: if a primary
@@ -1489,11 +1501,13 @@ impl ComputeNode {
if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings {
info!("tuning pgbouncer");
let pgbouncer_settings = pgbouncer_settings.clone();
let tls_config = self.compute_ctl_config.tls.clone();
// Spawn a background task to do the tuning,
// so that we don't block the main thread that starts Postgres.
let pgbouncer_settings = pgbouncer_settings.clone();
tokio::spawn(async move {
let res = tune_pgbouncer(pgbouncer_settings).await;
let res = tune_pgbouncer(pgbouncer_settings, tls_config).await;
if let Err(err) = res {
error!("error while tuning pgbouncer: {err:?}");
}
@@ -1505,7 +1519,8 @@ impl ComputeNode {
// Spawn a background task to do the configuration,
// so that we don't block the main thread that starts Postgres.
let local_proxy = local_proxy.clone();
let mut local_proxy = local_proxy.clone();
local_proxy.tls = self.compute_ctl_config.tls.clone();
tokio::spawn(async move {
if let Err(err) = local_proxy::configure(&local_proxy) {
error!("error while configuring local_proxy: {err:?}");
@@ -1515,8 +1530,12 @@ impl ComputeNode {
// Write new config
let pgdata_path = Path::new(&self.params.pgdata);
let postgresql_conf_path = pgdata_path.join("postgresql.conf");
config::write_postgres_conf(&postgresql_conf_path, &spec, self.params.internal_http_port)?;
config::write_postgres_conf(
pgdata_path,
&spec,
self.params.internal_http_port,
&self.compute_ctl_config.tls,
)?;
if !spec.skip_pg_catalog_updates {
let max_concurrent_connections = spec.reconfigure_concurrency;
@@ -1587,6 +1606,56 @@ impl ComputeNode {
Ok(())
}
pub async fn watch_cert_for_changes(self: Arc<Self>) {
// update status on cert renewal
if let Some(tls_config) = &self.compute_ctl_config.tls {
let tls_config = tls_config.clone();
// wait until the cert exists.
let mut cert_watch = watch_cert_for_changes(tls_config.cert_path.clone()).await;
tokio::task::spawn_blocking(move || {
let handle = tokio::runtime::Handle::current();
'cert_update: loop {
// let postgres/pgbouncer/local_proxy know the new cert/key exists.
// we need to wait until it's configurable first.
let mut state = self.state.lock().unwrap();
'status_update: loop {
match state.status {
// let's update the state to config pending
ComputeStatus::ConfigurationPending | ComputeStatus::Running => {
state.set_status(
ComputeStatus::ConfigurationPending,
&self.state_changed,
);
break 'status_update;
}
// exit loop
ComputeStatus::Failed
| ComputeStatus::TerminationPending
| ComputeStatus::Terminated => break 'cert_update,
// wait
ComputeStatus::Init
| ComputeStatus::Configuration
| ComputeStatus::Empty => {
state = self.state_changed.wait(state).unwrap();
}
}
}
drop(state);
// wait for a new certificate update
if handle.block_on(cert_watch.changed()).is_err() {
break;
}
}
});
}
}
/// Update the `last_active` in the shared state, but ensure that it's a more recent one.
pub fn update_last_active(&self, last_active: Option<DateTime<Utc>>) {
let mut state = self.state.lock().unwrap();

View File

@@ -6,11 +6,13 @@ use std::io::Write;
use std::io::prelude::*;
use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
};
use crate::tls::{self, SERVER_CRT, SERVER_KEY};
/// Check that `line` is inside a text file and put it there if it is not.
/// Create file if it doesn't exist.
@@ -38,10 +40,12 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
/// Create or completely rewrite configuration file specified by `path`
pub fn write_postgres_conf(
path: &Path,
pgdata_path: &Path,
spec: &ComputeSpec,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
) -> Result<()> {
let path = pgdata_path.join("postgresql.conf");
// File::create() destroys the file content if it exists.
let mut file = File::create(path)?;
@@ -86,6 +90,20 @@ pub fn write_postgres_conf(
)?;
}
// tls
if let Some(tls_config) = tls_config {
writeln!(file, "ssl = on")?;
// postgres requires the keyfile to be in a secure file,
// currently too complicated to ensure that at the VM level,
// so we just copy them to another file instead. :shrug:
tls::update_key_path_blocking(pgdata_path, tls_config);
// these are the default, but good to be explicit.
writeln!(file, "ssl_cert_file = '{}'", SERVER_CRT)?;
writeln!(file, "ssl_key_file = '{}'", SERVER_KEY)?;
}
// Locales
if cfg!(target_os = "macos") {
writeln!(file, "lc_messages='C'")?;

View File

@@ -8,8 +8,8 @@ use axum::Router;
use axum::middleware::{self};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use compute_api::responses::ComputeCtlConfig;
use http::StatusCode;
use jsonwebtoken::jwk::JwkSet;
use tokio::net::TcpListener;
use tower::ServiceBuilder;
use tower_http::{
@@ -41,7 +41,7 @@ pub enum Server {
},
External {
port: u16,
jwks: JwkSet,
config: ComputeCtlConfig,
compute_id: String,
},
}
@@ -79,7 +79,7 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
router
}
Server::External {
jwks, compute_id, ..
config, compute_id, ..
} => {
let unauthenticated_router =
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
@@ -95,7 +95,7 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
.route("/terminate", post(terminate::terminate))
.layer(AsyncRequireAuthorizationLayer::new(Authorize::new(
compute_id.clone(),
jwks.clone(),
config.jwks.clone(),
)));
router

View File

@@ -26,3 +26,4 @@ pub mod spec;
mod spec_apply;
pub mod swap;
pub mod sync_sk;
pub mod tls;

View File

@@ -10,8 +10,10 @@ use std::str::FromStr;
use std::time::{Duration, Instant};
use anyhow::{Result, bail};
use compute_api::responses::TlsConfig;
use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
use futures::StreamExt;
use indexmap::IndexMap;
use ini::Ini;
use notify::{RecursiveMode, Watcher};
use postgres::config::Config;
@@ -406,7 +408,7 @@ pub fn create_pgdata(pgdata: &str) -> Result<()> {
/// Update pgbouncer.ini with provided options
fn update_pgbouncer_ini(
pgbouncer_config: HashMap<String, String>,
pgbouncer_config: IndexMap<String, String>,
pgbouncer_ini_path: &str,
) -> Result<()> {
let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
@@ -427,7 +429,10 @@ fn update_pgbouncer_ini(
/// Tune pgbouncer.
/// 1. Apply new config using pgbouncer admin console
/// 2. Add new values to pgbouncer.ini to preserve them after restart
pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result<()> {
pub async fn tune_pgbouncer(
mut pgbouncer_config: IndexMap<String, String>,
tls_config: Option<TlsConfig>,
) -> Result<()> {
let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
// for VMs use pgbouncer specific way to connect to
// pgbouncer admin console without password
@@ -473,19 +478,21 @@ pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result
}
};
// Apply new config
for (option_name, value) in pgbouncer_config.iter() {
let query = format!("SET {}={}", option_name, value);
// keep this log line for debugging purposes
info!("Applying pgbouncer setting change: {}", query);
if let Some(tls_config) = tls_config {
// pgbouncer starts in a half-ok state if it cannot find these files.
// It will default to client_tls_sslmode=deny, which causes proxy to error.
// There is a small window at startup where these files don't yet exist in the VM.
// Best to wait until it exists.
loop {
if let Ok(true) = tokio::fs::try_exists(&tls_config.key_path).await {
break;
}
tokio::time::sleep(Duration::from_millis(500)).await
}
if let Err(err) = client.simple_query(&query).await {
// Don't fail on error, just print it into log
error!(
"Failed to apply pgbouncer setting change: {}, {}",
query, err
);
};
pgbouncer_config.insert("client_tls_cert_file".to_string(), tls_config.cert_path);
pgbouncer_config.insert("client_tls_key_file".to_string(), tls_config.key_path);
pgbouncer_config.insert("client_tls_sslmode".to_string(), "allow".to_string());
}
// save values to pgbouncer.ini
@@ -501,6 +508,13 @@ pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result
};
update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
info!("Applying pgbouncer setting change");
if let Err(err) = client.simple_query("RELOAD").await {
// Don't fail on error, just print it into log
error!("Failed to apply pgbouncer setting change, {err}",);
};
Ok(())
}

118
compute_tools/src/tls.rs Normal file
View File

@@ -0,0 +1,118 @@
use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
use anyhow::{Context, Result, bail};
use compute_api::responses::TlsConfig;
use ring::digest;
use spki::ObjectIdentifier;
use spki::der::{Decode, PemReader};
use x509_cert::Certificate;
#[derive(Clone, Copy)]
pub struct CertDigest(digest::Digest);
pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver<CertDigest> {
let mut digest = compute_digest(&cert_path).await;
let (tx, rx) = tokio::sync::watch::channel(digest);
tokio::spawn(async move {
while !tx.is_closed() {
let new_digest = compute_digest(&cert_path).await;
if digest.0.as_ref() != new_digest.0.as_ref() {
digest = new_digest;
_ = tx.send(digest);
}
tokio::time::sleep(Duration::from_secs(60)).await
}
});
rx
}
async fn compute_digest(cert_path: &str) -> CertDigest {
loop {
match try_compute_digest(cert_path).await {
Ok(d) => break d,
Err(e) => {
tracing::error!("could not read cert file {e:?}");
tokio::time::sleep(Duration::from_secs(1)).await
}
}
}
}
async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
let data = tokio::fs::read(cert_path).await?;
// sha256 is extremely collision resistent. can safely assume the digest to be unique
Ok(CertDigest(digest::digest(&digest::SHA256, &data)))
}
pub const SERVER_CRT: &str = "server.crt";
pub const SERVER_KEY: &str = "server.key";
pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
loop {
match try_update_key_path_blocking(pg_data, tls_config) {
Ok(()) => break,
Err(e) => {
tracing::error!("could not create key file {e:?}");
std::thread::sleep(Duration::from_secs(1))
}
}
}
}
// Postgres requires the keypath be "secure". This means
// 1. Owned by the postgres user.
// 2. Have permission 600.
fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> {
let key = std::fs::read_to_string(&tls_config.key_path)?;
let crt = std::fs::read_to_string(&tls_config.cert_path)?;
// to mitigate a race condition during renewal.
verify_key_cert(&key, &crt)?;
let mut key_file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(pg_data.join(SERVER_KEY))?;
let mut crt_file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(pg_data.join(SERVER_CRT))?;
key_file.write_all(key.as_bytes())?;
crt_file.write_all(crt.as_bytes())?;
Ok(())
}
fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
const ECDSA_WITH_SHA256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.4.3.2");
let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?)
.context("decode cert")?;
match cert.signature_algorithm.oid {
ECDSA_WITH_SHA256 => {
let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
let a = key.public_key().to_sec1_bytes();
let b = cert
.tbs_certificate
.subject_public_key_info
.subject_public_key
.raw_bytes();
if *a != *b {
bail!("private key file does not match certificate")
}
}
_ => bail!("unknown TLS key type"),
}
Ok(())
}

View File

@@ -7,6 +7,7 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
chrono.workspace = true
indexmap.workspace = true
jsonwebtoken.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -139,6 +139,7 @@ pub struct ComputeCtlConfig {
/// Set of JSON web keys that the compute can use to authenticate
/// communication from the control plane.
pub jwks: JwkSet,
pub tls: Option<TlsConfig>,
}
impl Default for ComputeCtlConfig {
@@ -147,10 +148,17 @@ impl Default for ComputeCtlConfig {
jwks: JwkSet {
keys: Vec::default(),
},
tls: None,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TlsConfig {
pub key_path: String,
pub cert_path: String,
}
/// Response of the `/computes/{compute_id}/spec` control-plane API.
#[derive(Deserialize, Debug)]
pub struct ControlPlaneSpecResponse {

View File

@@ -5,12 +5,15 @@
//! and connect it to the storage nodes.
use std::collections::HashMap;
use indexmap::IndexMap;
use regex::Regex;
use remote_storage::RemotePath;
use serde::{Deserialize, Serialize};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use crate::responses::TlsConfig;
/// String type alias representing Postgres identifier and
/// intended to be used for DB / role names.
pub type PgIdent = String;
@@ -125,7 +128,7 @@ pub struct ComputeSpec {
// information about available remote extensions
pub remote_extensions: Option<RemoteExtSpec>,
pub pgbouncer_settings: Option<HashMap<String, String>>,
pub pgbouncer_settings: Option<IndexMap<String, String>>,
// Stripe size for pageserver sharding, in pages
#[serde(default)]
@@ -357,6 +360,9 @@ pub struct LocalProxySpec {
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks: Option<Vec<JwksSettings>>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub tls: Option<TlsConfig>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use clap::Parser;
use compute_api::spec::LocalProxySpec;
@@ -27,6 +28,7 @@ use crate::config::{
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
@@ -190,7 +192,11 @@ pub async fn run() -> anyhow::Result<()> {
// 2. The config file is written but the signal hook is not yet received
// 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify));
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify,
));
maintenance_tasks.spawn(crate::http::health_server::task_main(
metrics_listener,
@@ -269,7 +275,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
};
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
tls_config: ArcSwapOption::from(None),
metric_collection: None,
http_config,
authentication_config: AuthenticationConfig {
@@ -311,14 +317,16 @@ enum RefreshConfigError {
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(&path).await {
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
@@ -327,6 +335,9 @@ async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
@@ -336,7 +347,10 @@ async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
}
}
async fn refresh_config_inner(path: &Utf8Path) -> Result<(), RefreshConfigError> {
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
@@ -406,5 +420,20 @@ async fn refresh_config_inner(path: &Utf8Path) -> Result<(), RefreshConfigError>
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
&tls_config.key_path,
&tls_config.cert_path,
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}

View File

@@ -4,6 +4,7 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use arc_swap::ArcSwapOption;
use futures::future::Either;
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
@@ -563,6 +564,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
(None, None) => None,
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};
let tls_config = ArcSwapOption::from(tls_config.map(Arc::new));
let backup_metric_collection_config = config::MetricBackupCollectionConfig {
remote_storage_config: args.metric_backup_collection_remote_storage.clone(),

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Ok, bail, ensure};
use arc_swap::ArcSwapOption;
use clap::ValueEnum;
use remote_storage::RemoteStorageConfig;
@@ -17,7 +18,7 @@ pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::Host;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub tls_config: ArcSwapOption<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,

View File

@@ -177,7 +177,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);

View File

@@ -114,7 +114,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
.accept_with(raw, |session| {
// push the early data to the tls session
while !read_buf.get_ref().is_empty() {

View File

@@ -278,7 +278,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);

View File

@@ -96,16 +96,18 @@ fn generate_tls_config<'a>(
.with_safe_default_protocol_versions()
.context("ring should support the default protocol versions")?
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?
.into();
.with_single_cert(vec![cert.clone()], key.clone_key())?;
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = cert_resolver.get_common_names();
let config = Arc::new(config);
TlsConfig {
config,
http_config: config.clone(),
pg_config: config,
common_names,
cert_resolver: Arc::new(cert_resolver),
}

View File

@@ -19,6 +19,7 @@ use std::pin::{Pin, pin};
use std::sync::Arc;
use anyhow::Context;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
@@ -117,18 +118,7 @@ pub async fn task_main(
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
Some(config) => {
let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(tls_server_config)
}
None => {
warn!("TLS config is missing");
Arc::new(NoTls)
}
};
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = Arc::new(&config.tls_config);
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
@@ -216,22 +206,20 @@ pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
#[async_trait]
trait MaybeTlsAcceptor: Send + Sync + 'static {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
async fn accept(&self, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
}
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
}
}
struct NoTls;
#[async_trait]
impl MaybeTlsAcceptor for NoTls {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(conn))
impl MaybeTlsAcceptor for &'static ArcSwapOption<crate::config::TlsConfig> {
async fn accept(&self, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
match &*self.load() {
Some(config) => Ok(Box::pin(
TlsAcceptor::from(config.http_config.clone())
.accept(conn)
.await?,
)),
None => Ok(Box::pin(conn)),
}
}
}

View File

@@ -614,7 +614,9 @@ async fn handle_inner(
&config.authentication_config,
ctx,
request.headers(),
config.tls_config.as_ref(),
// todo: race condition?
// we're unlikely to change the common names.
config.tls_config.load().as_deref(),
)?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),

View File

@@ -9,17 +9,14 @@ use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
// unfortunate split since we cannot change the ALPN on demand.
// <https://github.com/rustls/rustls/issues/2260>
pub http_config: Arc<rustls::ServerConfig>,
pub pg_config: Arc<rustls::ServerConfig>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
@@ -71,8 +68,15 @@ pub fn configure_tls(
config.key_log = Arc::new(rustls::KeyLogFile::new());
}
let mut http_config = config.clone();
let mut pg_config = config;
http_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
pg_config.alpn_protocols = vec![b"postgresql".to_vec()];
Ok(TlsConfig {
config: Arc::new(config),
http_config: Arc::new(http_config),
pg_config: Arc::new(pg_config),
common_names,
cert_resolver,
})

View File

@@ -26,11 +26,14 @@ camino = { version = "1", default-features = false, features = ["serde1"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
const-oid = { version = "0.9", default-features = false, features = ["db", "std"] }
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
der = { version = "0.7", default-features = false, features = ["oid", "pem", "std"] }
der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] }
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
digest = { version = "0.10", features = ["mac", "oid", "std"] }
ecdsa = { version = "0.16", features = ["pem", "signing", "std", "verifying"] }
either = { version = "1" }
elliptic-curve = { version = "0.13", default-features = false, features = ["digest", "hazmat", "jwk", "pem", "std"] }
env_filter = { version = "0.1", default-features = false, features = ["regex"] }
env_logger = { version = "0.11" }
fail = { version = "0.5", default-features = false, features = ["failpoints"] }
@@ -65,6 +68,7 @@ num-iter = { version = "0.1", default-features = false, features = ["i128", "std
num-rational = { version = "0.4", default-features = false, features = ["num-bigint-std", "std"] }
num-traits = { version = "0.2", features = ["i128", "libm"] }
once_cell = { version = "1" }
p256 = { version = "0.13", features = ["jwk"] }
parquet = { version = "53", default-features = false, features = ["zstd"] }
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
rand = { version = "0.8", features = ["small_rng"] }
@@ -74,6 +78,7 @@ regex-syntax = { version = "0.8" }
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] }
rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] }
scopeguard = { version = "1" }
sec1 = { version = "0.7", features = ["pem", "serde", "std", "subtle"] }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["alloc", "raw_value"] }
sha2 = { version = "0.10", features = ["asm", "oid"] }