diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs new file mode 100644 index 0000000000..adaf065e1b --- /dev/null +++ b/proxy/src/bin/pg_sni_router.rs @@ -0,0 +1,226 @@ +use std::{net::SocketAddr, sync::Arc}; +use tokio::{net::TcpListener, io::AsyncWriteExt}; + +use anyhow::{bail, ensure, Context}; +use clap::{self, Arg}; +use futures::TryFutureExt; +use proxy::{cancellation::CancelMap, auth::{AuthFlow, self}, compute::ConnCfg, console::messages::MetricsAuxInfo}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::sync::CancellationToken; +use utils::{project_git_version, sentry_init::init_sentry}; + +use tracing::{error, info, warn}; + +project_git_version!(GIT_VERSION); + +fn cli() -> clap::Command { + clap::Command::new("Neon proxy/router") + .disable_help_flag(true) + .version(GIT_VERSION) + .arg( + Arg::new("listen") + .short('l') + .long("listen") + .help("listen for incoming client connections on ip:port") + .default_value("127.0.0.1:4432"), + ) + .arg( + Arg::new("tls-key") + .short('k') + .long("tls-key") + .help("path to TLS key for client postgres connections"), + ) + .arg( + Arg::new("tls-cert") + .short('c') + .long("tls-cert") + .help("path to TLS cert for client postgres connections"), + ) + .arg( + Arg::new("dest") + .short('d') + .long("destination") + .help("append this domain zone to the SNI hostname to get the destination address"), + ) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let _logging_guard = proxy::logging::init().await?; + let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); + let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); + + let args = cli().get_matches(); + + // Configure TLS + let tls_config: Arc = match ( + args.get_one::("tls-key"), + args.get_one::("tls-cert"), + ) { + (Some(key_path), Some(cert_path)) => { + let key = { + let key_bytes = std::fs::read(key_path).context("TLS key file")?; + let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) + .context(format!("Failed to read TLS keys at '{key_path}'"))?; + + ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); + keys.pop().map(rustls::PrivateKey).unwrap() + }; + + let cert_chain_bytes = std::fs::read(cert_path) + .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; + + let cert_chain = { + rustls_pemfile::certs(&mut &cert_chain_bytes[..]) + .context(format!( + "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." + ))? + .into_iter() + .map(rustls::Certificate) + .collect() + }; + + rustls::ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? + .with_no_client_auth() + .with_single_cert(cert_chain, key)? + .into() + } + _ => bail!("tls-key and tls-cert must be specified"), + }; + + let destination: String = args.get_one::("dest").unwrap().parse()?; + + // Start listening for incoming client connections + let proxy_address: SocketAddr = args.get_one::("listen").unwrap().parse()?; + info!("Starting proxy on {proxy_address}"); + let proxy_listener = TcpListener::bind(proxy_address).await?; + + let cancellation_token = CancellationToken::new(); + let tasks = vec![ + tokio::spawn(proxy::handle_signals(cancellation_token.clone())), + tokio::spawn(task_main( + Arc::new(destination), + tls_config, + proxy_listener, + cancellation_token.clone(), + )), + ]; + + let _tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err)).await?; + + Ok(()) +} + +async fn task_main( + dest_suffix: Arc, + tls_config: Arc, + listener: tokio::net::TcpListener, + cancellation_token: CancellationToken, +) -> anyhow::Result<()> { + scopeguard::defer! { + info!("proxy has shut down"); + } + + // When set for the server socket, the keepalive setting + // will be inherited by all accepted client sockets. + socket2::SockRef::from(&listener).set_keepalive(true)?; + + let mut connections = tokio::task::JoinSet::new(); + let cancel_map = Arc::new(CancelMap::default()); + + loop { + tokio::select! { + accept_result = listener.accept() => { + let (socket, peer_addr) = accept_result?; + info!("accepted postgres client connection from {peer_addr}"); + + let session_id = uuid::Uuid::new_v4(); + let cancel_map = Arc::clone(&cancel_map); + let tls_config = Arc::clone(&tls_config); + let dest_suffix = Arc::clone(&dest_suffix); + + connections.spawn( + async move { + info!("spawned a task for {peer_addr}"); + + socket + .set_nodelay(true) + .context("failed to set socket option")?; + + handle_client(dest_suffix, tls_config, &cancel_map, session_id, socket).await + } + .unwrap_or_else(|e| { + // Acknowledge that the task has finished with an error. + error!("per-client task finished with an error: {e:#}"); + }), + ); + } + _ = cancellation_token.cancelled() => { + drop(listener); + break; + } + } + } + // Drain connections + while let Some(res) = connections.join_next().await { + if let Err(e) = res { + if !e.is_panic() && !e.is_cancelled() { + warn!("unexpected error from joined connection task: {e:?}"); + } + } + } + Ok(()) +} + +#[tracing::instrument(fields(session_id = ?session_id), skip_all)] +async fn handle_client( + dest_suffix: Arc, + tls: Arc, + cancel_map: &CancelMap, + session_id: uuid::Uuid, + stream: impl AsyncRead + AsyncWrite + Unpin, +) -> anyhow::Result<()> { + let do_handshake = proxy::proxy::handshake(stream, Some(tls), cancel_map); + let (mut stream, params) = match do_handshake.await? { + Some(x) => x, + None => return Ok(()), // it's a cancellation request + }; + + let password = AuthFlow::new(&mut stream) + .begin(auth::CleartextPassword) + .await? + .authenticate() + .await?; + + let mut conn_cfg = ConnCfg::new(); + conn_cfg.set_startup_params(¶ms); + conn_cfg.password(password); + + // cut off first part of the sni domain + let sni = stream.get_ref().sni_hostname().unwrap(); + let dest = sni + .split_once('.').context("invalid sni")?.0 + .replace("--", "."); + + let destination = format!("{}.{}", dest, dest_suffix); + + info!("destination: {:?}", destination); + + conn_cfg.host(destination.as_str()); + + let mut conn = conn_cfg.connect() + .or_else(|e| stream.throw_error(e)) + .await?; + + cancel_map.with_session(|session| async { + proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?; + let (stream, read_buf) = stream.into_inner(); + conn.stream.write_all(&read_buf).await?; + let metrics_aux: MetricsAuxInfo = Default::default(); + proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await + }) + .await +} diff --git a/proxy/src/main.rs b/proxy/src/bin/proxy.rs similarity index 79% rename from proxy/src/main.rs rename to proxy/src/bin/proxy.rs index 1fd13c9f68..4c66845db6 100644 --- a/proxy/src/main.rs +++ b/proxy/src/bin/proxy.rs @@ -1,49 +1,22 @@ -//! Postgres protocol proxy/router. -//! -//! This service listens psql port and can check auth via external service -//! (control plane API in our case) and can create new databases and accounts -//! in somewhat transparent manner (again via communication with control plane API). +use proxy::auth; +use proxy::console; +use proxy::http; +use proxy::metrics; -mod auth; -mod cache; -mod cancellation; -mod compute; -mod config; -mod console; -mod error; -mod http; -mod logging; -mod metrics; -mod parse; -mod proxy; -mod sasl; -mod scram; -mod stream; -mod url; -mod waiters; - -use anyhow::{bail, Context}; +use anyhow::bail; use clap::{self, Arg}; -use config::ProxyConfig; -use futures::FutureExt; -use std::{borrow::Cow, future::Future, net::SocketAddr}; -use tokio::{net::TcpListener, task::JoinError}; +use proxy::config::{self, ProxyConfig}; +use std::{borrow::Cow, net::SocketAddr}; +use tokio::{net::TcpListener}; use tokio_util::sync::CancellationToken; -use tracing::{info, warn}; +use tracing::info; use utils::{project_git_version, sentry_init::init_sentry}; project_git_version!(GIT_VERSION); -/// Flattens `Result>` into `Result`. -async fn flatten_err( - f: impl Future, JoinError>>, -) -> anyhow::Result<()> { - f.map(|r| r.context("join error").and_then(|x| x)).await -} - #[tokio::main] async fn main() -> anyhow::Result<()> { - let _logging_guard = logging::init().await?; + let _logging_guard = proxy::logging::init().await?; let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); @@ -69,7 +42,7 @@ async fn main() -> anyhow::Result<()> { let proxy_listener = TcpListener::bind(proxy_address).await?; let cancellation_token = CancellationToken::new(); - let mut client_tasks = vec![tokio::spawn(proxy::task_main( + let mut client_tasks = vec![tokio::spawn(proxy::proxy::task_main( config, proxy_listener, cancellation_token.clone(), @@ -88,7 +61,7 @@ async fn main() -> anyhow::Result<()> { } let mut tasks = vec![ - tokio::spawn(handle_signals(cancellation_token)), + tokio::spawn(proxy::handle_signals(cancellation_token)), tokio::spawn(http::server::task_main(http_listener)), tokio::spawn(console::mgmt::task_main(mgmt_listener)), ]; @@ -97,8 +70,9 @@ async fn main() -> anyhow::Result<()> { tasks.push(tokio::spawn(metrics::task_main(metrics_config))); } - let tasks = futures::future::try_join_all(tasks.into_iter().map(flatten_err)); - let client_tasks = futures::future::try_join_all(client_tasks.into_iter().map(flatten_err)); + let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err)); + let client_tasks = + futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err)); tokio::select! { // We are only expecting an error from these forever tasks res = tasks => { res?; }, @@ -107,33 +81,6 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -/// Handle unix signals appropriately. -async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> { - use tokio::signal::unix::{signal, SignalKind}; - - let mut hangup = signal(SignalKind::hangup())?; - let mut interrupt = signal(SignalKind::interrupt())?; - let mut terminate = signal(SignalKind::terminate())?; - - loop { - tokio::select! { - // Hangup is commonly used for config reload. - _ = hangup.recv() => { - warn!("received SIGHUP; config reload is not supported"); - } - // Shut down the whole application. - _ = interrupt.recv() => { - warn!("received SIGINT, exiting immediately"); - bail!("interrupted"); - } - _ = terminate.recv() => { - warn!("received SIGTERM, shutting down once all existing connections have closed"); - token.cancel(); - } - } - } -} - /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> { let tls_config = match ( diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs new file mode 100644 index 0000000000..148ee67d90 --- /dev/null +++ b/proxy/src/lib.rs @@ -0,0 +1,57 @@ +use anyhow::{bail, Context}; +use futures::{Future, FutureExt}; +use tokio::task::JoinError; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +pub mod auth; +pub mod cache; +pub mod cancellation; +pub mod compute; +pub mod config; +pub mod console; +pub mod error; +pub mod http; +pub mod logging; +pub mod metrics; +pub mod parse; +pub mod proxy; +pub mod sasl; +pub mod scram; +pub mod stream; +pub mod url; +pub mod waiters; + +/// Handle unix signals appropriately. +pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> { + use tokio::signal::unix::{signal, SignalKind}; + + let mut hangup = signal(SignalKind::hangup())?; + let mut interrupt = signal(SignalKind::interrupt())?; + let mut terminate = signal(SignalKind::terminate())?; + + loop { + tokio::select! { + // Hangup is commonly used for config reload. + _ = hangup.recv() => { + warn!("received SIGHUP; config reload is not supported"); + } + // Shut down the whole application. + _ = interrupt.recv() => { + warn!("received SIGINT, exiting immediately"); + bail!("interrupted"); + } + _ = terminate.recv() => { + warn!("received SIGTERM, shutting down once all existing connections have closed"); + token.cancel(); + } + } + } +} + +/// Flattens `Result>` into `Result`. +pub async fn flatten_err( + f: impl Future, JoinError>>, +) -> anyhow::Result<()> { + f.map(|r| r.context("join error").and_then(|x| x)).await +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1169d76160..e20c31e74c 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -5,7 +5,7 @@ use crate::{ auth::{self, backend::AuthSuccess}, cancellation::{self, CancelMap}, compute::{self, PostgresConnection}, - config::{ProxyConfig, TlsConfig}, + config::ProxyConfig, console::{self, messages::MetricsAuxInfo}, error::io_error, stream::{PqStream, Stream}, @@ -174,7 +174,7 @@ async fn handle_client( NUM_CONNECTIONS_CLOSED_COUNTER.inc(); } - let tls = config.tls_config.as_ref(); + let tls = config.tls_config.as_ref().map(|t| t.to_server_config()); let do_handshake = handshake(stream, tls, cancel_map); let (mut stream, params) = match do_handshake.await? { Some(x) => x, @@ -184,7 +184,10 @@ async fn handle_client( // Extract credentials which we're going to use for auth. let creds = { let sni = stream.get_ref().sni_hostname(); - let common_names = tls.and_then(|tls| tls.common_names.clone()); + let common_names = config + .tls_config + .as_ref() + .and_then(|tls| tls.common_names.clone()); let result = config .auth_backend .as_ref() @@ -205,13 +208,14 @@ async fn handle_client( /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -async fn handshake( +pub async fn handshake( stream: S, - mut tls: Option<&TlsConfig>, + tls: Option>, cancel_map: &CancelMap, ) -> anyhow::Result>, StartupMessageParams)>> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); + let mut tls_upgraded = false; let mut stream = PqStream::new(Stream::from_raw(stream)); loop { @@ -226,8 +230,9 @@ async fn handshake( // We can't perform TLS handshake without a config let enc = tls.is_some(); + stream.write_message(&Be::EncryptionResponse(enc)).await?; - if let Some(tls) = tls.take() { + if let Some(tls) = tls.clone() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. @@ -241,7 +246,8 @@ async fn handshake( if !read_buf.is_empty() { bail!("data is sent before server replied with EncryptionResponse"); } - stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?); + stream = PqStream::new(raw.upgrade(tls).await?); + tls_upgraded = true; } } _ => bail!(ERR_PROTO_VIOLATION), @@ -256,9 +262,8 @@ async fn handshake( _ => bail!(ERR_PROTO_VIOLATION), }, StartupMessage { params, .. } => { - // Check that the config has been consumed during upgrade - // OR we didn't provide it at all (for dev purposes). - if tls.is_some() { + // Check that tls was actually upgraded + if !tls_upgraded { stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; } @@ -340,7 +345,7 @@ async fn connect_to_compute( /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] -async fn prepare_client_connection( +pub async fn prepare_client_connection( node: &compute::PostgresConnection, reported_auth_ok: bool, session: cancellation::Session<'_>, @@ -378,7 +383,7 @@ async fn prepare_client_connection( /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] -async fn proxy_pass( +pub async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: &MetricsAuxInfo,