diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs index cddbf16336..67e921211c 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_query.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -9,7 +9,7 @@ use std::io; pub(crate) async fn cancel_query( config: Option, ssl_mode: SslMode, - mut tls: T, + tls: T, process_id: i32, secret_key: i32, ) -> Result<(), Error> diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index e0cb69748d..dd53eb0d84 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -10,7 +10,7 @@ use tokio::net::TcpStream; use tokio::sync::mpsc; pub async fn connect( - mut tls: T, + tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs index dc8140719f..2c69719128 100644 --- a/libs/proxy/tokio-postgres2/src/tls.rs +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -39,14 +39,14 @@ pub trait MakeTlsConnect { /// The stream type created by the `TlsConnect` implementation. type Stream: TlsStream + Unpin; /// The `TlsConnect` implementation created by this type. - type TlsConnect: TlsConnect; + type TlsConnect<'a>: TlsConnect where Self: 'a; /// The error type returned by the `TlsConnect` implementation. type Error: Into>; /// Creates a new `TlsConnect`or. /// /// The domain name is provided for certificate verification and SNI. - fn make_tls_connect(&mut self, domain: &str) -> Result; + fn make_tls_connect<'a>(&'a self, domain: &str) -> Result, Self::Error>; } /// An asynchronous function wrapping a stream in a TLS session. @@ -81,10 +81,10 @@ pub struct NoTls; impl MakeTlsConnect for NoTls { type Stream = NoTlsStream; - type TlsConnect = NoTls; + type TlsConnect<'a> = NoTls; type Error = NoTlsError; - fn make_tls_connect(&mut self, _: &str) -> Result { + fn make_tls_connect(&self, _: &str) -> Result { Ok(NoTls) } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index e989f4bbd1..e9cb486183 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -271,9 +271,9 @@ impl CancelClosure { ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); + let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); let tls = >::make_tls_connect( - &mut mk_tls, + &mk_tls, &self.hostname, ) .map_err(|e| { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 17588b9c34..aecc9b0fa6 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -251,10 +251,9 @@ impl ConnCfg { let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; drop(pause); - let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); + let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); let tls = >::make_tls_connect( - &mut mk_tls, - host, + &mk_tls, host, )?; // connect_raw() will not use TLS if sslmode is "disable" diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs index abf48d6f82..fd414556af 100644 --- a/proxy/src/postgres_rustls/mod.rs +++ b/proxy/src/postgres_rustls/mod.rs @@ -5,6 +5,7 @@ use postgres_client::tls::MakeTlsConnect; use rustls::pki_types::ServerName; use rustls::ClientConfig; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::TlsConnector; mod private { use std::future::Future; @@ -35,14 +36,12 @@ mod private { } } - pub struct RustlsConnect(pub RustlsConnectData); - - pub struct RustlsConnectData { + pub struct RustlsConnectData<'a> { pub hostname: ServerName<'static>, - pub connector: TlsConnector, + pub connector: &'a TlsConnector, } - impl TlsConnect for RustlsConnect + impl TlsConnect for RustlsConnectData<'_> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -52,7 +51,7 @@ mod private { fn connect(self, stream: S) -> Self::Future { TlsConnectFuture { - inner: self.0.connector.connect(self.0.hostname, stream), + inner: self.connector.connect(self.hostname, stream), } } } @@ -124,16 +123,17 @@ mod private { /// A `MakeTlsConnect` implementation using `rustls`. /// /// That way you can connect to PostgreSQL using `rustls` as the TLS stack. -#[derive(Clone)] pub struct MakeRustlsConnect { - pub config: Arc, + pub config: TlsConnector, } impl MakeRustlsConnect { /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. #[must_use] pub fn new(config: Arc) -> Self { - Self { config } + Self { + config: config.into(), + } } } @@ -142,15 +142,13 @@ where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Stream = private::RustlsStream; - type TlsConnect = private::RustlsConnect; + type TlsConnect<'a> = private::RustlsConnectData<'a>; type Error = rustls::pki_types::InvalidDnsNameError; - fn make_tls_connect(&mut self, hostname: &str) -> Result { - ServerName::try_from(hostname).map(|dns_name| { - private::RustlsConnect(private::RustlsConnectData { - hostname: dns_name.to_owned(), - connector: Arc::clone(&self.config).into(), - }) + fn make_tls_connect(&self, hostname: &str) -> Result, Self::Error> { + ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: &self.config, }) } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 5dc13982bd..3988482a25 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -67,16 +67,15 @@ fn generate_certs( } struct ClientConfig<'a> { - config: Arc, + config: MakeRustlsConnect, hostname: &'a str, } -type TlsConnect = >::TlsConnect; +type TlsConnect<'a, S> = >::TlsConnect<'a>; impl ClientConfig<'_> { - fn make_tls_connect(self) -> anyhow::Result> { - let mut mk = MakeRustlsConnect::new(self.config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; + fn make_tls_connect(&self) -> anyhow::Result> { + let tls = MakeTlsConnect::::make_tls_connect(&self.config, self.hostname)?; Ok(tls) } } @@ -122,7 +121,10 @@ fn generate_tls_config<'a>( .with_no_client_auth(); let config = Arc::new(config); - ClientConfig { config, hostname } + ClientConfig { + config: MakeRustlsConnect::new(config), + hostname, + } }; Ok((client_config, tls_config))