From a77919f4b2668277795d731a343f0955bf144eb7 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 May 2025 16:48:48 +0100 Subject: [PATCH] merge pg-sni-router into proxy (#11882) ## Problem We realised that pg-sni-router doesn't need to be separate from proxy. just a separate port. ## Summary of changes Add pg-sni-router config to proxy and expose the service. --- proxy/src/binary/local_proxy.rs | 4 +- proxy/src/binary/pg_sni_router.rs | 106 +++++---- proxy/src/binary/proxy.rs | 212 ++++++++++++------ proxy/src/tls/server_config.rs | 33 +-- test_runner/fixtures/neon_fixtures.py | 25 +++ .../regress/test_proxy_metric_collection.py | 4 + test_runner/regress/test_sni_router.py | 26 ++- 7 files changed, 283 insertions(+), 127 deletions(-) diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index ee7f6ffcd7..a566383390 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -423,8 +423,8 @@ async fn refresh_config_inner( if let Some(tls_config) = data.tls { let tls_config = tokio::task::spawn_blocking(move || { crate::tls::server_config::configure_tls( - &tls_config.key_path, - &tls_config.cert_path, + tls_config.key_path.as_ref(), + tls_config.cert_path.as_ref(), None, false, ) diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 19be058ac3..2239d064b2 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -1,8 +1,10 @@ -/// A stand-alone program that routes connections, e.g. from -/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`. -/// -/// This allows connecting to pods/services running in the same Kubernetes cluster from -/// the outside. Similar to an ingress controller for HTTPS. +//! A stand-alone program that routes connections, e.g. from +//! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`. +//! +//! This allows connecting to pods/services running in the same Kubernetes cluster from +//! the outside. Similar to an ingress controller for HTTPS. + +use std::path::Path; use std::{net::SocketAddr, sync::Arc}; use anyhow::{Context, anyhow, bail, ensure}; @@ -86,46 +88,7 @@ pub async fn run() -> anyhow::Result<()> { args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { - (Some(key_path), Some(cert_path)) => { - let key = { - let key_bytes = std::fs::read(key_path).context("TLS key file")?; - - let mut keys = - rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec(); - - ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); - PrivateKeyDer::Pkcs8( - keys.pop() - .expect("keys should not be empty") - .context(format!("Failed to read TLS keys at '{key_path}'"))?, - ) - }; - - let cert_chain_bytes = std::fs::read(cert_path) - .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; - - let cert_chain: Vec<_> = { - rustls_pemfile::certs(&mut &cert_chain_bytes[..]) - .try_collect() - .with_context(|| { - format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.") - })? - }; - - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - - let tls_config = - rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) - .context("ring should support TLS1.2 and TLS1.3")? - .with_no_client_auth() - .with_single_cert(cert_chain, key)? - .into(); - - (tls_config, tls_server_end_point) - } + (Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?, _ => bail!("tls-key and tls-cert must be specified"), }; @@ -188,7 +151,58 @@ pub async fn run() -> anyhow::Result<()> { match signal {} } -async fn task_main( +pub(super) fn parse_tls( + key_path: &Path, + cert_path: &Path, +) -> anyhow::Result<(Arc, TlsServerEndPoint)> { + let key = { + let key_bytes = std::fs::read(key_path).context("TLS key file")?; + + let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec(); + + ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); + PrivateKeyDer::Pkcs8( + keys.pop() + .expect("keys should not be empty") + .context(format!( + "Failed to read TLS keys at '{}'", + key_path.display() + ))?, + ) + }; + + let cert_chain_bytes = std::fs::read(cert_path).context(format!( + "Failed to read TLS cert file at '{}.'", + cert_path.display() + ))?; + + let cert_chain: Vec<_> = { + rustls_pemfile::certs(&mut &cert_chain_bytes[..]) + .try_collect() + .with_context(|| { + format!( + "Failed to read TLS certificate chain from bytes from file at '{}'.", + cert_path.display() + ) + })? + }; + + // needed for channel bindings + let first_cert = cert_chain.first().context("missing certificate")?; + let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; + + let tls_config = + rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .context("ring should support TLS1.2 and TLS1.3")? + .with_no_client_auth() + .with_single_cert(cert_chain, key)? + .into(); + + Ok((tls_config, tls_server_end_point)) +} + +pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index e03f2f33d9..fe0d551f7f 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -1,9 +1,10 @@ use std::net::SocketAddr; +use std::path::PathBuf; use std::pin::pin; use std::sync::Arc; use std::time::Duration; -use anyhow::bail; +use anyhow::{bail, ensure}; use arc_swap::ArcSwapOption; use futures::future::Either; use remote_storage::RemoteStorageConfig; @@ -62,18 +63,18 @@ struct ProxyCliArgs { region: String, /// listen for incoming client connections on ip:port #[clap(short, long, default_value = "127.0.0.1:4432")] - proxy: String, + proxy: SocketAddr, #[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)] auth_backend: AuthBackendType, /// listen for management callback connection on ip:port #[clap(short, long, default_value = "127.0.0.1:7000")] - mgmt: String, + mgmt: SocketAddr, /// listen for incoming http connections (metrics, etc) on ip:port #[clap(long, default_value = "127.0.0.1:7001")] - http: String, + http: SocketAddr, /// listen for incoming wss connections on ip:port #[clap(long)] - wss: Option, + wss: Option, /// redirect unauthenticated users to the given uri in case of console redirect auth #[clap(short, long, default_value = "http://localhost:3000/psql_session/")] uri: String, @@ -99,18 +100,18 @@ struct ProxyCliArgs { /// /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir #[clap(short = 'k', long, alias = "ssl-key")] - tls_key: Option, + tls_key: Option, /// path to TLS cert for client postgres connections /// /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir #[clap(short = 'c', long, alias = "ssl-cert")] - tls_cert: Option, + tls_cert: Option, /// Allow writing TLS session keys to the given file pointed to by the environment variable `SSLKEYLOGFILE`. #[clap(long, alias = "allow-ssl-keylogfile")] allow_tls_keylogfile: bool, /// path to directory with TLS certificates for client postgres connections #[clap(long)] - certs_dir: Option, + certs_dir: Option, /// timeout for the TLS handshake #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] handshake_timeout: tokio::time::Duration, @@ -229,6 +230,9 @@ struct ProxyCliArgs { // TODO: rename to `console_redirect_confirmation_timeout`. #[clap(long, default_value = "2m", value_parser = humantime::parse_duration)] webauth_confirmation_timeout: std::time::Duration, + + #[clap(flatten)] + pg_sni_router: PgSniRouterArgs, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -277,6 +281,25 @@ struct SqlOverHttpArgs { sql_over_http_max_response_size_bytes: usize, } +#[derive(clap::Args, Clone, Debug)] +struct PgSniRouterArgs { + /// listen for incoming client connections on ip:port + #[clap(id = "sni-router-listen", long, default_value = "127.0.0.1:4432")] + listen: SocketAddr, + /// listen for incoming client connections on ip:port, requiring TLS to compute + #[clap(id = "sni-router-listen-tls", long, default_value = "127.0.0.1:4433")] + listen_tls: SocketAddr, + /// path to TLS key for client postgres connections + #[clap(id = "sni-router-tls-key", long)] + tls_key: Option, + /// path to TLS cert for client postgres connections + #[clap(id = "sni-router-tls-cert", long)] + tls_cert: Option, + /// append this domain zone to the SNI hostname to get the destination address + #[clap(id = "sni-router-destination", long)] + dest: Option, +} + pub async fn run() -> anyhow::Result<()> { let _logging_guard = crate::logging::init().await?; let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); @@ -307,73 +330,51 @@ pub async fn run() -> anyhow::Result<()> { Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"), } info!("Using region: {}", args.aws_region); - - // TODO: untangle the config args - let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) { - ("plain", redis_url) => match redis_url { - None => { - bail!("plain auth requires redis_notifications to be set"); - } - Some(url) => { - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())) - } - }, - ("irsa", _) => match (&args.redis_host, args.redis_port) { - (Some(host), Some(port)) => Some( - ConnectionWithCredentialsProvider::new_with_credentials_provider( - host.to_string(), - port, - elasticache::CredentialsProvider::new( - args.aws_region, - args.redis_cluster_name, - args.redis_user_id, - ) - .await, - ), - ), - (None, None) => { - warn!( - "irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client" - ); - None - } - _ => { - bail!("redis-host and redis-port must be specified together"); - } - }, - _ => { - bail!("unknown auth type given"); - } - }; - - let redis_notifications_client = if let Some(url) = args.redis_notifications { - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url)) - } else { - regional_redis_client.clone() - }; + let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?; // Check that we can bind to address before further initialization - let http_address: SocketAddr = args.http.parse()?; - info!("Starting http on {http_address}"); - let http_listener = TcpListener::bind(http_address).await?.into_std()?; + info!("Starting http on {}", args.http); + let http_listener = TcpListener::bind(args.http).await?.into_std()?; - let mgmt_address: SocketAddr = args.mgmt.parse()?; - info!("Starting mgmt on {mgmt_address}"); - let mgmt_listener = TcpListener::bind(mgmt_address).await?; + info!("Starting mgmt on {}", args.mgmt); + let mgmt_listener = TcpListener::bind(args.mgmt).await?; let proxy_listener = if args.is_auth_broker { None } else { - let proxy_address: SocketAddr = args.proxy.parse()?; - info!("Starting proxy on {proxy_address}"); + info!("Starting proxy on {}", args.proxy); + Some(TcpListener::bind(args.proxy).await?) + }; - Some(TcpListener::bind(proxy_address).await?) + let sni_router_listeners = { + let args = &args.pg_sni_router; + if args.dest.is_some() { + ensure!( + args.tls_key.is_some(), + "sni-router-tls-key must be provided" + ); + ensure!( + args.tls_cert.is_some(), + "sni-router-tls-cert must be provided" + ); + + info!( + "Starting pg-sni-router on {} and {}", + args.listen, args.listen_tls + ); + + Some(( + TcpListener::bind(args.listen).await?, + TcpListener::bind(args.listen_tls).await?, + )) + } else { + None + } }; // TODO: rename the argument to something like serverless. // It now covers more than just websockets, it also covers SQL over HTTP. let serverless_listener = if let Some(serverless_address) = args.wss { - let serverless_address: SocketAddr = serverless_address.parse()?; info!("Starting wss on {serverless_address}"); Some(TcpListener::bind(serverless_address).await?) } else if args.is_auth_broker { @@ -458,6 +459,37 @@ pub async fn run() -> anyhow::Result<()> { } } + // spawn pg-sni-router mode. + if let Some((listen, listen_tls)) = sni_router_listeners { + let args = args.pg_sni_router; + let dest = args.dest.expect("already asserted it is set"); + let key_path = args.tls_key.expect("already asserted it is set"); + let cert_path = args.tls_cert.expect("already asserted it is set"); + + let (tls_config, tls_server_end_point) = + super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + + let dest = Arc::new(dest); + + client_tasks.spawn(super::pg_sni_router::task_main( + dest.clone(), + tls_config.clone(), + None, + tls_server_end_point, + listen, + cancellation_token.clone(), + )); + + client_tasks.spawn(super::pg_sni_router::task_main( + dest, + tls_config, + Some(config.connect_to_compute.tls.clone()), + tls_server_end_point, + listen_tls, + cancellation_token.clone(), + )); + } + client_tasks.spawn(crate::context::parquet::worker( cancellation_token.clone(), args.parquet_upload, @@ -565,7 +597,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( key_path, cert_path, - args.certs_dir.as_ref(), + args.certs_dir.as_deref(), args.allow_tls_keylogfile, )?), (None, None) => None, @@ -811,6 +843,60 @@ fn build_auth_backend( } } +async fn configure_redis( + args: &ProxyCliArgs, +) -> anyhow::Result<( + Option, + Option, +)> { + // TODO: untangle the config args + let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) { + ("plain", redis_url) => match redis_url { + None => { + bail!("plain auth requires redis_notifications to be set"); + } + Some(url) => { + Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())) + } + }, + ("irsa", _) => match (&args.redis_host, args.redis_port) { + (Some(host), Some(port)) => Some( + ConnectionWithCredentialsProvider::new_with_credentials_provider( + host.to_string(), + port, + elasticache::CredentialsProvider::new( + args.aws_region.clone(), + args.redis_cluster_name.clone(), + args.redis_user_id.clone(), + ) + .await, + ), + ), + (None, None) => { + // todo: upgrade to error? + warn!( + "irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client" + ); + None + } + _ => { + bail!("redis-host and redis-port must be specified together"); + } + }, + _ => { + bail!("unknown auth type given"); + } + }; + + let redis_notifications_client = if let Some(url) = &args.redis_notifications { + Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url)) + } else { + regional_redis_client.clone() + }; + + Ok((regional_redis_client, redis_notifications_client)) +} + #[cfg(test)] mod tests { use std::time::Duration; diff --git a/proxy/src/tls/server_config.rs b/proxy/src/tls/server_config.rs index 8f8917ef62..66c53b3aff 100644 --- a/proxy/src/tls/server_config.rs +++ b/proxy/src/tls/server_config.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::path::Path; use std::sync::Arc; use anyhow::{Context, bail}; @@ -21,9 +22,9 @@ pub struct TlsConfig { /// Configure TLS for the main endpoint. pub fn configure_tls( - key_path: &str, - cert_path: &str, - certs_dir: Option<&String>, + key_path: &Path, + cert_path: &Path, + certs_dir: Option<&Path>, allow_tls_keylogfile: bool, ) -> anyhow::Result { // add default certificate @@ -39,8 +40,7 @@ pub fn configure_tls( let key_path = path.join("tls.key"); let cert_path = path.join("tls.crt"); if key_path.exists() && cert_path.exists() { - cert_resolver - .add_cert_path(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?; + cert_resolver.add_cert_path(&key_path, &cert_path)?; } } } @@ -86,7 +86,7 @@ pub struct CertResolver { } impl CertResolver { - fn parse_new(key_path: &str, cert_path: &str) -> anyhow::Result { + fn parse_new(key_path: &Path, cert_path: &Path) -> anyhow::Result { let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?; Self::new(priv_key, cert_chain) } @@ -103,7 +103,7 @@ impl CertResolver { Ok(Self { certs, default }) } - fn add_cert_path(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> { + fn add_cert_path(&mut self, key_path: &Path, cert_path: &Path) -> anyhow::Result<()> { let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?; self.add_cert(priv_key, cert_chain) } @@ -124,26 +124,29 @@ impl CertResolver { } fn parse_key_cert( - key_path: &str, - cert_path: &str, + key_path: &Path, + cert_path: &Path, ) -> anyhow::Result<(PrivateKeyDer<'static>, Vec>)> { let priv_key = { let key_bytes = std::fs::read(key_path) - .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?; + .with_context(|| format!("Failed to read TLS keys at '{}'", key_path.display()))?; rustls_pemfile::private_key(&mut &key_bytes[..]) - .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? - .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? + .with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))? + .with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))? }; - let cert_chain_bytes = std::fs::read(cert_path) - .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; + let cert_chain_bytes = std::fs::read(cert_path).context(format!( + "Failed to read TLS cert file at '{}.'", + cert_path.display() + ))?; let cert_chain = { rustls_pemfile::certs(&mut &cert_chain_bytes[..]) .try_collect() .with_context(|| { format!( - "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." + "Failed to read TLS certificate chain from bytes from file at '{}'.", + cert_path.display() ) })? }; diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 8f56ee4392..2801a0e867 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3607,6 +3607,8 @@ class NeonProxy(PgProtocol): http_port: int, mgmt_port: int, external_http_port: int, + router_port: int, + router_tls_port: int, auth_backend: NeonProxy.AuthBackend, metric_collection_endpoint: str | None = None, metric_collection_interval: str | None = None, @@ -3623,6 +3625,8 @@ class NeonProxy(PgProtocol): self.test_output_dir = test_output_dir self.proxy_port = proxy_port self.mgmt_port = mgmt_port + self.router_port = router_port + self.router_tls_port = router_tls_port self.auth_backend = auth_backend self.metric_collection_endpoint = metric_collection_endpoint self.metric_collection_interval = metric_collection_interval @@ -3637,6 +3641,14 @@ class NeonProxy(PgProtocol): key_path = self.test_output_dir / "proxy.key" generate_proxy_tls_certs("*.local.neon.build", key_path, crt_path) + # generate key for pg-sni-router. + # endpoint.namespace.local.neon.build resolves to 127.0.0.1 + generate_proxy_tls_certs( + "endpoint.namespace.local.neon.build", + self.test_output_dir / "router.key", + self.test_output_dir / "router.crt", + ) + args = [ str(self.neon_binpath / "proxy"), *["--http", f"{self.host}:{self.http_port}"], @@ -3646,6 +3658,11 @@ class NeonProxy(PgProtocol): *["--sql-over-http-timeout", f"{self.http_timeout_seconds}s"], *["-c", str(crt_path)], *["-k", str(key_path)], + *["--sni-router-listen", f"{self.host}:{self.router_port}"], + *["--sni-router-listen-tls", f"{self.host}:{self.router_tls_port}"], + *["--sni-router-tls-cert", str(self.test_output_dir / "router.crt")], + *["--sni-router-tls-key", str(self.test_output_dir / "router.key")], + *["--sni-router-destination", "local.neon.build"], *self.auth_backend.extra_args(), ] @@ -3945,6 +3962,8 @@ def link_proxy( proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() external_http_port = port_distributor.get_port() + router_port = port_distributor.get_port() + router_tls_port = port_distributor.get_port() with NeonProxy( neon_binpath=neon_binpath, @@ -3952,6 +3971,8 @@ def link_proxy( proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, + router_port=router_port, + router_tls_port=router_tls_port, external_http_port=external_http_port, auth_backend=NeonProxy.Link(), ) as proxy: @@ -3985,6 +4006,8 @@ def static_proxy( mgmt_port = port_distributor.get_port() http_port = port_distributor.get_port() external_http_port = port_distributor.get_port() + router_port = port_distributor.get_port() + router_tls_port = port_distributor.get_port() with NeonProxy( neon_binpath=neon_binpath, @@ -3992,6 +4015,8 @@ def static_proxy( proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, + router_port=router_port, + router_tls_port=router_tls_port, external_http_port=external_http_port, auth_backend=NeonProxy.Postgres(auth_endpoint), ) as proxy: diff --git a/test_runner/regress/test_proxy_metric_collection.py b/test_runner/regress/test_proxy_metric_collection.py index 85d8a6daaa..7442d50f68 100644 --- a/test_runner/regress/test_proxy_metric_collection.py +++ b/test_runner/regress/test_proxy_metric_collection.py @@ -52,6 +52,8 @@ def proxy_with_metric_collector( proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() external_http_port = port_distributor.get_port() + router_port = port_distributor.get_port() + router_tls_port = port_distributor.get_port() (host, port) = httpserver_listen_address metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" @@ -63,6 +65,8 @@ def proxy_with_metric_collector( proxy_port=proxy_port, http_port=http_port, mgmt_port=mgmt_port, + router_port=router_port, + router_tls_port=router_tls_port, external_http_port=external_http_port, metric_collection_endpoint=metric_collection_endpoint, metric_collection_interval=metric_collection_interval, diff --git a/test_runner/regress/test_sni_router.py b/test_runner/regress/test_sni_router.py index 19952fc71b..61893f22ba 100644 --- a/test_runner/regress/test_sni_router.py +++ b/test_runner/regress/test_sni_router.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING import backoff from fixtures.log_helper import log -from fixtures.neon_fixtures import PgProtocol, VanillaPostgres +from fixtures.neon_fixtures import NeonProxy, PgProtocol, VanillaPostgres if TYPE_CHECKING: from pathlib import Path @@ -41,6 +41,7 @@ class PgSniRouter(PgProtocol): self, neon_binpath: Path, port: int, + tls_port: int, destination: str, tls_cert: Path, tls_key: Path, @@ -53,6 +54,7 @@ class PgSniRouter(PgProtocol): self.host = host self.neon_binpath = neon_binpath self.port = port + self.tls_port = tls_port self.destination = destination self.tls_cert = tls_cert self.tls_key = tls_key @@ -64,6 +66,7 @@ class PgSniRouter(PgProtocol): args = [ str(self.neon_binpath / "pg_sni_router"), *["--listen", f"127.0.0.1:{self.port}"], + *["--listen-tls", f"127.0.0.1:{self.tls_port}"], *["--tls-cert", str(self.tls_cert)], *["--tls-key", str(self.tls_key)], *["--destination", self.destination], @@ -127,10 +130,12 @@ def test_pg_sni_router( pg_port = vanilla_pg.default_options["port"] router_port = port_distributor.get_port() + router_tls_port = port_distributor.get_port() with PgSniRouter( neon_binpath=neon_binpath, port=router_port, + tls_port=router_tls_port, destination="local.neon.build", tls_cert=test_output_dir / "router.crt", tls_key=test_output_dir / "router.key", @@ -146,3 +151,22 @@ def test_pg_sni_router( hostaddr="127.0.0.1", ) assert out[0][0] == 1 + + +def test_pg_sni_router_in_proxy( + static_proxy: NeonProxy, + vanilla_pg: VanillaPostgres, +): + # static_proxy starts this. + assert vanilla_pg.is_running() + pg_port = vanilla_pg.default_options["port"] + + out = static_proxy.safe_psql( + "select 1", + dbname="postgres", + sslmode="require", + host=f"endpoint--namespace--{pg_port}.local.neon.build", + hostaddr="127.0.0.1", + port=static_proxy.router_port, + ) + assert out[0][0] == 1