merge pg-sni-router into proxy (#11882)

## Problem

We realised that pg-sni-router doesn't need to be separate from proxy.
just a separate port.

## Summary of changes

Add pg-sni-router config to proxy and expose the service.
This commit is contained in:
Conrad Ludgate
2025-05-12 16:48:48 +01:00
committed by GitHub
parent a618056770
commit a77919f4b2
7 changed files with 283 additions and 127 deletions

View File

@@ -423,8 +423,8 @@ async fn refresh_config_inner(
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
&tls_config.key_path,
&tls_config.cert_path,
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)

View File

@@ -1,8 +1,10 @@
/// A stand-alone program that routes connections, e.g. from
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
///
/// This allows connecting to pods/services running in the same Kubernetes cluster from
/// the outside. Similar to an ingress controller for HTTPS.
//! A stand-alone program that routes connections, e.g. from
//! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
//!
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::path::Path;
use std::{net::SocketAddr, sync::Arc};
use anyhow::{Context, anyhow, bail, ensure};
@@ -86,46 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
(Some(key_path), Some(cert_path)) => {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys =
rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
PrivateKeyDer::Pkcs8(
keys.pop()
.expect("keys should not be empty")
.context(format!("Failed to read TLS keys at '{key_path}'"))?,
)
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain: Vec<_> = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
}
(Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -188,7 +151,58 @@ pub async fn run() -> anyhow::Result<()> {
match signal {}
}
async fn task_main(
pub(super) fn parse_tls(
key_path: &Path,
cert_path: &Path,
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
PrivateKeyDer::Pkcs8(
keys.pop()
.expect("keys should not be empty")
.context(format!(
"Failed to read TLS keys at '{}'",
key_path.display()
))?,
)
};
let cert_chain_bytes = std::fs::read(cert_path).context(format!(
"Failed to read TLS cert file at '{}.'",
cert_path.display()
))?;
let cert_chain: Vec<_> = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{}'.",
cert_path.display()
)
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
Ok((tls_config, tls_server_end_point))
}
pub(super) async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,

View File

@@ -1,9 +1,10 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
use futures::future::Either;
use remote_storage::RemoteStorageConfig;
@@ -62,18 +63,18 @@ struct ProxyCliArgs {
region: String,
/// listen for incoming client connections on ip:port
#[clap(short, long, default_value = "127.0.0.1:4432")]
proxy: String,
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: String,
mgmt: SocketAddr,
/// listen for incoming http connections (metrics, etc) on ip:port
#[clap(long, default_value = "127.0.0.1:7001")]
http: String,
http: SocketAddr,
/// listen for incoming wss connections on ip:port
#[clap(long)]
wss: Option<String>,
wss: Option<SocketAddr>,
/// redirect unauthenticated users to the given uri in case of console redirect auth
#[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
uri: String,
@@ -99,18 +100,18 @@ struct ProxyCliArgs {
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
#[clap(short = 'k', long, alias = "ssl-key")]
tls_key: Option<String>,
tls_key: Option<PathBuf>,
/// path to TLS cert for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
#[clap(short = 'c', long, alias = "ssl-cert")]
tls_cert: Option<String>,
tls_cert: Option<PathBuf>,
/// Allow writing TLS session keys to the given file pointed to by the environment variable `SSLKEYLOGFILE`.
#[clap(long, alias = "allow-ssl-keylogfile")]
allow_tls_keylogfile: bool,
/// path to directory with TLS certificates for client postgres connections
#[clap(long)]
certs_dir: Option<String>,
certs_dir: Option<PathBuf>,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
@@ -229,6 +230,9 @@ struct ProxyCliArgs {
// TODO: rename to `console_redirect_confirmation_timeout`.
#[clap(long, default_value = "2m", value_parser = humantime::parse_duration)]
webauth_confirmation_timeout: std::time::Duration,
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -277,6 +281,25 @@ struct SqlOverHttpArgs {
sql_over_http_max_response_size_bytes: usize,
}
#[derive(clap::Args, Clone, Debug)]
struct PgSniRouterArgs {
/// listen for incoming client connections on ip:port
#[clap(id = "sni-router-listen", long, default_value = "127.0.0.1:4432")]
listen: SocketAddr,
/// listen for incoming client connections on ip:port, requiring TLS to compute
#[clap(id = "sni-router-listen-tls", long, default_value = "127.0.0.1:4433")]
listen_tls: SocketAddr,
/// path to TLS key for client postgres connections
#[clap(id = "sni-router-tls-key", long)]
tls_key: Option<PathBuf>,
/// path to TLS cert for client postgres connections
#[clap(id = "sni-router-tls-cert", long)]
tls_cert: Option<PathBuf>,
/// append this domain zone to the SNI hostname to get the destination address
#[clap(id = "sni-router-destination", long)]
dest: Option<String>,
}
pub async fn run() -> anyhow::Result<()> {
let _logging_guard = crate::logging::init().await?;
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
@@ -307,73 +330,51 @@ pub async fn run() -> anyhow::Result<()> {
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.to_string(),
port,
elasticache::CredentialsProvider::new(
args.aws_region,
args.redis_cluster_name,
args.redis_user_id,
)
.await,
),
),
(None, None) => {
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
}
};
let redis_notifications_client = if let Some(url) = args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
} else {
regional_redis_client.clone()
};
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
info!("Starting http on {http_address}");
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
info!("Starting http on {}", args.http);
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
let mgmt_address: SocketAddr = args.mgmt.parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
info!("Starting mgmt on {}", args.mgmt);
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
let proxy_listener = if args.is_auth_broker {
None
} else {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
info!("Starting proxy on {}", args.proxy);
Some(TcpListener::bind(args.proxy).await?)
};
Some(TcpListener::bind(proxy_address).await?)
let sni_router_listeners = {
let args = &args.pg_sni_router;
if args.dest.is_some() {
ensure!(
args.tls_key.is_some(),
"sni-router-tls-key must be provided"
);
ensure!(
args.tls_cert.is_some(),
"sni-router-tls-cert must be provided"
);
info!(
"Starting pg-sni-router on {} and {}",
args.listen, args.listen_tls
);
Some((
TcpListener::bind(args.listen).await?,
TcpListener::bind(args.listen_tls).await?,
))
} else {
None
}
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
@@ -458,6 +459,37 @@ pub async fn run() -> anyhow::Result<()> {
}
}
// spawn pg-sni-router mode.
if let Some((listen, listen_tls)) = sni_router_listeners {
let args = args.pg_sni_router;
let dest = args.dest.expect("already asserted it is set");
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let (tls_config, tls_server_end_point) =
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
tls_server_end_point,
listen_tls,
cancellation_token.clone(),
));
}
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
@@ -565,7 +597,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
cert_path,
args.certs_dir.as_ref(),
args.certs_dir.as_deref(),
args.allow_tls_keylogfile,
)?),
(None, None) => None,
@@ -811,6 +843,60 @@ fn build_auth_backend(
}
}
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.to_string(),
port,
elasticache::CredentialsProvider::new(
args.aws_region.clone(),
args.redis_cluster_name.clone(),
args.redis_user_id.clone(),
)
.await,
),
),
(None, None) => {
// todo: upgrade to error?
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
}
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
};
Ok((regional_redis_client, redis_notifications_client))
}
#[cfg(test)]
mod tests {
use std::time::Duration;