diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 62cf734373..8aa1923f61 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -49,6 +49,13 @@ fn cli() -> clap::Command { .help("append this domain zone to the SNI hostname to get the destination address") .required(true), ) + .arg( + Arg::new("dest-port") + .long("destination-port") + .help("destination port to connect to") + .default_value("5432") + .value_parser(clap::value_parser!(u16)), + ) } #[tokio::main] @@ -58,6 +65,8 @@ async fn main() -> anyhow::Result<()> { let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); let args = cli().get_matches(); + let destination: String = args.get_one::("dest").unwrap().parse()?; + let destination_port: u16 = *args.get_one::("dest-port").unwrap(); // Configure TLS let tls_config: Arc = match ( @@ -98,8 +107,6 @@ async fn main() -> anyhow::Result<()> { _ => bail!("tls-key and tls-cert must be specified"), }; - let destination: String = args.get_one::("dest").unwrap().parse()?; - // Start listening for incoming client connections let proxy_address: SocketAddr = args.get_one::("listen").unwrap().parse()?; info!("Starting proxy on {proxy_address}"); @@ -110,6 +117,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(proxy::handle_signals(cancellation_token.clone())), tokio::spawn(task_main( Arc::new(destination), + destination_port, tls_config, proxy_listener, cancellation_token.clone(), @@ -123,6 +131,7 @@ async fn main() -> anyhow::Result<()> { async fn task_main( dest_suffix: Arc, + dest_port: u16, tls_config: Arc, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, @@ -157,7 +166,7 @@ async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(dest_suffix, tls_config, &cancel_map, session_id, socket).await + handle_client(dest_suffix, dest_port, tls_config, &cancel_map, session_id, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -185,6 +194,7 @@ async fn task_main( #[tracing::instrument(fields(session_id = ?session_id), skip_all)] async fn handle_client( dest_suffix: Arc, + dest_port: u16, tls: Arc, cancel_map: &CancelMap, session_id: uuid::Uuid, @@ -214,9 +224,10 @@ async fn handle_client( let destination = format!("{}.{}", dest, dest_suffix); - info!("destination: {:?}", destination); + info!("destination: {}:{}", destination, dest_port); conn_cfg.host(destination.as_str()); + conn_cfg.port(dest_port); let mut conn = conn_cfg.connect() .or_else(|e| stream.throw_error(e))