reduce cloning

This commit is contained in:
Conrad Ludgate
2024-12-18 11:57:10 +00:00
parent bbc799ce77
commit b79a1dd337
7 changed files with 32 additions and 33 deletions

View File

@@ -9,7 +9,7 @@ use std::io;
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
mut tls: T,
tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>

View File

@@ -10,7 +10,7 @@ use tokio::net::TcpStream;
use tokio::sync::mpsc;
pub async fn connect<T>(
mut tls: T,
tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where

View File

@@ -39,14 +39,14 @@ pub trait MakeTlsConnect<S> {
/// The stream type created by the `TlsConnect` implementation.
type Stream: TlsStream + Unpin;
/// The `TlsConnect` implementation created by this type.
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
type TlsConnect<'a>: TlsConnect<S, Stream = Self::Stream> where Self: 'a;
/// The error type returned by the `TlsConnect` implementation.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
fn make_tls_connect<'a>(&'a self, domain: &str) -> Result<Self::TlsConnect<'a>, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -81,10 +81,10 @@ pub struct NoTls;
impl<S> MakeTlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type TlsConnect = NoTls;
type TlsConnect<'a> = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -271,9 +271,9 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&mk_tls,
&self.hostname,
)
.map_err(|e| {

View File

@@ -251,10 +251,9 @@ impl ConnCfg {
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
&mk_tls, host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"

View File

@@ -5,6 +5,7 @@ use postgres_client::tls::MakeTlsConnect;
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
mod private {
use std::future::Future;
@@ -35,14 +36,12 @@ mod private {
}
}
pub struct RustlsConnect(pub RustlsConnectData);
pub struct RustlsConnectData {
pub struct RustlsConnectData<'a> {
pub hostname: ServerName<'static>,
pub connector: TlsConnector,
pub connector: &'a TlsConnector,
}
impl<S> TlsConnect<S> for RustlsConnect
impl<S> TlsConnect<S> for RustlsConnectData<'_>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@@ -52,7 +51,7 @@ mod private {
fn connect(self, stream: S) -> Self::Future {
TlsConnectFuture {
inner: self.0.connector.connect(self.0.hostname, stream),
inner: self.connector.connect(self.hostname, stream),
}
}
}
@@ -124,16 +123,17 @@ mod private {
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
pub config: TlsConnector,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
Self {
config: config.into(),
}
}
}
@@ -142,15 +142,13 @@ where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type TlsConnect<'a> = private::RustlsConnectData<'a>;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect<'_>, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: &self.config,
})
}
}

View File

@@ -67,16 +67,15 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: Arc<rustls::ClientConfig>,
config: MakeRustlsConnect,
hostname: &'a str,
}
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
type TlsConnect<'a, S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect<'a>;
impl ClientConfig<'_> {
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
fn make_tls_connect(&self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&self.config, self.hostname)?;
Ok(tls)
}
}
@@ -122,7 +121,10 @@ fn generate_tls_config<'a>(
.with_no_client_auth();
let config = Arc::new(config);
ClientConfig { config, hostname }
ClientConfig {
config: MakeRustlsConnect::new(config),
hostname,
}
};
Ok((client_config, tls_config))