Compare commits

...

12 Commits

Author SHA1 Message Date
Conrad Ludgate
cc66f78d01 update readme 2025-07-31 11:51:44 +01:00
Conrad Ludgate
f9e6802974 s/ssl/tls 2025-07-30 14:03:22 +01:00
Conrad Ludgate
74afc9d96f refactor pgbouncer tuning 2025-07-30 12:34:36 +01:00
Conrad Ludgate
86fe3150f0 add basic tls test 2025-07-30 12:34:31 +01:00
Conrad Ludgate
52be0146d3 fix runtime 2025-07-30 12:32:23 +01:00
Conrad Ludgate
a3f2a2cae5 add fast path for TLS renewal configuration 2025-07-30 12:29:41 +01:00
Conrad Ludgate
a24a0032ad update certificate files in the watch task 2025-07-30 11:47:34 +01:00
Conrad Ludgate
70cb02742a pass in the tls_config as a param to watch_certs_for_changes, also wait for it to complete before configuring pgbouncer/local_proxy 2025-07-30 11:47:07 +01:00
Conrad Ludgate
a845295cb3 refactor TLS processing. Only use blocking-IO, split out the loading of certificates from the updating of certificates 2025-07-30 10:29:03 +01:00
Conrad Ludgate
e288cd2198 fix concurrent reconfigure while TLS configuration is taking place 2025-07-30 10:14:20 +01:00
Conrad Ludgate
ffa9e595b8 introduce separate reload commands 2025-07-30 10:14:17 +01:00
Conrad Ludgate
e7b1f63f68 add logs for TLS 2025-07-30 10:08:04 +01:00
18 changed files with 407 additions and 192 deletions

View File

@@ -57,6 +57,9 @@ stateDiagram-v2
RefreshConfigurationPending --> RefreshConfiguration: Received compute spec and started configuration
RefreshConfiguration --> Running : Compute has been re-configured
RefreshConfiguration --> RefreshConfigurationPending : Configuration failed and to be retried
Running --> Reloading : Local changes (TLS certificate renewal) were detected and postgres is being reloaded
Reloading --> Running : Postgres was reloaded
Reloading --> Failed : Failed to reload postgres
TerminationPendingFast --> Terminated compute with 30s delay for cplane to inspect status
TerminationPendingImmediate --> Terminated : Terminated compute immediately
Failed --> RefreshConfigurationPending : Received a /refresh_configuration request

View File

@@ -28,7 +28,7 @@ use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::{spawn, sync::watch, task::JoinHandle, time};
@@ -57,7 +57,6 @@ use crate::rsyslog::{
use crate::spec::*;
use crate::swap::resize_swap;
use crate::sync_sk::{check_if_synced, ping_safekeeper};
use crate::tls::watch_cert_for_changes;
use crate::{config, extension_server, local_proxy};
pub static SYNC_SAFEKEEPERS_PID: AtomicU32 = AtomicU32::new(0);
@@ -842,14 +841,11 @@ impl ComputeNode {
let mut pre_tasks = tokio::task::JoinSet::new();
// Make sure TLS certificates are properly loaded and in the right place.
if self.compute_ctl_config.tls.is_some() {
let tls_task = self.compute_ctl_config.tls.as_ref().map(|tls_config| {
let this = self.clone();
pre_tasks.spawn(async move {
this.watch_cert_for_changes().await;
Ok::<(), anyhow::Error>(())
});
}
let tls_config = tls_config.clone();
tokio::task::spawn_blocking(|| this.watch_cert_for_changes(tls_config))
});
let tls_config = self.tls_config(&pspec.spec);
@@ -904,6 +900,13 @@ impl ComputeNode {
});
}
// Wait for TLS certificates to be issued before updating pgbouncer and local proxy.
let rt = tokio::runtime::Handle::current();
if let Some(tls_task) = tls_task {
rt.block_on(tls_task)
.context("TLS certificate renewal task panicked")?;
}
// tune pgbouncer
if let Some(pgbouncer_settings) = &pspec.spec.pgbouncer_settings {
info!("tuning pgbouncer");
@@ -986,7 +989,6 @@ impl ComputeNode {
let _configurator_handle = launch_configurator(self);
// Wait for all the pre-tasks to finish before starting postgres
let rt = tokio::runtime::Handle::current();
while let Some(res) = rt.block_on(pre_tasks.join_next()) {
res??;
}
@@ -1949,10 +1951,7 @@ impl ComputeNode {
.clone(),
);
let mut tls_config = None::<TlsConfig>;
if spec.features.contains(&ComputeFeature::TlsExperimental) {
tls_config = self.compute_ctl_config.tls.clone();
}
let tls_config = self.tls_config(&spec);
self.update_installed_extensions_collection_interval(&spec);
@@ -2134,6 +2133,60 @@ impl ComputeNode {
Ok(())
}
/// Tell postgres/pgbouncer/local_proxy to reload their configurations.
#[instrument(skip_all)]
pub fn reload(&self, spec: ComputeSpec) -> Result<()> {
let rt = tokio::runtime::Handle::current();
if spec.pgbouncer_settings.is_some() {
rt.block_on(reload_pgbouncer())?;
}
if spec.local_proxy_config.is_some() {
local_proxy::reload()?;
}
self.pg_reload_conf()?;
let unknown_op = "unknown".to_string();
let op_id = spec.operation_uuid.as_ref().unwrap_or(&unknown_op);
info!("finished reload of compute node for operation {op_id}");
Ok(())
}
/// Acquire the "reloading" lock while running the supplied function.
///
/// This ensures that this thread is the only thread that
/// can issue signals to postgres.
///
/// If the supplied function errors, the compute status is marked as failed.
pub fn lock_while_reloading<T>(
&self,
mut state: MutexGuard<'_, ComputeState>,
f: impl FnOnce(ComputeSpec) -> Result<T>,
) -> Result<T> {
let old_status = state.status;
// transition to the reloading state.
state.set_status(ComputeStatus::Reloading, &self.state_changed);
let spec = state.pspec.as_ref().unwrap().spec.clone();
// unlock while reloading, so we don't block other tasks.
drop(state);
let res = f(spec);
let new_status = if res.is_ok() {
old_status
} else {
ComputeStatus::Failed
};
let mut state = self.state.lock().unwrap();
// make sure our invariants are upheld
assert_eq!(state.status, ComputeStatus::Reloading);
state.set_status(new_status, &self.state_changed);
res
}
#[instrument(skip_all)]
pub fn configure_as_primary(&self, compute_state: &ComputeState) -> Result<()> {
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
@@ -2168,57 +2221,103 @@ impl ComputeNode {
Ok(())
}
pub async fn watch_cert_for_changes(self: Arc<Self>) {
// update status on cert renewal
if let Some(tls_config) = &self.compute_ctl_config.tls {
let tls_config = tls_config.clone();
pub fn watch_cert_for_changes(self: Arc<Self>, tls_config: TlsConfig) {
// wait until the cert exists.
let mut digest = crate::tls::compute_digest(&tls_config.cert_path);
info!(
cert_path = tls_config.cert_path,
key_path = tls_config.key_path,
"TLS certificates found"
);
// wait until the cert exists.
let mut cert_watch = watch_cert_for_changes(tls_config.cert_path.clone()).await;
// ensure the keys are saved before continuing.
let key_pair = crate::tls::load_certs_blocking(&tls_config);
while let Err(e) =
crate::tls::update_key_path_blocking(Path::new(&self.params.pgdata), &key_pair)
{
error!("could not save TLS certificates: {e}");
std::thread::sleep(Duration::from_millis(20));
}
tokio::task::spawn_blocking(move || {
let handle = tokio::runtime::Handle::current();
'cert_update: loop {
// let postgres/pgbouncer/local_proxy know the new cert/key exists.
// we need to wait until it's configurable first.
tokio::task::spawn_blocking(move || {
'cert_update: loop {
// wait for a new certificate update
let new_digest = crate::tls::wait_until_cert_changed(digest, &tls_config.cert_path);
let mut state = self.state.lock().unwrap();
'status_update: loop {
match state.status {
// let's update the state to config pending
ComputeStatus::ConfigurationPending | ComputeStatus::Running => {
state.set_status(
ComputeStatus::ConfigurationPending,
&self.state_changed,
);
break 'status_update;
}
// load the corresponding keys
let key_pair = crate::tls::load_certs_blocking(&tls_config);
// exit loop
ComputeStatus::Failed
| ComputeStatus::TerminationPendingFast
| ComputeStatus::TerminationPendingImmediate
| ComputeStatus::Terminated => break 'cert_update,
// let postgres/pgbouncer/local_proxy know the new cert/key exists.
// we need to wait until it's configurable first.
// wait
ComputeStatus::Init
| ComputeStatus::Configuration
| ComputeStatus::RefreshConfiguration
| ComputeStatus::RefreshConfigurationPending
| ComputeStatus::Empty => {
state = self.state_changed.wait(state).unwrap();
}
let mut state = self.state.lock().unwrap();
'status_update: loop {
match state.status {
// let's update the state to config pending
ComputeStatus::Running => {
info!("reloading compute due to TLS certificate renewal");
break 'status_update;
}
// exit loop
ComputeStatus::Failed
| ComputeStatus::TerminationPendingFast
| ComputeStatus::TerminationPendingImmediate
| ComputeStatus::Terminated => break 'cert_update,
// wait
ComputeStatus::Init
| ComputeStatus::Configuration
| ComputeStatus::ConfigurationPending
| ComputeStatus::RefreshConfiguration
| ComputeStatus::RefreshConfigurationPending
| ComputeStatus::Reloading
| ComputeStatus::Empty => {
state = self.state_changed.wait(state).unwrap();
}
}
drop(state);
}
// wait for a new certificate update
if handle.block_on(cert_watch.changed()).is_err() {
break;
let result = self.lock_while_reloading(state, |spec| {
// ensure the keys are saved before continuing.
// we do this while holding the 'reloading' state so that we know we're not interfering with any
// active configuration stages.
if let Err(e) = crate::tls::update_key_path_blocking(
Path::new(&self.params.pgdata),
&key_pair,
) {
return Ok(Err(e));
}
// reload postgres/pgbouncer/local_proxy to pick up our new certificates.
self.reload(spec)?;
Ok(Ok(()))
});
match result {
// Reload failed. Compute is in a bad state.
Err(e) => {
error!("could not reload compute node: {}", e);
return;
}
// Updating the certificates failed. Retry
Ok(Err(e)) => {
error!("could not save TLS certificates: {e}");
std::thread::sleep(Duration::from_millis(20));
}
// Successful. Acknowledge that we've saved these certificates.
Ok(Ok(())) => {
digest = new_digest;
info!(
cert_path = tls_config.cert_path,
key_path = tls_config.key_path,
"TLS certificates renewed",
);
}
}
});
}
}
});
}
pub fn tls_config(&self, spec: &ComputeSpec) -> &Option<TlsConfig> {

View File

@@ -16,7 +16,7 @@ use crate::pg_helpers::{
DatabricksSettingsExt as _, GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize,
escape_conf_value,
};
use crate::tls::{self, SERVER_CRT, SERVER_KEY};
use crate::tls::{SERVER_CRT, SERVER_KEY};
use utils::shard::{ShardIndex, ShardNumber};
@@ -178,14 +178,9 @@ pub fn write_postgres_conf(
}
// tls
if let Some(tls_config) = tls_config {
if tls_config.is_some() {
writeln!(file, "ssl = on")?;
// postgres requires the keyfile to be in a secure file,
// currently too complicated to ensure that at the VM level,
// so we just copy them to another file instead. :shrug:
tls::update_key_path_blocking(pgdata_path, tls_config);
// these are the default, but good to be explicit.
writeln!(file, "ssl_cert_file = '{SERVER_CRT}'")?;
writeln!(file, "ssl_key_file = '{SERVER_KEY}'")?;

View File

@@ -12,8 +12,10 @@ use crate::http::JsonResponse;
/// Check that the compute is currently running.
pub(in crate::http) async fn is_writable(State(compute): State<Arc<ComputeNode>>) -> Response {
let status = compute.get_status();
if status != ComputeStatus::Running {
return JsonResponse::invalid_status(status);
match status {
// If we are running, or just reloading the config, we are ok to write a new config.
ComputeStatus::Running | ComputeStatus::Reloading => {}
_ => return JsonResponse::invalid_status(status),
}
match check_writability(&compute).await {

View File

@@ -27,32 +27,6 @@ pub(in crate::http) async fn configure(
Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e),
};
// XXX: wrap state update under lock in a code block. Otherwise, we will try
// to `Send` `mut state` into the spawned thread bellow, which will cause
// the following rustc error:
//
// error: future cannot be sent between threads safely
{
let mut state = compute.state.lock().unwrap();
if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) {
return JsonResponse::invalid_status(state.status);
}
// Pass the tracing span to the main thread that performs the startup,
// so that the start_compute operation is considered a child of this
// configure request for tracing purposes.
state.startup_span = Some(tracing::Span::current());
if compute.params.lakebase_mode {
ComputeNode::set_spec(&compute.params, &mut state, pspec);
} else {
state.pspec = Some(pspec);
}
state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed);
drop(state);
}
// Spawn a blocking thread to wait for compute to become Running. This is
// needed to not block the main pool of workers and to be able to serve
// other requests while some particular request is waiting for compute to
@@ -60,6 +34,32 @@ pub(in crate::http) async fn configure(
let c = compute.clone();
let completed = task::spawn_blocking(move || {
let mut state = c.state.lock().unwrap();
loop {
match state.status {
// ideal state.
ComputeStatus::Empty | ComputeStatus::Running => break,
// we need to wait until reloaded
ComputeStatus::Reloading => {
state = c.state_changed.wait(state).unwrap();
}
// All other cases are unexpected.
_ => return Err(JsonResponse::invalid_status(state.status)),
}
}
// Pass the tracing span to the main thread that performs the startup,
// so that the start_compute operation is considered a child of this
// configure request for tracing purposes.
state.startup_span = Some(tracing::Span::current());
if c.params.lakebase_mode {
ComputeNode::set_spec(&c.params, &mut state, pspec);
} else {
state.pspec = Some(pspec);
}
state.set_status(ComputeStatus::ConfigurationPending, &c.state_changed);
while state.status != ComputeStatus::Running {
state = c.state_changed.wait(state).unwrap();
info!(
@@ -71,7 +71,7 @@ pub(in crate::http) async fn configure(
if state.status == ComputeStatus::Failed {
let err = state.error.as_ref().map_or("unknown error", |x| x);
let msg = format!("compute configuration failed: {err:?}");
return Err(msg);
return Err(JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, msg));
}
}
@@ -81,7 +81,7 @@ pub(in crate::http) async fn configure(
.unwrap();
if let Err(e) = completed {
return JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e);
return e;
}
// Return current compute state if everything went well.

View File

@@ -11,9 +11,11 @@ use utils::pid_file::{self, PidFileRead};
pub fn configure(local_proxy: &LocalProxySpec) -> Result<()> {
write_local_proxy_conf("/etc/local_proxy/config.json".as_ref(), local_proxy)?;
notify_local_proxy("/etc/local_proxy/pid".as_ref())?;
reload()
}
Ok(())
pub fn reload() -> Result<()> {
notify_local_proxy("/etc/local_proxy/pid".as_ref())
}
/// Create or completely rewrite configuration file specified by `path`

View File

@@ -466,13 +466,7 @@ fn update_pgbouncer_ini(
Ok(())
}
/// Tune pgbouncer.
/// 1. Apply new config using pgbouncer admin console
/// 2. Add new values to pgbouncer.ini to preserve them after restart
pub async fn tune_pgbouncer(
mut pgbouncer_config: IndexMap<String, String>,
tls_config: Option<TlsConfig>,
) -> Result<()> {
async fn connect() -> Result<tokio_postgres::Client> {
let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
// for VMs use pgbouncer specific way to connect to
// pgbouncer admin console without password
@@ -518,18 +512,17 @@ pub async fn tune_pgbouncer(
}
};
if let Some(tls_config) = tls_config {
// pgbouncer starts in a half-ok state if it cannot find these files.
// It will default to client_tls_sslmode=deny, which causes proxy to error.
// There is a small window at startup where these files don't yet exist in the VM.
// Best to wait until it exists.
loop {
if let Ok(true) = tokio::fs::try_exists(&tls_config.key_path).await {
break;
}
tokio::time::sleep(Duration::from_millis(500)).await
}
Ok(client)
}
/// Tune pgbouncer.
/// 1. Apply new config to pgbouncer.ini
/// 2. Notify pgbouncer to reload
pub async fn tune_pgbouncer(
mut pgbouncer_config: IndexMap<String, String>,
tls_config: Option<TlsConfig>,
) -> Result<()> {
if let Some(tls_config) = tls_config {
pgbouncer_config.insert("client_tls_cert_file".to_string(), tls_config.cert_path);
pgbouncer_config.insert("client_tls_key_file".to_string(), tls_config.key_path);
pgbouncer_config.insert("client_tls_sslmode".to_string(), "allow".to_string());
@@ -550,10 +543,17 @@ pub async fn tune_pgbouncer(
info!("Applying pgbouncer setting change");
reload_pgbouncer().await
}
/// Reload pgbouncer.
pub async fn reload_pgbouncer() -> Result<()> {
let client = connect().await?;
if let Err(err) = client.simple_query("RELOAD").await {
// Don't fail on error, just print it into log
error!("Failed to apply pgbouncer setting change, {err}",);
};
error!("Failed to apply pgbouncer setting change: {err}",);
}
Ok(())
}

View File

@@ -3,42 +3,43 @@ use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
use anyhow::{Context, Result, bail};
use compute_api::responses::TlsConfig;
use ring::digest;
use x509_cert::Certificate;
#[derive(Clone, Copy)]
pub struct CertDigest(digest::Digest);
pub async fn watch_cert_for_changes(cert_path: String) -> tokio::sync::watch::Receiver<CertDigest> {
let mut digest = compute_digest(&cert_path).await;
let (tx, rx) = tokio::sync::watch::channel(digest);
tokio::spawn(async move {
while !tx.is_closed() {
let new_digest = compute_digest(&cert_path).await;
if digest.0.as_ref() != new_digest.0.as_ref() {
digest = new_digest;
_ = tx.send(digest);
}
tokio::time::sleep(Duration::from_secs(60)).await
}
});
rx
impl PartialEq for CertDigest {
fn eq(&self, other: &Self) -> bool {
self.0.as_ref() == other.0.as_ref()
}
}
async fn compute_digest(cert_path: &str) -> CertDigest {
pub fn wait_until_cert_changed(digest: CertDigest, cert_path: &str) -> CertDigest {
loop {
match try_compute_digest(cert_path).await {
let new_digest = compute_digest(cert_path);
if digest != new_digest {
break new_digest;
}
// Wait a while before checking the certificates.
// We renew on a daily basis, so there's no rush.
std::thread::sleep(Duration::from_secs(60));
}
}
pub fn compute_digest(cert_path: &str) -> CertDigest {
loop {
match try_compute_digest(cert_path) {
Ok(d) => break d,
Err(e) => {
tracing::error!("could not read cert file {e:?}");
tokio::time::sleep(Duration::from_secs(1)).await
std::thread::sleep(Duration::from_secs(1))
}
}
}
}
async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
let data = tokio::fs::read(cert_path).await?;
fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
let data = std::fs::read(cert_path)?;
// sha256 is extremely collision resistent. can safely assume the digest to be unique
Ok(CertDigest(digest::digest(&digest::SHA256, &data)))
}
@@ -46,28 +47,37 @@ async fn try_compute_digest(cert_path: &str) -> Result<CertDigest> {
pub const SERVER_CRT: &str = "server.crt";
pub const SERVER_KEY: &str = "server.key";
pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
pub struct KeyPair {
crt: String,
key: String,
}
pub fn load_certs_blocking(tls_config: &TlsConfig) -> KeyPair {
loop {
match try_update_key_path_blocking(pg_data, tls_config) {
Ok(()) => break,
match try_load_certs_blocking(tls_config) {
Ok(key_pair) => break key_pair,
Err(e) => {
tracing::error!(error = ?e, "could not create key file");
tracing::error!(error = ?e, "could not load certs");
std::thread::sleep(Duration::from_secs(1))
}
}
}
}
// Postgres requires the keypath be "secure". This means
// 1. Owned by the postgres user.
// 2. Have permission 600.
fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Result<()> {
fn try_load_certs_blocking(tls_config: &TlsConfig) -> Result<KeyPair> {
let key = std::fs::read_to_string(&tls_config.key_path)?;
let crt = std::fs::read_to_string(&tls_config.cert_path)?;
// to mitigate a race condition during renewal.
verify_key_cert(&key, &crt)?;
Ok(KeyPair { key, crt })
}
// Postgres requires the keypath be "secure". This means
// 1. Owned by the postgres user.
// 2. Have permission 600.
pub fn update_key_path_blocking(pg_data: &Path, key_pair: &KeyPair) -> Result<()> {
let mut key_file = std::fs::OpenOptions::new()
.write(true)
.create(true)
@@ -82,14 +92,22 @@ fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Resul
.mode(0o600)
.open(pg_data.join(SERVER_CRT))?;
key_file.write_all(key.as_bytes())?;
crt_file.write_all(crt.as_bytes())?;
// NOTE: We currently ensure that an explicit reload does not happen during TLS renewal, but
// there's a chance that postgres/pgbouncer/local_proxy reloads implicitly halfway between
// these writes. This could allow them to reads the wrong keys to the wrong certs.
// There doesn't seem to be any way to prevent that. However, we will issue a reload shortly
// after which should at least correct it.
key_file.write_all(key_pair.key.as_bytes())?;
crt_file.write_all(key_pair.crt.as_bytes())?;
Ok(())
}
fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
use x509_cert::Certificate;
use x509_cert::der::oid::db::rfc5912::ECDSA_WITH_SHA_256;
use x509_cert::der::oid::db::rfc8410::ID_ED_25519;
use x509_cert::der::pem;
let certs = Certificate::load_pem_chain(cert.as_bytes())
.context("decoding PEM encoded certificates")?;
@@ -100,22 +118,30 @@ fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
bail!("no certificates found");
};
let pubkey = cert
.tbs_certificate
.subject_public_key_info
.subject_public_key
.raw_bytes();
match cert.signature_algorithm.oid {
ECDSA_WITH_SHA_256 => {
let key = p256::SecretKey::from_sec1_pem(key).context("parse key")?;
let a = key.public_key().to_sec1_bytes();
let b = cert
.tbs_certificate
.subject_public_key_info
.subject_public_key
.raw_bytes();
if *a != *b {
if *key.public_key().to_sec1_bytes() != *pubkey {
bail!("private key file does not match certificate")
}
}
_ => bail!("unknown TLS key type"),
ID_ED_25519 => {
use ring::signature::{Ed25519KeyPair, KeyPair};
let (_, bytes) = pem::decode_vec(key.as_bytes())
.map_err(|_| anyhow::anyhow!("invalid key encoding"))?;
let key = Ed25519KeyPair::from_pkcs8_maybe_unchecked(&bytes).context("parse key")?;
if *key.public_key().as_ref() != *pubkey {
bail!("private key file does not match certificate")
}
}
oid => bail!("unknown TLS key type: {oid}"),
}
Ok(())

View File

@@ -1089,7 +1089,8 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
default_tenant_id: TenantId::from_array(std::array::from_fn(|_| 0)),
storage_controller: None,
control_plane_hooks_api: None,
generate_local_ssl_certs: false,
generate_local_tls_certs: false,
generate_compute_tls_certs: false,
}
};

View File

@@ -23,7 +23,7 @@ impl StorageBroker {
}
pub fn initialize(&self) -> anyhow::Result<()> {
if self.env.generate_local_ssl_certs {
if self.env.generate_local_tls_certs {
self.env.generate_ssl_cert(
&self.env.storage_broker_data_dir().join("server.crt"),
&self.env.storage_broker_data_dir().join("server.key"),

View File

@@ -54,7 +54,6 @@ use compute_api::requests::{
};
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TerminateResponse,
TlsConfig,
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol,
@@ -213,8 +212,13 @@ impl ComputeControlPlane {
let internal_http_port = internal_http_port.unwrap_or_else(|| external_http_port + 1);
let compute_ctl_config = ComputeCtlConfig {
jwks: Self::create_jwks_from_pem(&self.env.read_public_key()?)?,
tls: None::<TlsConfig>,
tls: self.env.get_tls_config()?,
};
let mut features = vec![];
if compute_ctl_config.tls.is_some() {
features.push(ComputeFeature::TlsExperimental);
}
let ep = Arc::new(Endpoint {
endpoint_id: endpoint_id.to_owned(),
pg_address: SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), pg_port),
@@ -241,7 +245,7 @@ impl ComputeControlPlane {
drop_subscriptions_before_start,
grpc,
reconfigure_concurrency: 1,
features: vec![],
features: features.clone(),
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
@@ -263,7 +267,7 @@ impl ComputeControlPlane {
skip_pg_catalog_updates,
drop_subscriptions_before_start,
reconfigure_concurrency: 1,
features: vec![],
features,
cluster: None,
compute_ctl_config,
privileged_role_name,
@@ -953,7 +957,7 @@ impl Endpoint {
}
// keep retrying
}
ComputeStatus::Running => {
ComputeStatus::Reloading | ComputeStatus::Running => {
// All good!
break;
}

View File

@@ -12,6 +12,7 @@ use std::{env, fs};
use anyhow::{Context, bail};
use clap::ValueEnum;
use compute_api::responses::TlsConfig;
use pageserver_api::config::PostHogConfig;
use pem::Pem;
use postgres_backend::AuthType;
@@ -95,7 +96,10 @@ pub struct LocalEnv {
/// Flag to generate SSL certificates for components that need it.
/// Also generates root CA certificate that is used to sign all other certificates.
pub generate_local_ssl_certs: bool,
pub generate_local_tls_certs: bool,
/// Flag to generate SSL certificates for compute.
pub generate_compute_tls_certs: bool,
}
/// On-disk state stored in `.neon/config`.
@@ -123,7 +127,11 @@ pub struct OnDiskConfig {
// Note: skip serializing because in compat tests old storage controller fails
// to load new config file. May be removed after this field is in release branch.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub generate_local_ssl_certs: bool,
pub generate_local_tls_certs: bool,
// Note: skip serializing because in compat tests old storage controller fails
// to load new config file. May be removed after this field is in release branch.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub generate_compute_tls_certs: bool,
}
fn fail_if_pageservers_field_specified<'de, D>(_: D) -> Result<Vec<PageServerConf>, D::Error>
@@ -152,7 +160,8 @@ pub struct NeonLocalInitConf {
pub endpoint_storage: EndpointStorageConf,
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub generate_local_ssl_certs: bool,
pub generate_local_tls_certs: bool,
pub generate_compute_tls_certs: bool,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -511,7 +520,7 @@ impl LocalEnv {
}
pub fn ssl_ca_cert_path(&self) -> Option<PathBuf> {
if self.generate_local_ssl_certs {
if self.generate_local_tls_certs {
Some(self.base_data_dir.join("rootCA.crt"))
} else {
None
@@ -519,7 +528,7 @@ impl LocalEnv {
}
pub fn ssl_ca_key_path(&self) -> Option<PathBuf> {
if self.generate_local_ssl_certs {
if self.generate_local_tls_certs {
Some(self.base_data_dir.join("rootCA.key"))
} else {
None
@@ -545,6 +554,33 @@ impl LocalEnv {
)
}
fn compute_ssl_paths(&self) -> Option<(PathBuf, PathBuf)> {
if self.generate_compute_tls_certs {
Some((
self.base_data_dir.join("compute_server.crt"),
self.base_data_dir.join("compute_server.key"),
))
} else {
None
}
}
pub fn generate_compute_ssl_cert(&self) -> anyhow::Result<()> {
self.generate_ssl_ca_cert()?;
let (cert_path, key_path) = self.compute_ssl_paths().unwrap();
if !fs::exists(&cert_path)? {
generate_ssl_cert(
&cert_path,
&key_path,
self.ssl_ca_cert_path().unwrap().as_path(),
self.ssl_ca_key_path().unwrap().as_path(),
)?;
}
Ok(())
}
/// Creates HTTP client with local SSL CA certificates.
pub fn create_http_client(&self) -> reqwest::Client {
let ssl_ca_certs = self.ssl_ca_cert_path().map(|ssl_ca_file| {
@@ -673,7 +709,8 @@ impl LocalEnv {
control_plane_hooks_api,
control_plane_compute_hook_api: _,
branch_name_mappings,
generate_local_ssl_certs,
generate_local_tls_certs,
generate_compute_tls_certs,
endpoint_storage,
} = on_disk_config;
LocalEnv {
@@ -690,7 +727,8 @@ impl LocalEnv {
control_plane_api: control_plane_api.unwrap(),
control_plane_hooks_api,
branch_name_mappings,
generate_local_ssl_certs,
generate_local_tls_certs,
generate_compute_tls_certs,
endpoint_storage,
}
};
@@ -806,7 +844,8 @@ impl LocalEnv {
control_plane_hooks_api: self.control_plane_hooks_api.clone(),
control_plane_compute_hook_api: None,
branch_name_mappings: self.branch_name_mappings.clone(),
generate_local_ssl_certs: self.generate_local_ssl_certs,
generate_local_tls_certs: self.generate_local_tls_certs,
generate_compute_tls_certs: self.generate_compute_tls_certs,
endpoint_storage: self.endpoint_storage.clone(),
},
)
@@ -861,6 +900,21 @@ impl LocalEnv {
Ok(pem)
}
/// Get the TLS config if set.
pub fn get_tls_config(&self) -> anyhow::Result<Option<TlsConfig>> {
match self.compute_ssl_paths() {
Some((cert_path, key_path)) => {
self.generate_compute_ssl_cert()?;
Ok(Some(TlsConfig {
key_path: key_path.to_str().context("utf8")?.to_string(),
cert_path: cert_path.to_str().context("utf8")?.to_string(),
}))
}
None => Ok(None),
}
}
/// Materialize the [`NeonLocalInitConf`] to disk. Called during [`neon_local init`].
pub fn init(conf: NeonLocalInitConf, force: &InitForceMode) -> anyhow::Result<()> {
let base_path = base_path();
@@ -912,7 +966,8 @@ impl LocalEnv {
pageservers,
safekeepers,
control_plane_api,
generate_local_ssl_certs,
generate_local_tls_certs,
generate_compute_tls_certs,
control_plane_hooks_api,
endpoint_storage,
} = conf;
@@ -965,13 +1020,17 @@ impl LocalEnv {
control_plane_api: control_plane_api.unwrap(),
control_plane_hooks_api,
branch_name_mappings: Default::default(),
generate_local_ssl_certs,
generate_local_tls_certs,
generate_compute_tls_certs,
endpoint_storage,
};
if generate_local_ssl_certs {
if generate_local_tls_certs {
env.generate_ssl_ca_cert()?;
}
if generate_compute_tls_certs {
env.generate_compute_ssl_cert()?;
}
// create endpoints dir
fs::create_dir_all(env.endpoints_path())?;

View File

@@ -241,7 +241,7 @@ impl PageServerNode {
.context("write identity toml")?;
drop(identity_toml);
if self.env.generate_local_ssl_certs {
if self.env.generate_local_tls_certs {
self.env.generate_ssl_cert(
datadir.join("server.crt").as_path(),
datadir.join("server.key").as_path(),

View File

@@ -102,7 +102,7 @@ impl SafekeeperNode {
/// Initializes a safekeeper node by creating all necessary files,
/// e.g. SSL certificates and JWT token file.
pub fn initialize(&self) -> anyhow::Result<()> {
if self.env.generate_local_ssl_certs {
if self.env.generate_local_tls_certs {
self.env.generate_ssl_cert(
&self.datadir_path().join("server.crt"),
&self.datadir_path().join("server.key"),

View File

@@ -353,7 +353,7 @@ impl StorageController {
}
}
if self.env.generate_local_ssl_certs {
if self.env.generate_local_tls_certs {
self.env.generate_ssl_cert(
&instance_dir.join("server.crt"),
&instance_dir.join("server.key"),

View File

@@ -27,7 +27,6 @@ pub struct ComputeConfig {
pub spec: Option<ComputeSpec>,
/// The compute_ctl configuration
#[allow(dead_code)]
pub compute_ctl_config: ComputeCtlConfig,
}
@@ -155,6 +154,8 @@ pub enum ComputeStatus {
Empty,
// Compute configuration was requested.
ConfigurationPending,
// Postgres, pgbouncer, and local_proxy is currently being reloaded.
Reloading,
// Compute node has spec and initial startup and
// configuration is in progress.
Init,
@@ -189,6 +190,7 @@ impl Display for ComputeStatus {
match self {
ComputeStatus::Empty => f.write_str("empty"),
ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"),
ComputeStatus::Reloading => f.write_str("reloading"),
ComputeStatus::RefreshConfiguration => f.write_str("refresh-configuration"),
ComputeStatus::RefreshConfigurationPending => {
f.write_str("refresh-configuration-pending")

View File

@@ -506,6 +506,8 @@ class NeonEnvBuilder:
# Flag to use https listener in storage broker, generate local ssl certs,
# and force pageservers and safekeepers to use https for storage broker api.
self.use_https_storage_broker_api: bool = False
# Flag to enable TLS for computes
self.use_compute_tls: bool = False
self.pageserver_virtual_file_io_engine: str | None = pageserver_virtual_file_io_engine
self.pageserver_get_vectored_concurrent_io: str | None = (
@@ -1112,14 +1114,16 @@ class NeonEnv:
self.initial_tenant = config.initial_tenant
self.initial_timeline = config.initial_timeline
self.generate_local_ssl_certs = (
self.generate_compute_tls_certs = config.use_compute_tls
self.generate_local_tls_certs = (
config.use_https_pageserver_api
or config.use_https_safekeeper_api
or config.use_https_storage_controller_api
or config.use_https_storage_broker_api
or config.use_compute_tls
)
self.ssl_ca_file = (
self.repo_dir.joinpath("rootCA.crt") if self.generate_local_ssl_certs else None
self.tls_ca_file = (
self.repo_dir.joinpath("rootCA.crt") if self.generate_local_tls_certs else None
)
neon_local_env_vars = {}
@@ -1198,7 +1202,8 @@ class NeonEnv:
"endpoint_storage": {
"listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}",
},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
"generate_local_tls_certs": self.generate_local_tls_certs,
"generate_compute_tls_certs": self.generate_compute_tls_certs,
}
if config.use_https_storage_broker_api:
@@ -1942,7 +1947,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
self.auth_enabled = auth_enabled
self.allowed_errors: list[str] = DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS
self.logfile = self.env.repo_dir / "storage_controller_1" / "storage_controller.log"
self.ssl_ca_file = env.ssl_ca_file
self.tls_ca_file = env.tls_ca_file
def start(
self,
@@ -2015,8 +2020,8 @@ class NeonStorageController(MetricsGetter, LogUtils):
return PageserverHttpClient(self.port, lambda: True, auth_token, *args, **kwargs)
def request(self, method, *args, **kwargs) -> requests.Response:
if self.ssl_ca_file is not None:
kwargs["verify"] = self.ssl_ca_file
if self.tls_ca_file is not None:
kwargs["verify"] = self.tls_ca_file
resp = requests.request(method, *args, **kwargs)
NeonStorageController.raise_api_exception(resp)

View File

@@ -19,7 +19,7 @@ def test_pageserver_https_api(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
addr = f"https://localhost:{env.pageserver.service_port.https}/v1/status"
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
def test_safekeeper_https_api(neon_env_builder: NeonEnvBuilder):
@@ -37,7 +37,7 @@ def test_safekeeper_https_api(neon_env_builder: NeonEnvBuilder):
# 1. Make simple https request.
addr = f"https://localhost:{sk.port.https}/v1/status"
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
# Note: http_port is intentionally wrong.
# Storcon should not use it if use_https is on.
@@ -83,7 +83,7 @@ def test_storage_controller_https_api(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
addr = f"https://localhost:{env.storage_controller.port}/status"
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
def test_certificate_rotation(neon_env_builder: NeonEnvBuilder):
@@ -111,7 +111,7 @@ def test_certificate_rotation(neon_env_builder: NeonEnvBuilder):
# 1. Check if https works.
addr = f"https://localhost:{port}/v1/status"
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
ps_cert_path = env.pageserver.workdir / "server.crt"
ps_key_path = env.pageserver.workdir / "server.key"
@@ -136,7 +136,7 @@ def test_certificate_rotation(neon_env_builder: NeonEnvBuilder):
wait_until(error_reloading_cert)
# 4. Check that it uses old cert.
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
cur_cert = ssl.get_server_certificate(("localhost", port))
assert cur_cert == ps_cert
@@ -150,7 +150,7 @@ def test_certificate_rotation(neon_env_builder: NeonEnvBuilder):
wait_until(cert_reloaded)
# 6. Check that server returns new cert.
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
cur_cert = ssl.get_server_certificate(("localhost", port))
assert cur_cert == sk_cert
@@ -174,7 +174,7 @@ def test_server_and_cert_metrics(neon_env_builder: NeonEnvBuilder):
)
addr = f"https://localhost:{env.pageserver.service_port.https}/v1/status"
requests.get(addr, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(addr, verify=str(env.tls_ca_file)).raise_for_status()
new_https_conn_count = (
ps_client.get_metric_value("http_server_connection_started_total", filter_https) or 0
@@ -227,10 +227,27 @@ def test_storage_broker_https_api(neon_env_builder: NeonEnvBuilder):
# 1. Simple check that HTTPS is enabled and works.
url = env.broker.client_url() + "/status"
assert url.startswith("https://")
requests.get(url, verify=str(env.ssl_ca_file)).raise_for_status()
requests.get(url, verify=str(env.tls_ca_file)).raise_for_status()
# 2. Simple workload to check that SK -> broker -> PS communication works over HTTPS.
workload = Workload(env, env.initial_tenant, env.initial_timeline)
workload.init()
workload.write_rows(10)
workload.validate()
def test_compute_tls(
neon_env_builder: NeonEnvBuilder,
):
neon_env_builder.use_compute_tls = True
env = neon_env_builder.init_start()
env.create_branch("test_compute_tls")
with env.endpoints.create_start("test_compute_tls") as endpoint:
res = endpoint.safe_psql(
"select ssl from pg_stat_ssl where pid = pg_backend_pid();",
sslmode="verify-full",
sslrootcert=env.tls_ca_file,
)
assert res == [(True,)]