Files
neon/proxy/src/postgres_rustls/mod.rs
Conrad Ludgate e61ec94fbc chore(proxy): vendor a subset of rust-postgres (#9930)
Our rust-postgres fork is getting messy. Mostly because proxy wants more
control over the raw protocol than tokio-postgres provides. As such,
it's diverging more and more. Storage and compute also make use of
rust-postgres, but in more normal usage, thus they don't need our crazy
changes.

Idea: 
* proxy maintains their subset
* other teams use a minimal patch set against upstream rust-postgres

Reviewing this code will be difficult. To implement it, I
1. Copied tokio-postgres, postgres-protocol and postgres-types from
00940fcdb5
2. Updated their package names with the `2` suffix to make them compile
in the workspace.
3. Updated proxy to use those packages
4. Copied in the code from tokio-postgres-rustls 0.13 (with some patches
applied https://github.com/jbg/tokio-postgres-rustls/pull/32
https://github.com/jbg/tokio-postgres-rustls/pull/33)
5. Removed as much dead code as I could find in the vendored libraries
6. Updated the tokio-postgres-rustls code to use our existing channel
binding implementation
2024-12-05 13:00:40 +02:00

159 lines
4.5 KiB
Rust

use std::convert::TryFrom;
use std::sync::Arc;
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::tls::MakeTlsConnect;
mod private {
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use rustls::pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use crate::config::TlsServerEndPoint;
pub struct TlsConnectFuture<S> {
inner: tokio_rustls::Connect<S>,
}
impl<S> Future for TlsConnectFuture<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<RustlsStream<S>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
}
}
pub struct RustlsConnect(pub RustlsConnectData);
pub struct RustlsConnectData {
pub hostname: ServerName<'static>,
pub connector: TlsConnector,
}
impl<S> TlsConnect<S> for RustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = RustlsStream<S>;
type Error = io::Error;
type Future = TlsConnectFuture<S>;
fn connect(self, stream: S) -> Self::Future {
TlsConnectFuture {
inner: self.0.connector.connect(self.0.hostname, stream),
}
}
}
pub struct RustlsStream<S>(TlsStream<S>);
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
let (_, session) = self.0.get_ref();
match session.peer_certificates() {
Some([cert, ..]) => TlsServerEndPoint::new(cert)
.ok()
.and_then(|cb| match cb {
TlsServerEndPoint::Sha256(hash) => Some(hash),
TlsServerEndPoint::Undefined => None,
})
.map_or_else(ChannelBinding::none, |hash| {
ChannelBinding::tls_server_end_point(hash.to_vec())
}),
_ => ChannelBinding::none(),
}
}
}
impl<S> AsyncRead for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<S> AsyncWrite for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<tokio::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: ClientConfig) -> Self {
Self {
config: Arc::new(config),
}
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
})
}
}