mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 10:22:56 +00:00
A smaller version of #12066 that is somewhat easier to review. Now that I've been using https://crates.io/crates/top-type-sizes I've found a lot more of the low hanging fruit that can be tweaks to reduce the memory usage. Some context for the optimisations: Rust's stack allocation in futures is quite naive. Stack variables, even if moved, often still end up taking space in the future. Rearranging the order in which variables are defined, and properly scoping them can go a long way. `async fn` and `async move {}` have a consequence that they always duplicate the "upvars" (aka captures). All captures are permanently allocated in the future, even if moved. We can be mindful when writing futures to only capture as little as possible. TlsStream is massive. Needs boxing so it doesn't contribute to the above issue. ## Measurements from `top-type-sizes`: ### Before ``` 10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8 6120 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8 ``` ### After ``` 4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} 4704 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8 ```
159 lines
4.5 KiB
Rust
159 lines
4.5 KiB
Rust
use std::convert::TryFrom;
|
|
use std::sync::Arc;
|
|
|
|
use postgres_client::tls::MakeTlsConnect;
|
|
use rustls::ClientConfig;
|
|
use rustls::pki_types::ServerName;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
|
|
mod private {
|
|
use std::future::Future;
|
|
use std::io;
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
|
|
use postgres_client::tls::{ChannelBinding, TlsConnect};
|
|
use rustls::pki_types::ServerName;
|
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
use tokio_rustls::TlsConnector;
|
|
use tokio_rustls::client::TlsStream;
|
|
|
|
use crate::tls::TlsServerEndPoint;
|
|
|
|
pub struct TlsConnectFuture<S> {
|
|
inner: tokio_rustls::Connect<S>,
|
|
}
|
|
|
|
impl<S> Future for TlsConnectFuture<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
type Output = io::Result<RustlsStream<S>>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
Pin::new(&mut self.inner)
|
|
.poll(cx)
|
|
.map_ok(|s| RustlsStream(Box::new(s)))
|
|
}
|
|
}
|
|
|
|
pub struct RustlsConnect(pub RustlsConnectData);
|
|
|
|
pub struct RustlsConnectData {
|
|
pub hostname: ServerName<'static>,
|
|
pub connector: TlsConnector,
|
|
}
|
|
|
|
impl<S> TlsConnect<S> for RustlsConnect
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
{
|
|
type Stream = RustlsStream<S>;
|
|
type Error = io::Error;
|
|
type Future = TlsConnectFuture<S>;
|
|
|
|
fn connect(self, stream: S) -> Self::Future {
|
|
TlsConnectFuture {
|
|
inner: self.0.connector.connect(self.0.hostname, stream),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct RustlsStream<S>(Box<TlsStream<S>>);
|
|
|
|
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
fn channel_binding(&self) -> ChannelBinding {
|
|
let (_, session) = self.0.get_ref();
|
|
match session.peer_certificates() {
|
|
Some([cert, ..]) => TlsServerEndPoint::new(cert)
|
|
.ok()
|
|
.and_then(|cb| match cb {
|
|
TlsServerEndPoint::Sha256(hash) => Some(hash),
|
|
TlsServerEndPoint::Undefined => None,
|
|
})
|
|
.map_or_else(ChannelBinding::none, |hash| {
|
|
ChannelBinding::tls_server_end_point(hash.to_vec())
|
|
}),
|
|
_ => ChannelBinding::none(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> AsyncRead for RustlsStream<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<tokio::io::Result<()>> {
|
|
Pin::new(&mut self.0).poll_read(cx, buf)
|
|
}
|
|
}
|
|
|
|
impl<S> AsyncWrite for RustlsStream<S>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<tokio::io::Result<usize>> {
|
|
Pin::new(&mut self.0).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<tokio::io::Result<()>> {
|
|
Pin::new(&mut self.0).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<tokio::io::Result<()>> {
|
|
Pin::new(&mut self.0).poll_shutdown(cx)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A `MakeTlsConnect` implementation using `rustls`.
|
|
///
|
|
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
|
#[derive(Clone)]
|
|
pub struct MakeRustlsConnect {
|
|
pub config: Arc<ClientConfig>,
|
|
}
|
|
|
|
impl MakeRustlsConnect {
|
|
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
|
#[must_use]
|
|
pub fn new(config: Arc<ClientConfig>) -> Self {
|
|
Self { config }
|
|
}
|
|
}
|
|
|
|
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
{
|
|
type Stream = private::RustlsStream<S>;
|
|
type TlsConnect = private::RustlsConnect;
|
|
type Error = rustls::pki_types::InvalidDnsNameError;
|
|
|
|
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
|
ServerName::try_from(hostname).map(|dns_name| {
|
|
private::RustlsConnect(private::RustlsConnectData {
|
|
hostname: dns_name.to_owned(),
|
|
connector: Arc::clone(&self.config).into(),
|
|
})
|
|
})
|
|
}
|
|
}
|