diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs index 2c69719128..db238ff2ea 100644 --- a/libs/proxy/tokio-postgres2/src/tls.rs +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -39,14 +39,14 @@ pub trait MakeTlsConnect { /// The stream type created by the `TlsConnect` implementation. type Stream: TlsStream + Unpin; /// The `TlsConnect` implementation created by this type. - type TlsConnect<'a>: TlsConnect where Self: 'a; + type TlsConnect: TlsConnect; /// The error type returned by the `TlsConnect` implementation. type Error: Into>; /// Creates a new `TlsConnect`or. /// /// The domain name is provided for certificate verification and SNI. - fn make_tls_connect<'a>(&'a self, domain: &str) -> Result, Self::Error>; + fn make_tls_connect(self, domain: &str) -> Result; } /// An asynchronous function wrapping a stream in a TLS session. @@ -81,10 +81,10 @@ pub struct NoTls; impl MakeTlsConnect for NoTls { type Stream = NoTlsStream; - type TlsConnect<'a> = NoTls; + type TlsConnect = NoTls; type Error = NoTlsError; - fn make_tls_connect(&self, _: &str) -> Result { + fn make_tls_connect(self, _: &str) -> Result { Ok(NoTls) } } diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 38b42e15a3..1f040e7f5d 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -285,7 +285,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig let compute_config = ComputeConfig { retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?, - tls: Arc::new(client_config), + tls: Arc::new(client_config).into(), timeout: Duration::from_secs(2), }; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 1dace2ec8f..b6bbe80391 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -650,7 +650,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let compute_config = ComputeConfig { retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?, - tls: Arc::new(client_config), + tls: Arc::new(client_config).into(), timeout: Duration::from_secs(2), }; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index e9cb486183..6b248e8ae4 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -271,9 +271,8 @@ impl CancelClosure { ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); let tls = >::make_tls_connect( - &mk_tls, + crate::postgres_rustls::MakeRustlsConnect::new(&compute_config.tls), &self.hostname, ) .map_err(|e| { @@ -349,7 +348,7 @@ mod tests { ComputeConfig { retry, - tls: Arc::new(client_config), + tls: Arc::new(client_config).into(), timeout: Duration::from_secs(2), } } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index aecc9b0fa6..5565666231 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -221,7 +221,7 @@ impl ConnCfg { } } -type RustlsStream = >::Stream; +type RustlsStream = crate::postgres_rustls::RustlsStream; pub(crate) struct PostgresConnection { /// Socket connected to a compute node. @@ -251,9 +251,9 @@ impl ConnCfg { let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; drop(pause); - let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); let tls = >::make_tls_connect( - &mk_tls, host, + crate::postgres_rustls::MakeRustlsConnect::new(&config.tls), + host, )?; // connect_raw() will not use TLS if sslmode is "disable" diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 351ff6d12d..5133ce9e1c 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -10,6 +10,7 @@ use remote_storage::RemoteStorageConfig; use rustls::crypto::ring::{self, sign}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use sha2::{Digest, Sha256}; +use tokio_rustls::TlsConnector; use tracing::{error, info}; use x509_parser::oid_registry; @@ -37,7 +38,7 @@ pub struct ProxyConfig { pub struct ComputeConfig { pub retry: RetryConfig, - pub tls: Arc, + pub tls: TlsConnector, pub timeout: Duration, } diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs index fd414556af..b7739f066e 100644 --- a/proxy/src/postgres_rustls/mod.rs +++ b/proxy/src/postgres_rustls/mod.rs @@ -1,9 +1,8 @@ use std::convert::TryFrom; -use std::sync::Arc; use postgres_client::tls::MakeTlsConnect; +pub use private::RustlsStream; use rustls::pki_types::ServerName; -use rustls::ClientConfig; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsConnector; @@ -123,32 +122,30 @@ mod private { /// A `MakeTlsConnect` implementation using `rustls`. /// /// That way you can connect to PostgreSQL using `rustls` as the TLS stack. -pub struct MakeRustlsConnect { - pub config: TlsConnector, +pub struct MakeRustlsConnect<'a> { + pub connector: &'a TlsConnector, } -impl MakeRustlsConnect { - /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. +impl<'a> MakeRustlsConnect<'a> { + /// Creates a new `MakeRustlsConnect` from the provided `TlsConnector`. #[must_use] - pub fn new(config: Arc) -> Self { - Self { - config: config.into(), - } + pub fn new(connector: &'a TlsConnector) -> Self { + Self { connector } } } -impl MakeTlsConnect for MakeRustlsConnect +impl<'a, S> MakeTlsConnect for MakeRustlsConnect<'a> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Stream = private::RustlsStream; - type TlsConnect<'a> = private::RustlsConnectData<'a>; + type TlsConnect = private::RustlsConnectData<'a>; type Error = rustls::pki_types::InvalidDnsNameError; - fn make_tls_connect(&self, hostname: &str) -> Result, Self::Error> { + fn make_tls_connect(self, hostname: &str) -> Result { ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData { hostname: dns_name.to_owned(), - connector: &self.config, + connector: self.connector, }) } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3988482a25..7f473cb420 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -15,6 +15,7 @@ use rstest::rstest; use rustls::crypto::ring; use rustls::{pki_types, RootCertStore}; use tokio::io::DuplexStream; +use tokio_rustls::TlsConnector; use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; @@ -67,15 +68,16 @@ fn generate_certs( } struct ClientConfig<'a> { - config: MakeRustlsConnect, + config: TlsConnector, hostname: &'a str, } -type TlsConnect<'a, S> = >::TlsConnect<'a>; +type TlsConnect<'a, S> = as MakeTlsConnect>::TlsConnect; impl ClientConfig<'_> { fn make_tls_connect(&self) -> anyhow::Result> { - let tls = MakeTlsConnect::::make_tls_connect(&self.config, self.hostname)?; + let mk = MakeRustlsConnect::new(&self.config); + let tls = MakeTlsConnect::::make_tls_connect(mk, self.hostname)?; Ok(tls) } } @@ -122,7 +124,7 @@ fn generate_tls_config<'a>( let config = Arc::new(config); ClientConfig { - config: MakeRustlsConnect::new(config), + config: TlsConnector::from(config), hostname, } }; @@ -597,7 +599,7 @@ fn config() -> ComputeConfig { ComputeConfig { retry, - tls: Arc::new(client_config), + tls: Arc::new(client_config).into(), timeout: Duration::from_secs(2), } }