mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
merge pg-sni-router into proxy (#11882)
## Problem We realised that pg-sni-router doesn't need to be separate from proxy. just a separate port. ## Summary of changes Add pg-sni-router config to proxy and expose the service.
This commit is contained in:
@@ -423,8 +423,8 @@ async fn refresh_config_inner(
|
||||
if let Some(tls_config) = data.tls {
|
||||
let tls_config = tokio::task::spawn_blocking(move || {
|
||||
crate::tls::server_config::configure_tls(
|
||||
&tls_config.key_path,
|
||||
&tls_config.cert_path,
|
||||
tls_config.key_path.as_ref(),
|
||||
tls_config.cert_path.as_ref(),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
/// A stand-alone program that routes connections, e.g. from
|
||||
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
|
||||
///
|
||||
/// This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
/// the outside. Similar to an ingress controller for HTTPS.
|
||||
//! A stand-alone program that routes connections, e.g. from
|
||||
//! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
|
||||
//!
|
||||
//! This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
//! the outside. Similar to an ingress controller for HTTPS.
|
||||
|
||||
use std::path::Path;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::{Context, anyhow, bail, ensure};
|
||||
@@ -86,46 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
|
||||
let mut keys =
|
||||
rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
PrivateKeyDer::Pkcs8(
|
||||
keys.pop()
|
||||
.expect("keys should not be empty")
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?,
|
||||
)
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain: Vec<_> = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.context("ring should support TLS1.2 and TLS1.3")?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
(tls_config, tls_server_end_point)
|
||||
}
|
||||
(Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
@@ -188,7 +151,58 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
match signal {}
|
||||
}
|
||||
|
||||
async fn task_main(
|
||||
pub(super) fn parse_tls(
|
||||
key_path: &Path,
|
||||
cert_path: &Path,
|
||||
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
PrivateKeyDer::Pkcs8(
|
||||
keys.pop()
|
||||
.expect("keys should not be empty")
|
||||
.context(format!(
|
||||
"Failed to read TLS keys at '{}'",
|
||||
key_path.display()
|
||||
))?,
|
||||
)
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path).context(format!(
|
||||
"Failed to read TLS cert file at '{}.'",
|
||||
cert_path.display()
|
||||
))?;
|
||||
|
||||
let cert_chain: Vec<_> = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{}'.",
|
||||
cert_path.display()
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.context("ring should support TLS1.2 and TLS1.3")?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
Ok((tls_config, tls_server_end_point))
|
||||
}
|
||||
|
||||
pub(super) async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::{bail, ensure};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use futures::future::Either;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
@@ -62,18 +63,18 @@ struct ProxyCliArgs {
|
||||
region: String,
|
||||
/// listen for incoming client connections on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:4432")]
|
||||
proxy: String,
|
||||
proxy: SocketAddr,
|
||||
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
|
||||
auth_backend: AuthBackendType,
|
||||
/// listen for management callback connection on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:7000")]
|
||||
mgmt: String,
|
||||
mgmt: SocketAddr,
|
||||
/// listen for incoming http connections (metrics, etc) on ip:port
|
||||
#[clap(long, default_value = "127.0.0.1:7001")]
|
||||
http: String,
|
||||
http: SocketAddr,
|
||||
/// listen for incoming wss connections on ip:port
|
||||
#[clap(long)]
|
||||
wss: Option<String>,
|
||||
wss: Option<SocketAddr>,
|
||||
/// redirect unauthenticated users to the given uri in case of console redirect auth
|
||||
#[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
|
||||
uri: String,
|
||||
@@ -99,18 +100,18 @@ struct ProxyCliArgs {
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
#[clap(short = 'k', long, alias = "ssl-key")]
|
||||
tls_key: Option<String>,
|
||||
tls_key: Option<PathBuf>,
|
||||
/// path to TLS cert for client postgres connections
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
#[clap(short = 'c', long, alias = "ssl-cert")]
|
||||
tls_cert: Option<String>,
|
||||
tls_cert: Option<PathBuf>,
|
||||
/// Allow writing TLS session keys to the given file pointed to by the environment variable `SSLKEYLOGFILE`.
|
||||
#[clap(long, alias = "allow-ssl-keylogfile")]
|
||||
allow_tls_keylogfile: bool,
|
||||
/// path to directory with TLS certificates for client postgres connections
|
||||
#[clap(long)]
|
||||
certs_dir: Option<String>,
|
||||
certs_dir: Option<PathBuf>,
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
@@ -229,6 +230,9 @@ struct ProxyCliArgs {
|
||||
// TODO: rename to `console_redirect_confirmation_timeout`.
|
||||
#[clap(long, default_value = "2m", value_parser = humantime::parse_duration)]
|
||||
webauth_confirmation_timeout: std::time::Duration,
|
||||
|
||||
#[clap(flatten)]
|
||||
pg_sni_router: PgSniRouterArgs,
|
||||
}
|
||||
|
||||
#[derive(clap::Args, Clone, Copy, Debug)]
|
||||
@@ -277,6 +281,25 @@ struct SqlOverHttpArgs {
|
||||
sql_over_http_max_response_size_bytes: usize,
|
||||
}
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
struct PgSniRouterArgs {
|
||||
/// listen for incoming client connections on ip:port
|
||||
#[clap(id = "sni-router-listen", long, default_value = "127.0.0.1:4432")]
|
||||
listen: SocketAddr,
|
||||
/// listen for incoming client connections on ip:port, requiring TLS to compute
|
||||
#[clap(id = "sni-router-listen-tls", long, default_value = "127.0.0.1:4433")]
|
||||
listen_tls: SocketAddr,
|
||||
/// path to TLS key for client postgres connections
|
||||
#[clap(id = "sni-router-tls-key", long)]
|
||||
tls_key: Option<PathBuf>,
|
||||
/// path to TLS cert for client postgres connections
|
||||
#[clap(id = "sni-router-tls-cert", long)]
|
||||
tls_cert: Option<PathBuf>,
|
||||
/// append this domain zone to the SNI hostname to get the destination address
|
||||
#[clap(id = "sni-router-destination", long)]
|
||||
dest: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn run() -> anyhow::Result<()> {
|
||||
let _logging_guard = crate::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
@@ -307,73 +330,51 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
|
||||
}
|
||||
info!("Using region: {}", args.aws_region);
|
||||
|
||||
// TODO: untangle the config args
|
||||
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
|
||||
("plain", redis_url) => match redis_url {
|
||||
None => {
|
||||
bail!("plain auth requires redis_notifications to be set");
|
||||
}
|
||||
Some(url) => {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
|
||||
}
|
||||
},
|
||||
("irsa", _) => match (&args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host.to_string(),
|
||||
port,
|
||||
elasticache::CredentialsProvider::new(
|
||||
args.aws_region,
|
||||
args.redis_cluster_name,
|
||||
args.redis_user_id,
|
||||
)
|
||||
.await,
|
||||
),
|
||||
),
|
||||
(None, None) => {
|
||||
warn!(
|
||||
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
|
||||
);
|
||||
None
|
||||
}
|
||||
_ => {
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
bail!("unknown auth type given");
|
||||
}
|
||||
};
|
||||
|
||||
let redis_notifications_client = if let Some(url) = args.redis_notifications {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
|
||||
} else {
|
||||
regional_redis_client.clone()
|
||||
};
|
||||
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
let http_address: SocketAddr = args.http.parse()?;
|
||||
info!("Starting http on {http_address}");
|
||||
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
|
||||
info!("Starting http on {}", args.http);
|
||||
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
|
||||
|
||||
let mgmt_address: SocketAddr = args.mgmt.parse()?;
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||
info!("Starting mgmt on {}", args.mgmt);
|
||||
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
|
||||
|
||||
let proxy_listener = if args.is_auth_broker {
|
||||
None
|
||||
} else {
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
info!("Starting proxy on {}", args.proxy);
|
||||
Some(TcpListener::bind(args.proxy).await?)
|
||||
};
|
||||
|
||||
Some(TcpListener::bind(proxy_address).await?)
|
||||
let sni_router_listeners = {
|
||||
let args = &args.pg_sni_router;
|
||||
if args.dest.is_some() {
|
||||
ensure!(
|
||||
args.tls_key.is_some(),
|
||||
"sni-router-tls-key must be provided"
|
||||
);
|
||||
ensure!(
|
||||
args.tls_cert.is_some(),
|
||||
"sni-router-tls-cert must be provided"
|
||||
);
|
||||
|
||||
info!(
|
||||
"Starting pg-sni-router on {} and {}",
|
||||
args.listen, args.listen_tls
|
||||
);
|
||||
|
||||
Some((
|
||||
TcpListener::bind(args.listen).await?,
|
||||
TcpListener::bind(args.listen_tls).await?,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
let serverless_listener = if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
Some(TcpListener::bind(serverless_address).await?)
|
||||
} else if args.is_auth_broker {
|
||||
@@ -458,6 +459,37 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// spawn pg-sni-router mode.
|
||||
if let Some((listen, listen_tls)) = sni_router_listeners {
|
||||
let args = args.pg_sni_router;
|
||||
let dest = args.dest.expect("already asserted it is set");
|
||||
let key_path = args.tls_key.expect("already asserted it is set");
|
||||
let cert_path = args.tls_cert.expect("already asserted it is set");
|
||||
|
||||
let (tls_config, tls_server_end_point) =
|
||||
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
|
||||
let dest = Arc::new(dest);
|
||||
|
||||
client_tasks.spawn(super::pg_sni_router::task_main(
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
listen,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
|
||||
client_tasks.spawn(super::pg_sni_router::task_main(
|
||||
dest,
|
||||
tls_config,
|
||||
Some(config.connect_to_compute.tls.clone()),
|
||||
tls_server_end_point,
|
||||
listen_tls,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
client_tasks.spawn(crate::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
@@ -565,7 +597,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
key_path,
|
||||
cert_path,
|
||||
args.certs_dir.as_ref(),
|
||||
args.certs_dir.as_deref(),
|
||||
args.allow_tls_keylogfile,
|
||||
)?),
|
||||
(None, None) => None,
|
||||
@@ -811,6 +843,60 @@ fn build_auth_backend(
|
||||
}
|
||||
}
|
||||
|
||||
async fn configure_redis(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<(
|
||||
Option<ConnectionWithCredentialsProvider>,
|
||||
Option<ConnectionWithCredentialsProvider>,
|
||||
)> {
|
||||
// TODO: untangle the config args
|
||||
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
|
||||
("plain", redis_url) => match redis_url {
|
||||
None => {
|
||||
bail!("plain auth requires redis_notifications to be set");
|
||||
}
|
||||
Some(url) => {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
|
||||
}
|
||||
},
|
||||
("irsa", _) => match (&args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host.to_string(),
|
||||
port,
|
||||
elasticache::CredentialsProvider::new(
|
||||
args.aws_region.clone(),
|
||||
args.redis_cluster_name.clone(),
|
||||
args.redis_user_id.clone(),
|
||||
)
|
||||
.await,
|
||||
),
|
||||
),
|
||||
(None, None) => {
|
||||
// todo: upgrade to error?
|
||||
warn!(
|
||||
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
|
||||
);
|
||||
None
|
||||
}
|
||||
_ => {
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
bail!("unknown auth type given");
|
||||
}
|
||||
};
|
||||
|
||||
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
|
||||
} else {
|
||||
regional_redis_client.clone()
|
||||
};
|
||||
|
||||
Ok((regional_redis_client, redis_notifications_client))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
Reference in New Issue
Block a user