mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 21:12:55 +00:00
pg-sni-router: support compute TLS on different port (#11732)
## Problem pg-sni-router isn't aware of compute TLS ## Summary of changes If connections come in on port 4433, we require TLS to compute from pg-sni-router
This commit is contained in:
@@ -7,13 +7,14 @@ use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::{Context, anyhow, bail, ensure};
|
||||
use clap::Arg;
|
||||
use futures::TryFutureExt;
|
||||
use futures::future::Either;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types::PrivateKeyDer;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use rustls::pki_types::{DnsName, PrivateKeyDer};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
@@ -38,6 +39,12 @@ fn cli() -> clap::Command {
|
||||
.help("listen for incoming client connections on ip:port")
|
||||
.default_value("127.0.0.1:4432"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("listen-tls")
|
||||
.long("listen-tls")
|
||||
.help("listen for incoming client connections on ip:port, requiring TLS to compute")
|
||||
.default_value("127.0.0.1:4433"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-key")
|
||||
.short('k')
|
||||
@@ -122,31 +129,58 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
let compute_tls_config =
|
||||
Arc::new(crate::tls::client_config::compute_client_config_with_root_certs()?);
|
||||
|
||||
// Start listening for incoming client connections
|
||||
let proxy_address: SocketAddr = args
|
||||
.get_one::<String>("listen")
|
||||
.expect("string argument defined")
|
||||
.expect("listen argument defined")
|
||||
.parse()?;
|
||||
let proxy_address_compute_tls: SocketAddr = args
|
||||
.get_one::<String>("listen-tls")
|
||||
.expect("listen-tls argument defined")
|
||||
.parse()?;
|
||||
|
||||
info!("Starting sni router on {proxy_address}");
|
||||
info!("Starting sni router on {proxy_address_compute_tls}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let proxy_listener_compute_tls = TcpListener::bind(proxy_address_compute_tls).await?;
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let dest = Arc::new(destination);
|
||||
|
||||
let main = tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
))
|
||||
.map(crate::error::flatten_err);
|
||||
|
||||
let main_tls = tokio::spawn(task_main(
|
||||
dest,
|
||||
tls_config,
|
||||
Some(compute_tls_config),
|
||||
tls_server_end_point,
|
||||
proxy_listener_compute_tls,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
.map(crate::error::flatten_err);
|
||||
let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
|
||||
|
||||
// the signal task cant ever succeed.
|
||||
// the main task can error, or can succeed on cancellation.
|
||||
// we want to immediately exit on either of these cases
|
||||
let main = futures::future::try_join(main, main_tls);
|
||||
let signal = match futures::future::select(signals_task, main).await {
|
||||
Either::Left((res, _)) => crate::error::flatten_err(res)?,
|
||||
Either::Right((res, _)) => return crate::error::flatten_err(res),
|
||||
Either::Right((res, _)) => {
|
||||
res?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// maintenance tasks return `Infallible` success values, this is an impossible value
|
||||
@@ -157,6 +191,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -175,6 +210,7 @@ async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let tls_config = Arc::clone(&tls_config);
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
let compute_tls_config = compute_tls_config.clone();
|
||||
|
||||
connections.spawn(
|
||||
async move {
|
||||
@@ -192,7 +228,15 @@ async fn task_main(
|
||||
crate::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
handle_client(
|
||||
ctx,
|
||||
dest_suffix,
|
||||
tls_config,
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -268,6 +312,7 @@ async fn handle_client(
|
||||
ctx: RequestContext,
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -288,7 +333,33 @@ async fn handle_client(
|
||||
|
||||
info!("destination: {}", destination);
|
||||
|
||||
let mut client = tokio::net::TcpStream::connect(destination).await?;
|
||||
let mut client = tokio::net::TcpStream::connect(&destination).await?;
|
||||
|
||||
let client = if let Some(compute_tls_config) = compute_tls_config {
|
||||
info!("upgrading TLS");
|
||||
|
||||
// send SslRequest
|
||||
client
|
||||
.write_all(b"\x00\x00\x00\x08\x04\xd2\x16\x2f")
|
||||
.await?;
|
||||
|
||||
// wait for S/N respons
|
||||
let mut resp = b'N';
|
||||
client.read_exact(std::slice::from_mut(&mut resp)).await?;
|
||||
|
||||
// error if not S
|
||||
ensure!(resp == b'S', "compute refused TLS");
|
||||
|
||||
// upgrade to TLS.
|
||||
let domain = DnsName::try_from(destination)?;
|
||||
let domain = rustls::pki_types::ServerName::DnsName(domain);
|
||||
let client = TlsConnector::from(compute_tls_config)
|
||||
.connect(domain, client)
|
||||
.await?;
|
||||
Connection::Tls(client)
|
||||
} else {
|
||||
Connection::Raw(client)
|
||||
};
|
||||
|
||||
// doesn't yet matter as pg-sni-router doesn't report analytics logs
|
||||
ctx.set_success();
|
||||
@@ -297,9 +368,19 @@ async fn handle_client(
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
|
||||
let res = match client {
|
||||
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
||||
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(_) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
|
||||
enum Connection {
|
||||
Raw(tokio::net::TcpStream),
|
||||
Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user