mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-26 23:59:58 +00:00
This rebuilds #11552 on top the current Cargo.lock. --------- Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
379 lines
12 KiB
Rust
379 lines
12 KiB
Rust
//! A stand-alone program that routes connections, e.g. from
|
|
//! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
|
|
//!
|
|
//! This allows connecting to pods/services running in the same Kubernetes cluster from
|
|
//! the outside. Similar to an ingress controller for HTTPS.
|
|
|
|
use std::io;
|
|
use std::net::SocketAddr;
|
|
use std::path::Path;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{Context, anyhow, bail, ensure};
|
|
use clap::Arg;
|
|
use futures::future::Either;
|
|
use futures::{FutureExt, TryFutureExt};
|
|
use itertools::Itertools;
|
|
use rustls::crypto::ring;
|
|
use rustls::pki_types::{DnsName, PrivateKeyDer};
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
use tokio::net::TcpListener;
|
|
use tokio_rustls::TlsConnector;
|
|
use tokio_rustls::server::TlsStream;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{Instrument, error, info};
|
|
use utils::project_git_version;
|
|
use utils::sentry_init::init_sentry;
|
|
|
|
use crate::context::RequestContext;
|
|
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
|
use crate::pglb::TlsRequired;
|
|
use crate::pqproto::FeStartupPacket;
|
|
use crate::protocol2::ConnectionInfo;
|
|
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
|
|
use crate::stream::{PqStream, Stream};
|
|
use crate::util::run_until_cancelled;
|
|
|
|
project_git_version!(GIT_VERSION);
|
|
|
|
fn cli() -> clap::Command {
|
|
clap::Command::new("Neon proxy/router")
|
|
.version(GIT_VERSION)
|
|
.arg(
|
|
Arg::new("listen")
|
|
.short('l')
|
|
.long("listen")
|
|
.help("listen for incoming client connections on ip:port")
|
|
.default_value("127.0.0.1:4432"),
|
|
)
|
|
.arg(
|
|
Arg::new("listen-tls")
|
|
.long("listen-tls")
|
|
.help("listen for incoming client connections on ip:port, requiring TLS to compute")
|
|
.default_value("127.0.0.1:4433"),
|
|
)
|
|
.arg(
|
|
Arg::new("tls-key")
|
|
.short('k')
|
|
.long("tls-key")
|
|
.help("path to TLS key for client postgres connections")
|
|
.required(true),
|
|
)
|
|
.arg(
|
|
Arg::new("tls-cert")
|
|
.short('c')
|
|
.long("tls-cert")
|
|
.help("path to TLS cert for client postgres connections")
|
|
.required(true),
|
|
)
|
|
.arg(
|
|
Arg::new("dest")
|
|
.short('d')
|
|
.long("destination")
|
|
.help("append this domain zone to the SNI hostname to get the destination address")
|
|
.required(true),
|
|
)
|
|
}
|
|
|
|
pub async fn run() -> anyhow::Result<()> {
|
|
let _logging_guard = crate::logging::init()?;
|
|
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
|
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
|
|
|
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
|
|
|
let args = cli().get_matches();
|
|
let destination: String = args
|
|
.get_one::<String>("dest")
|
|
.expect("string argument defined")
|
|
.parse()?;
|
|
|
|
// Configure TLS
|
|
let tls_config = match (
|
|
args.get_one::<String>("tls-key"),
|
|
args.get_one::<String>("tls-cert"),
|
|
) {
|
|
(Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
|
|
_ => bail!("tls-key and tls-cert must be specified"),
|
|
};
|
|
|
|
let compute_tls_config =
|
|
Arc::new(crate::tls::client_config::compute_client_config_with_root_certs()?);
|
|
|
|
// Start listening for incoming client connections
|
|
let proxy_address: SocketAddr = args
|
|
.get_one::<String>("listen")
|
|
.expect("listen argument defined")
|
|
.parse()?;
|
|
let proxy_address_compute_tls: SocketAddr = args
|
|
.get_one::<String>("listen-tls")
|
|
.expect("listen-tls argument defined")
|
|
.parse()?;
|
|
|
|
info!("Starting sni router on {proxy_address}");
|
|
info!("Starting sni router on {proxy_address_compute_tls}");
|
|
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
|
let proxy_listener_compute_tls = TcpListener::bind(proxy_address_compute_tls).await?;
|
|
|
|
let cancellation_token = CancellationToken::new();
|
|
let dest = Arc::new(destination);
|
|
|
|
let main = tokio::spawn(task_main(
|
|
dest.clone(),
|
|
tls_config.clone(),
|
|
None,
|
|
proxy_listener,
|
|
cancellation_token.clone(),
|
|
))
|
|
.map(crate::error::flatten_err);
|
|
|
|
let main_tls = tokio::spawn(task_main(
|
|
dest,
|
|
tls_config,
|
|
Some(compute_tls_config),
|
|
proxy_listener_compute_tls,
|
|
cancellation_token.clone(),
|
|
))
|
|
.map(crate::error::flatten_err);
|
|
let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
|
|
|
|
// the signal task cant ever succeed.
|
|
// the main task can error, or can succeed on cancellation.
|
|
// we want to immediately exit on either of these cases
|
|
let main = futures::future::try_join(main, main_tls);
|
|
let signal = match futures::future::select(signals_task, main).await {
|
|
Either::Left((res, _)) => crate::error::flatten_err(res)?,
|
|
Either::Right((res, _)) => {
|
|
res?;
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
// maintenance tasks return `Infallible` success values, this is an impossible value
|
|
// so this match statically ensures that there are no possibilities for that value
|
|
match signal {}
|
|
}
|
|
|
|
pub(super) fn parse_tls(
|
|
key_path: &Path,
|
|
cert_path: &Path,
|
|
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
|
|
let key = {
|
|
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
|
|
|
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
|
|
|
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
|
PrivateKeyDer::Pkcs8(
|
|
keys.pop()
|
|
.expect("keys should not be empty")
|
|
.context(format!(
|
|
"Failed to read TLS keys at '{}'",
|
|
key_path.display()
|
|
))?,
|
|
)
|
|
};
|
|
|
|
let cert_chain_bytes = std::fs::read(cert_path).context(format!(
|
|
"Failed to read TLS cert file at '{}.'",
|
|
cert_path.display()
|
|
))?;
|
|
|
|
let cert_chain: Vec<_> = {
|
|
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
|
.try_collect()
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to read TLS certificate chain from bytes from file at '{}'.",
|
|
cert_path.display()
|
|
)
|
|
})?
|
|
};
|
|
|
|
let tls_config =
|
|
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
|
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
|
.context("ring should support TLS1.2 and TLS1.3")?
|
|
.with_no_client_auth()
|
|
.with_single_cert(cert_chain, key)?
|
|
.into();
|
|
|
|
Ok(tls_config)
|
|
}
|
|
|
|
pub(super) async fn task_main(
|
|
dest_suffix: Arc<String>,
|
|
tls_config: Arc<rustls::ServerConfig>,
|
|
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
|
listener: tokio::net::TcpListener,
|
|
cancellation_token: CancellationToken,
|
|
) -> anyhow::Result<()> {
|
|
// When set for the server socket, the keepalive setting
|
|
// will be inherited by all accepted client sockets.
|
|
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
|
|
|
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
|
|
|
while let Some(accept_result) =
|
|
run_until_cancelled(listener.accept(), &cancellation_token).await
|
|
{
|
|
let (socket, peer_addr) = accept_result?;
|
|
|
|
let session_id = uuid::Uuid::new_v4();
|
|
let tls_config = Arc::clone(&tls_config);
|
|
let dest_suffix = Arc::clone(&dest_suffix);
|
|
let compute_tls_config = compute_tls_config.clone();
|
|
|
|
connections.spawn(
|
|
async move {
|
|
socket
|
|
.set_nodelay(true)
|
|
.context("failed to set socket option")?;
|
|
|
|
let ctx = RequestContext::new(
|
|
session_id,
|
|
ConnectionInfo {
|
|
addr: peer_addr,
|
|
extra: None,
|
|
},
|
|
crate::metrics::Protocol::SniRouter,
|
|
);
|
|
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
|
}
|
|
.unwrap_or_else(|e| {
|
|
if let Some(FirstMessage(io_error)) = e.downcast_ref() {
|
|
// this is noisy. if we get EOF on the very first message that's likely
|
|
// just NLB doing a healthcheck.
|
|
if io_error.kind() == io::ErrorKind::UnexpectedEof {
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Acknowledge that the task has finished with an error.
|
|
error!("per-client task finished with an error: {e:#}");
|
|
})
|
|
.instrument(tracing::info_span!("handle_client", ?session_id)),
|
|
);
|
|
}
|
|
|
|
connections.close();
|
|
drop(listener);
|
|
|
|
connections.wait().await;
|
|
|
|
info!("all client connections have finished");
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
#[error(transparent)]
|
|
struct FirstMessage(io::Error);
|
|
|
|
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
|
ctx: &RequestContext,
|
|
raw_stream: S,
|
|
tls_config: Arc<rustls::ServerConfig>,
|
|
) -> anyhow::Result<TlsStream<S>> {
|
|
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
|
|
.await
|
|
.map_err(FirstMessage)?;
|
|
|
|
match msg {
|
|
FeStartupPacket::SslRequest { direct: None } => {
|
|
let raw = stream.accept_tls().await?;
|
|
|
|
Ok(raw
|
|
.upgrade(tls_config, !ctx.has_private_peer_addr())
|
|
.await?)
|
|
}
|
|
unexpected => {
|
|
info!(
|
|
?unexpected,
|
|
"unexpected startup packet, rejecting connection"
|
|
);
|
|
Err(stream.throw_error(TlsRequired, None).await)?
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_client(
|
|
ctx: RequestContext,
|
|
dest_suffix: Arc<String>,
|
|
tls_config: Arc<rustls::ServerConfig>,
|
|
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
|
stream: impl AsyncRead + AsyncWrite + Unpin,
|
|
) -> anyhow::Result<()> {
|
|
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
|
|
|
|
// Cut off first part of the SNI domain
|
|
// We receive required destination details in the format of
|
|
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
|
let sni = tls_stream
|
|
.get_ref()
|
|
.1
|
|
.server_name()
|
|
.ok_or(anyhow!("SNI missing"))?;
|
|
let dest: Vec<&str> = sni
|
|
.split_once('.')
|
|
.context("invalid SNI")?
|
|
.0
|
|
.splitn(3, "--")
|
|
.collect();
|
|
let port = dest[2].parse::<u16>().context("invalid port")?;
|
|
let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
|
|
|
|
info!("destination: {}", destination);
|
|
|
|
let mut client = tokio::net::TcpStream::connect(&destination).await?;
|
|
|
|
let client = if let Some(compute_tls_config) = compute_tls_config {
|
|
info!("upgrading TLS");
|
|
|
|
// send SslRequest
|
|
client
|
|
.write_all(b"\x00\x00\x00\x08\x04\xd2\x16\x2f")
|
|
.await?;
|
|
|
|
// wait for S/N respons
|
|
let mut resp = b'N';
|
|
client.read_exact(std::slice::from_mut(&mut resp)).await?;
|
|
|
|
// error if not S
|
|
ensure!(resp == b'S', "compute refused TLS");
|
|
|
|
// upgrade to TLS.
|
|
let domain = DnsName::try_from(destination)?;
|
|
let domain = rustls::pki_types::ServerName::DnsName(domain);
|
|
let client = TlsConnector::from(compute_tls_config)
|
|
.connect(domain, client)
|
|
.await?;
|
|
Connection::Tls(client)
|
|
} else {
|
|
Connection::Raw(client)
|
|
};
|
|
|
|
// doesn't yet matter as pg-sni-router doesn't report analytics logs
|
|
ctx.set_success();
|
|
ctx.log_connect();
|
|
|
|
// Starting from here we only proxy the client's traffic.
|
|
info!("performing the proxy pass...");
|
|
|
|
let res = match client {
|
|
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
|
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
|
};
|
|
|
|
match res {
|
|
Ok(_) => Ok(()),
|
|
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
|
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::large_enum_variant)]
|
|
enum Connection {
|
|
Raw(tokio::net::TcpStream),
|
|
Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
|
|
}
|