mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-04 12:02:55 +00:00
completely rewrite pq_proto (#12085)
libs/pqproto is designed for safekeeper/pageserver with maximum throughput. proxy only needs it for handshakes/authentication where throughput is not a concern but memory efficiency is. For this reason, we switch to using read_exact and only allocating as much memory as we need to. All reads return a `&'a [u8]` instead of a `Bytes` because accidental sharing of bytes can cause fragmentation. Returning the reference enforces all callers only hold onto the bytes they absolutely need. For example, before this change, `pqproto` was allocating 8KiB for the initial read `BytesMut`, and proxy was holding the `Bytes` in the `StartupMessageParams` for the entire connection through to passthrough.
This commit is contained in:
@@ -4,8 +4,9 @@
|
||||
//! This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
//! the outside. Similar to an ingress controller for HTTPS.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, anyhow, bail, ensure};
|
||||
use clap::Arg;
|
||||
@@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
@@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
|
||||
use crate::proxy::{
|
||||
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
|
||||
};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
@@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
let tls_config = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
@@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest,
|
||||
tls_config,
|
||||
Some(compute_tls_config),
|
||||
tls_server_end_point,
|
||||
proxy_listener_compute_tls,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
@@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
pub(super) fn parse_tls(
|
||||
key_path: &Path,
|
||||
cert_path: &Path,
|
||||
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
|
||||
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
|
||||
@@ -187,10 +189,6 @@ pub(super) fn parse_tls(
|
||||
})?
|
||||
};
|
||||
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
@@ -199,14 +197,13 @@ pub(super) fn parse_tls(
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
Ok((tls_config, tls_server_end_point))
|
||||
Ok(tls_config)
|
||||
}
|
||||
|
||||
pub(super) async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -242,15 +239,7 @@ pub(super) async fn task_main(
|
||||
crate::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(
|
||||
ctx,
|
||||
dest_suffix,
|
||||
tls_config,
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
)
|
||||
.await
|
||||
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -269,55 +258,26 @@ pub(super) async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
|
||||
) -> anyhow::Result<TlsStream<S>> {
|
||||
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
match msg {
|
||||
SslRequest { direct: false } => {
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
|
||||
.await?;
|
||||
FeStartupPacket::SslRequest { direct: None } => {
|
||||
let raw = stream.accept_tls().await?;
|
||||
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
Ok(raw
|
||||
.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?)
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
?unexpected,
|
||||
"unexpected startup packet, rejecting connection"
|
||||
);
|
||||
stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
|
||||
.await?
|
||||
Err(stream.throw_error(TlsRequired, None).await)?
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,15 +287,18 @@ async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
||||
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
|
||||
let sni = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.server_name()
|
||||
.ok_or(anyhow!("SNI missing"))?;
|
||||
let dest: Vec<&str> = sni
|
||||
.split_once('.')
|
||||
.context("invalid SNI")?
|
||||
|
||||
@@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let key_path = args.tls_key.expect("already asserted it is set");
|
||||
let cert_path = args.tls_cert.expect("already asserted it is set");
|
||||
|
||||
let (tls_config, tls_server_end_point) =
|
||||
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
|
||||
let dest = Arc::new(dest);
|
||||
|
||||
@@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
listen,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
@@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest,
|
||||
tls_config,
|
||||
Some(config.connect_to_compute.tls.clone()),
|
||||
tls_server_end_point,
|
||||
listen_tls,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
|
||||
Reference in New Issue
Block a user