mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-17 21:20:37 +00:00
## Problem Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can get away with temporarily caching some steps so we only need fewer rounds, which will save some CPU time. ## Summary of changes The output of pbkdf2 is the XOR of the outputs of each iteration round, eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`. The suffix by itself is useless, which prevent's its use in brute-force attacks should this cached memory leak. We are also caching the full 4096 round hash in memory, which can be used for brute-force attacks, where this suffix could be used to speed it up. My hope/expectation is that since these will be in different allocations, it makes any such memory exploitation much much harder. Since the full hash cache might be invalidated while the suffix is cached, I'm storing the timestamp of the computation as a way to identity the match. I also added `zeroize()` to clear the sensitive state from the stack/heap. For the most security conscious customers, we hope to roll out OIDC soon, so they can disable passwords entirely. --- The numbers for the threadpool were pretty random, but according to our busiest region for sql-over-http, we only see about 150 unique endpoints every minute. So storing ~100 of the most common endpoints for that minute should be the vast majority of requests. 1 minute was chosen so we don't keep data in memory for too long.
383 lines
12 KiB
Rust
383 lines
12 KiB
Rust
//! A stand-alone program that routes connections, e.g. from
|
|
//! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
|
|
//!
|
|
//! This allows connecting to pods/services running in the same Kubernetes cluster from
|
|
//! the outside. Similar to an ingress controller for HTTPS.
|
|
|
|
use std::io;
|
|
use std::net::SocketAddr;
|
|
use std::path::Path;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{Context, anyhow, bail, ensure};
|
|
use clap::Arg;
|
|
use futures::future::Either;
|
|
use futures::{FutureExt, TryFutureExt};
|
|
use itertools::Itertools;
|
|
use rustls::crypto::ring;
|
|
use rustls::pki_types::{DnsName, PrivateKeyDer};
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
use tokio::net::TcpListener;
|
|
use tokio_rustls::TlsConnector;
|
|
use tokio_rustls::server::TlsStream;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{Instrument, error, info};
|
|
use utils::project_git_version;
|
|
use utils::sentry_init::init_sentry;
|
|
|
|
use crate::context::RequestContext;
|
|
use crate::metrics::{Metrics, ServiceInfo};
|
|
use crate::pglb::TlsRequired;
|
|
use crate::pqproto::FeStartupPacket;
|
|
use crate::protocol2::ConnectionInfo;
|
|
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
|
|
use crate::stream::{PqStream, Stream};
|
|
use crate::util::run_until_cancelled;
|
|
|
|
project_git_version!(GIT_VERSION);
|
|
|
|
fn cli() -> clap::Command {
|
|
clap::Command::new("Neon proxy/router")
|
|
.version(GIT_VERSION)
|
|
.arg(
|
|
Arg::new("listen")
|
|
.short('l')
|
|
.long("listen")
|
|
.help("listen for incoming client connections on ip:port")
|
|
.default_value("127.0.0.1:4432"),
|
|
)
|
|
.arg(
|
|
Arg::new("listen-tls")
|
|
.long("listen-tls")
|
|
.help("listen for incoming client connections on ip:port, requiring TLS to compute")
|
|
.default_value("127.0.0.1:4433"),
|
|
)
|
|
.arg(
|
|
Arg::new("tls-key")
|
|
.short('k')
|
|
.long("tls-key")
|
|
.help("path to TLS key for client postgres connections")
|
|
.required(true),
|
|
)
|
|
.arg(
|
|
Arg::new("tls-cert")
|
|
.short('c')
|
|
.long("tls-cert")
|
|
.help("path to TLS cert for client postgres connections")
|
|
.required(true),
|
|
)
|
|
.arg(
|
|
Arg::new("dest")
|
|
.short('d')
|
|
.long("destination")
|
|
.help("append this domain zone to the SNI hostname to get the destination address")
|
|
.required(true),
|
|
)
|
|
}
|
|
|
|
pub async fn run() -> anyhow::Result<()> {
|
|
let _logging_guard = crate::logging::init()?;
|
|
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
|
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
|
|
|
let args = cli().get_matches();
|
|
let destination: String = args
|
|
.get_one::<String>("dest")
|
|
.expect("string argument defined")
|
|
.parse()?;
|
|
|
|
// Configure TLS
|
|
let tls_config = match (
|
|
args.get_one::<String>("tls-key"),
|
|
args.get_one::<String>("tls-cert"),
|
|
) {
|
|
(Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
|
|
_ => bail!("tls-key and tls-cert must be specified"),
|
|
};
|
|
|
|
let compute_tls_config =
|
|
Arc::new(crate::tls::client_config::compute_client_config_with_root_certs()?);
|
|
|
|
// Start listening for incoming client connections
|
|
let proxy_address: SocketAddr = args
|
|
.get_one::<String>("listen")
|
|
.expect("listen argument defined")
|
|
.parse()?;
|
|
let proxy_address_compute_tls: SocketAddr = args
|
|
.get_one::<String>("listen-tls")
|
|
.expect("listen-tls argument defined")
|
|
.parse()?;
|
|
|
|
info!("Starting sni router on {proxy_address}");
|
|
info!("Starting sni router on {proxy_address_compute_tls}");
|
|
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
|
let proxy_listener_compute_tls = TcpListener::bind(proxy_address_compute_tls).await?;
|
|
|
|
let cancellation_token = CancellationToken::new();
|
|
let dest = Arc::new(destination);
|
|
|
|
let main = tokio::spawn(task_main(
|
|
dest.clone(),
|
|
tls_config.clone(),
|
|
None,
|
|
proxy_listener,
|
|
cancellation_token.clone(),
|
|
))
|
|
.map(crate::error::flatten_err);
|
|
|
|
let main_tls = tokio::spawn(task_main(
|
|
dest,
|
|
tls_config,
|
|
Some(compute_tls_config),
|
|
proxy_listener_compute_tls,
|
|
cancellation_token.clone(),
|
|
))
|
|
.map(crate::error::flatten_err);
|
|
|
|
Metrics::get()
|
|
.service
|
|
.info
|
|
.set_label(ServiceInfo::running());
|
|
|
|
let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
|
|
|
|
// the signal task cant ever succeed.
|
|
// the main task can error, or can succeed on cancellation.
|
|
// we want to immediately exit on either of these cases
|
|
let main = futures::future::try_join(main, main_tls);
|
|
let signal = match futures::future::select(signals_task, main).await {
|
|
Either::Left((res, _)) => crate::error::flatten_err(res)?,
|
|
Either::Right((res, _)) => {
|
|
res?;
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
// maintenance tasks return `Infallible` success values, this is an impossible value
|
|
// so this match statically ensures that there are no possibilities for that value
|
|
match signal {}
|
|
}
|
|
|
|
pub(super) fn parse_tls(
|
|
key_path: &Path,
|
|
cert_path: &Path,
|
|
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
|
|
let key = {
|
|
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
|
|
|
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
|
|
|
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
|
PrivateKeyDer::Pkcs8(
|
|
keys.pop()
|
|
.expect("keys should not be empty")
|
|
.context(format!(
|
|
"Failed to read TLS keys at '{}'",
|
|
key_path.display()
|
|
))?,
|
|
)
|
|
};
|
|
|
|
let cert_chain_bytes = std::fs::read(cert_path).context(format!(
|
|
"Failed to read TLS cert file at '{}.'",
|
|
cert_path.display()
|
|
))?;
|
|
|
|
let cert_chain: Vec<_> = {
|
|
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
|
.try_collect()
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to read TLS certificate chain from bytes from file at '{}'.",
|
|
cert_path.display()
|
|
)
|
|
})?
|
|
};
|
|
|
|
let tls_config =
|
|
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
|
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
|
.context("ring should support TLS1.2 and TLS1.3")?
|
|
.with_no_client_auth()
|
|
.with_single_cert(cert_chain, key)?
|
|
.into();
|
|
|
|
Ok(tls_config)
|
|
}
|
|
|
|
pub(super) async fn task_main(
|
|
dest_suffix: Arc<String>,
|
|
tls_config: Arc<rustls::ServerConfig>,
|
|
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
|
listener: tokio::net::TcpListener,
|
|
cancellation_token: CancellationToken,
|
|
) -> anyhow::Result<()> {
|
|
// When set for the server socket, the keepalive setting
|
|
// will be inherited by all accepted client sockets.
|
|
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
|
|
|
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
|
|
|
while let Some(accept_result) =
|
|
run_until_cancelled(listener.accept(), &cancellation_token).await
|
|
{
|
|
let (socket, peer_addr) = accept_result?;
|
|
|
|
let session_id = uuid::Uuid::new_v4();
|
|
let tls_config = Arc::clone(&tls_config);
|
|
let dest_suffix = Arc::clone(&dest_suffix);
|
|
let compute_tls_config = compute_tls_config.clone();
|
|
|
|
connections.spawn(
|
|
async move {
|
|
socket
|
|
.set_nodelay(true)
|
|
.context("failed to set socket option")?;
|
|
|
|
let ctx = RequestContext::new(
|
|
session_id,
|
|
ConnectionInfo {
|
|
addr: peer_addr,
|
|
extra: None,
|
|
},
|
|
crate::metrics::Protocol::SniRouter,
|
|
);
|
|
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
|
}
|
|
.unwrap_or_else(|e| {
|
|
if let Some(FirstMessage(io_error)) = e.downcast_ref() {
|
|
// this is noisy. if we get EOF on the very first message that's likely
|
|
// just NLB doing a healthcheck.
|
|
if io_error.kind() == io::ErrorKind::UnexpectedEof {
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Acknowledge that the task has finished with an error.
|
|
error!("per-client task finished with an error: {e:#}");
|
|
})
|
|
.instrument(tracing::info_span!("handle_client", ?session_id)),
|
|
);
|
|
}
|
|
|
|
connections.close();
|
|
drop(listener);
|
|
|
|
connections.wait().await;
|
|
|
|
info!("all client connections have finished");
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
#[error(transparent)]
|
|
struct FirstMessage(io::Error);
|
|
|
|
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
|
ctx: &RequestContext,
|
|
raw_stream: S,
|
|
tls_config: Arc<rustls::ServerConfig>,
|
|
) -> anyhow::Result<TlsStream<S>> {
|
|
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
|
|
.await
|
|
.map_err(FirstMessage)?;
|
|
|
|
match msg {
|
|
FeStartupPacket::SslRequest { direct: None } => {
|
|
let raw = stream.accept_tls().await?;
|
|
|
|
Ok(raw
|
|
.upgrade(tls_config, !ctx.has_private_peer_addr())
|
|
.await?)
|
|
}
|
|
unexpected => {
|
|
info!(
|
|
?unexpected,
|
|
"unexpected startup packet, rejecting connection"
|
|
);
|
|
Err(stream.throw_error(TlsRequired, None).await)?
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_client(
|
|
ctx: RequestContext,
|
|
dest_suffix: Arc<String>,
|
|
tls_config: Arc<rustls::ServerConfig>,
|
|
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
|
stream: impl AsyncRead + AsyncWrite + Unpin,
|
|
) -> anyhow::Result<()> {
|
|
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
|
|
|
|
// Cut off first part of the SNI domain
|
|
// We receive required destination details in the format of
|
|
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
|
let sni = tls_stream
|
|
.get_ref()
|
|
.1
|
|
.server_name()
|
|
.ok_or(anyhow!("SNI missing"))?;
|
|
let dest: Vec<&str> = sni
|
|
.split_once('.')
|
|
.context("invalid SNI")?
|
|
.0
|
|
.splitn(3, "--")
|
|
.collect();
|
|
let port = dest[2].parse::<u16>().context("invalid port")?;
|
|
let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
|
|
|
|
info!("destination: {}", destination);
|
|
|
|
let mut client = tokio::net::TcpStream::connect(&destination).await?;
|
|
|
|
let client = if let Some(compute_tls_config) = compute_tls_config {
|
|
info!("upgrading TLS");
|
|
|
|
// send SslRequest
|
|
client
|
|
.write_all(b"\x00\x00\x00\x08\x04\xd2\x16\x2f")
|
|
.await?;
|
|
|
|
// wait for S/N respons
|
|
let mut resp = b'N';
|
|
client.read_exact(std::slice::from_mut(&mut resp)).await?;
|
|
|
|
// error if not S
|
|
ensure!(resp == b'S', "compute refused TLS");
|
|
|
|
// upgrade to TLS.
|
|
let domain = DnsName::try_from(destination)?;
|
|
let domain = rustls::pki_types::ServerName::DnsName(domain);
|
|
let client = TlsConnector::from(compute_tls_config)
|
|
.connect(domain, client)
|
|
.await?;
|
|
Connection::Tls(client)
|
|
} else {
|
|
Connection::Raw(client)
|
|
};
|
|
|
|
// doesn't yet matter as pg-sni-router doesn't report analytics logs
|
|
ctx.set_success();
|
|
ctx.log_connect();
|
|
|
|
// Starting from here we only proxy the client's traffic.
|
|
info!("performing the proxy pass...");
|
|
|
|
let res = match client {
|
|
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
|
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
|
};
|
|
|
|
match res {
|
|
Ok(_) => Ok(()),
|
|
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
|
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::large_enum_variant)]
|
|
enum Connection {
|
|
Raw(tokio::net::TcpStream),
|
|
Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
|
|
}
|