mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 13:02:55 +00:00
Add SNI-based proxy router
In order to not to create NodePorts for each compute we can setup services that accept connections on wildcard domains and then use information from domain name to route connection to some internal service. There are ready solutions for HTTPS and TLS connections but postgresql protocol uses opportunistic TLS and we haven't found any ready solutions. This patch introduces `pg_sni_router` which routes connections to `aaa--bbb--123.external.domain` to `aaa.bbb.123.internal.domain`. In the long run we can avoid console -> compute psql communications, but now this router seems to be the easier way forward.
This commit is contained in:
226
proxy/src/bin/pg_sni_router.rs
Normal file
226
proxy/src/bin/pg_sni_router.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::{net::TcpListener, io::AsyncWriteExt};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use clap::{self, Arg};
|
||||
use futures::TryFutureExt;
|
||||
use proxy::{cancellation::CancelMap, auth::{AuthFlow, self}, compute::ConnCfg, console::messages::MetricsAuxInfo};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
fn cli() -> clap::Command {
|
||||
clap::Command::new("Neon proxy/router")
|
||||
.disable_help_flag(true)
|
||||
.version(GIT_VERSION)
|
||||
.arg(
|
||||
Arg::new("listen")
|
||||
.short('l')
|
||||
.long("listen")
|
||||
.help("listen for incoming client connections on ip:port")
|
||||
.default_value("127.0.0.1:4432"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-key")
|
||||
.short('k')
|
||||
.long("tls-key")
|
||||
.help("path to TLS key for client postgres connections"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-cert")
|
||||
.short('c')
|
||||
.long("tls-cert")
|
||||
.help("path to TLS cert for client postgres connections"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("dest")
|
||||
.short('d')
|
||||
.long("destination")
|
||||
.help("append this domain zone to the SNI hostname to get the destination address"),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
let args = cli().get_matches();
|
||||
|
||||
// Configure TLS
|
||||
let tls_config: Arc<rustls::ServerConfig> = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
|
||||
rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into()
|
||||
}
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
|
||||
|
||||
// Start listening for incoming client connections
|
||||
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let tasks = vec![
|
||||
tokio::spawn(proxy::handle_signals(cancellation_token.clone())),
|
||||
tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
)),
|
||||
];
|
||||
|
||||
let _tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let mut connections = tokio::task::JoinSet::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
accept_result = listener.accept() => {
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
info!("accepted postgres client connection from {peer_addr}");
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let tls_config = Arc::clone(&tls_config);
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
|
||||
connections.spawn(
|
||||
async move {
|
||||
info!("spawned a task for {peer_addr}");
|
||||
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(dest_suffix, tls_config, &cancel_map, session_id, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
}),
|
||||
);
|
||||
}
|
||||
_ = cancellation_token.cancelled() => {
|
||||
drop(listener);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Drain connections
|
||||
while let Some(res) = connections.join_next().await {
|
||||
if let Err(e) = res {
|
||||
if !e.is_panic() && !e.is_cancelled() {
|
||||
warn!("unexpected error from joined connection task: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(fields(session_id = ?session_id), skip_all)]
|
||||
async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls: Arc<rustls::ServerConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let do_handshake = proxy::proxy::handshake(stream, Some(tls), cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
let password = AuthFlow::new(&mut stream)
|
||||
.begin(auth::CleartextPassword)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
let mut conn_cfg = ConnCfg::new();
|
||||
conn_cfg.set_startup_params(¶ms);
|
||||
conn_cfg.password(password);
|
||||
|
||||
// cut off first part of the sni domain
|
||||
let sni = stream.get_ref().sni_hostname().unwrap();
|
||||
let dest = sni
|
||||
.split_once('.').context("invalid sni")?.0
|
||||
.replace("--", ".");
|
||||
|
||||
let destination = format!("{}.{}", dest, dest_suffix);
|
||||
|
||||
info!("destination: {:?}", destination);
|
||||
|
||||
conn_cfg.host(destination.as_str());
|
||||
|
||||
let mut conn = conn_cfg.connect()
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
cancel_map.with_session(|session| async {
|
||||
proxy::proxy::prepare_client_connection(&conn, false, session, &mut stream).await?;
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
conn.stream.write_all(&read_buf).await?;
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(stream, conn.stream, &metrics_aux).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -1,49 +1,22 @@
|
||||
//! Postgres protocol proxy/router.
|
||||
//!
|
||||
//! This service listens psql port and can check auth via external service
|
||||
//! (control plane API in our case) and can create new databases and accounts
|
||||
//! in somewhat transparent manner (again via communication with control plane API).
|
||||
use proxy::auth;
|
||||
use proxy::console;
|
||||
use proxy::http;
|
||||
use proxy::metrics;
|
||||
|
||||
mod auth;
|
||||
mod cache;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
mod config;
|
||||
mod console;
|
||||
mod error;
|
||||
mod http;
|
||||
mod logging;
|
||||
mod metrics;
|
||||
mod parse;
|
||||
mod proxy;
|
||||
mod sasl;
|
||||
mod scram;
|
||||
mod stream;
|
||||
mod url;
|
||||
mod waiters;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use anyhow::bail;
|
||||
use clap::{self, Arg};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::{borrow::Cow, future::Future, net::SocketAddr};
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
use tokio::{net::TcpListener};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
use tracing::info;
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
) -> anyhow::Result<()> {
|
||||
f.map(|r| r.context("join error").and_then(|x| x)).await
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _logging_guard = logging::init().await?;
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
@@ -69,7 +42,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let mut client_tasks = vec![tokio::spawn(proxy::task_main(
|
||||
let mut client_tasks = vec![tokio::spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
@@ -88,7 +61,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
let mut tasks = vec![
|
||||
tokio::spawn(handle_signals(cancellation_token)),
|
||||
tokio::spawn(proxy::handle_signals(cancellation_token)),
|
||||
tokio::spawn(http::server::task_main(http_listener)),
|
||||
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
|
||||
];
|
||||
@@ -97,8 +70,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
|
||||
}
|
||||
|
||||
let tasks = futures::future::try_join_all(tasks.into_iter().map(flatten_err));
|
||||
let client_tasks = futures::future::try_join_all(client_tasks.into_iter().map(flatten_err));
|
||||
let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err));
|
||||
let client_tasks =
|
||||
futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
|
||||
tokio::select! {
|
||||
// We are only expecting an error from these forever tasks
|
||||
res = tasks => { res?; },
|
||||
@@ -107,33 +81,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
let mut interrupt = signal(SignalKind::interrupt())?;
|
||||
let mut terminate = signal(SignalKind::terminate())?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Hangup is commonly used for config reload.
|
||||
_ = hangup.recv() => {
|
||||
warn!("received SIGHUP; config reload is not supported");
|
||||
}
|
||||
// Shut down the whole application.
|
||||
_ = interrupt.recv() => {
|
||||
warn!("received SIGINT, exiting immediately");
|
||||
bail!("interrupted");
|
||||
}
|
||||
_ = terminate.recv() => {
|
||||
warn!("received SIGTERM, shutting down once all existing connections have closed");
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let tls_config = match (
|
||||
57
proxy/src/lib.rs
Normal file
57
proxy/src/lib.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use anyhow::{bail, Context};
|
||||
use futures::{Future, FutureExt};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cache;
|
||||
pub mod cancellation;
|
||||
pub mod compute;
|
||||
pub mod config;
|
||||
pub mod console;
|
||||
pub mod error;
|
||||
pub mod http;
|
||||
pub mod logging;
|
||||
pub mod metrics;
|
||||
pub mod parse;
|
||||
pub mod proxy;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
pub mod stream;
|
||||
pub mod url;
|
||||
pub mod waiters;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
let mut interrupt = signal(SignalKind::interrupt())?;
|
||||
let mut terminate = signal(SignalKind::terminate())?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Hangup is commonly used for config reload.
|
||||
_ = hangup.recv() => {
|
||||
warn!("received SIGHUP; config reload is not supported");
|
||||
}
|
||||
// Shut down the whole application.
|
||||
_ = interrupt.recv() => {
|
||||
warn!("received SIGINT, exiting immediately");
|
||||
bail!("interrupted");
|
||||
}
|
||||
_ = terminate.recv() => {
|
||||
warn!("received SIGTERM, shutting down once all existing connections have closed");
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
pub async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
) -> anyhow::Result<()> {
|
||||
f.map(|r| r.context("join error").and_then(|x| x)).await
|
||||
}
|
||||
@@ -5,7 +5,7 @@ use crate::{
|
||||
auth::{self, backend::AuthSuccess},
|
||||
cancellation::{self, CancelMap},
|
||||
compute::{self, PostgresConnection},
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
config::ProxyConfig,
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
error::io_error,
|
||||
stream::{PqStream, Stream},
|
||||
@@ -174,7 +174,7 @@ async fn handle_client(
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let tls = config.tls_config.as_ref().map(|t| t.to_server_config());
|
||||
let do_handshake = handshake(stream, tls, cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
@@ -184,7 +184,10 @@ async fn handle_client(
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let sni = stream.get_ref().sni_hostname();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let common_names = config
|
||||
.tls_config
|
||||
.as_ref()
|
||||
.and_then(|tls| tls.common_names.clone());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
@@ -205,13 +208,14 @@ async fn handle_client(
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
tls: Option<Arc<rustls::ServerConfig>>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
let mut tls_upgraded = false;
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
@@ -226,8 +230,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
if let Some(tls) = tls.take() {
|
||||
if let Some(tls) = tls.clone() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
@@ -241,7 +246,8 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
|
||||
stream = PqStream::new(raw.upgrade(tls).await?);
|
||||
tls_upgraded = true;
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
@@ -256,9 +262,8 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
// Check that tls was actually upgraded
|
||||
if !tls_upgraded {
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
}
|
||||
|
||||
@@ -340,7 +345,7 @@ async fn connect_to_compute(
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn prepare_client_connection(
|
||||
pub async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
reported_auth_ok: bool,
|
||||
session: cancellation::Session<'_>,
|
||||
@@ -378,7 +383,7 @@ async fn prepare_client_connection(
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn proxy_pass(
|
||||
pub async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: &MetricsAuxInfo,
|
||||
|
||||
Reference in New Issue
Block a user