Files
neon/proxy/src/tls/postgres_rustls.rs
Conrad Ludgate 781bf4945d proxy: optimise future layout allocations (#12104)
A smaller version of #12066 that is somewhat easier to review.

Now that I've been using https://crates.io/crates/top-type-sizes I've
found a lot more of the low hanging fruit that can be tweaks to reduce
the memory usage.

Some context for the optimisations:

Rust's stack allocation in futures is quite naive. Stack variables, even
if moved, often still end up taking space in the future. Rearranging the
order in which variables are defined, and properly scoping them can go a
long way.

`async fn` and `async move {}` have a consequence that they always
duplicate the "upvars" (aka captures). All captures are permanently
allocated in the future, even if moved. We can be mindful when writing
futures to only capture as little as possible.

TlsStream is massive. Needs boxing so it doesn't contribute to the above
issue.

## Measurements from `top-type-sizes`:

### Before

```
10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8
6120 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8
```

### After

```
4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}}
4704 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8
```
2025-06-02 16:13:30 +00:00

159 lines
4.5 KiB
Rust

use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
use rustls::ClientConfig;
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
mod private {
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use postgres_client::tls::{ChannelBinding, TlsConnect};
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use crate::tls::TlsServerEndPoint;
pub struct TlsConnectFuture<S> {
inner: tokio_rustls::Connect<S>,
}
impl<S> Future for TlsConnectFuture<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<RustlsStream<S>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner)
.poll(cx)
.map_ok(|s| RustlsStream(Box::new(s)))
}
}
pub struct RustlsConnect(pub RustlsConnectData);
pub struct RustlsConnectData {
pub hostname: ServerName<'static>,
pub connector: TlsConnector,
}
impl<S> TlsConnect<S> for RustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = RustlsStream<S>;
type Error = io::Error;
type Future = TlsConnectFuture<S>;
fn connect(self, stream: S) -> Self::Future {
TlsConnectFuture {
inner: self.0.connector.connect(self.0.hostname, stream),
}
}
}
pub struct RustlsStream<S>(Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
let (_, session) = self.0.get_ref();
match session.peer_certificates() {
Some([cert, ..]) => TlsServerEndPoint::new(cert)
.ok()
.and_then(|cb| match cb {
TlsServerEndPoint::Sha256(hash) => Some(hash),
TlsServerEndPoint::Undefined => None,
})
.map_or_else(ChannelBinding::none, |hash| {
ChannelBinding::tls_server_end_point(hash.to_vec())
}),
_ => ChannelBinding::none(),
}
}
}
impl<S> AsyncRead for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<S> AsyncWrite for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
})
}
}