Add an option for destination port.

Makes it easier to test locally.
This commit is contained in:
Heikki Linnakangas
2023-04-26 15:22:02 +03:00
parent 620efed7f6
commit ae25d4ab35

View File

@@ -49,6 +49,13 @@ fn cli() -> clap::Command {
.help("append this domain zone to the SNI hostname to get the destination address")
.required(true),
)
.arg(
Arg::new("dest-port")
.long("destination-port")
.help("destination port to connect to")
.default_value("5432")
.value_parser(clap::value_parser!(u16)),
)
}
#[tokio::main]
@@ -58,6 +65,8 @@ async fn main() -> anyhow::Result<()> {
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
let args = cli().get_matches();
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
let destination_port: u16 = *args.get_one::<u16>("dest-port").unwrap();
// Configure TLS
let tls_config: Arc<rustls::ServerConfig> = match (
@@ -98,8 +107,6 @@ async fn main() -> anyhow::Result<()> {
_ => bail!("tls-key and tls-cert must be specified"),
};
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Start listening for incoming client connections
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
info!("Starting proxy on {proxy_address}");
@@ -110,6 +117,7 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(proxy::handle_signals(cancellation_token.clone())),
tokio::spawn(task_main(
Arc::new(destination),
destination_port,
tls_config,
proxy_listener,
cancellation_token.clone(),
@@ -123,6 +131,7 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
dest_port: u16,
tls_config: Arc<rustls::ServerConfig>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
@@ -157,7 +166,7 @@ async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(dest_suffix, tls_config, &cancel_map, session_id, socket).await
handle_client(dest_suffix, dest_port, tls_config, &cancel_map, session_id, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -185,6 +194,7 @@ async fn task_main(
#[tracing::instrument(fields(session_id = ?session_id), skip_all)]
async fn handle_client(
dest_suffix: Arc<String>,
dest_port: u16,
tls: Arc<rustls::ServerConfig>,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
@@ -214,9 +224,10 @@ async fn handle_client(
let destination = format!("{}.{}", dest, dest_suffix);
info!("destination: {:?}", destination);
info!("destination: {}:{}", destination, dest_port);
conn_cfg.host(destination.as_str());
conn_cfg.port(dest_port);
let mut conn = conn_cfg.connect()
.or_else(|e| stream.throw_error(e))