properly remove clones

This commit is contained in:
Conrad Ludgate
2024-12-18 12:06:57 +00:00
parent b79a1dd337
commit 90ce4f3002
8 changed files with 31 additions and 32 deletions

View File

@@ -39,14 +39,14 @@ pub trait MakeTlsConnect<S> {
/// The stream type created by the `TlsConnect` implementation.
type Stream: TlsStream + Unpin;
/// The `TlsConnect` implementation created by this type.
type TlsConnect<'a>: TlsConnect<S, Stream = Self::Stream> where Self: 'a;
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
/// The error type returned by the `TlsConnect` implementation.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect<'a>(&'a self, domain: &str) -> Result<Self::TlsConnect<'a>, Self::Error>;
fn make_tls_connect(self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -81,10 +81,10 @@ pub struct NoTls;
impl<S> MakeTlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type TlsConnect<'a> = NoTls;
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -285,7 +285,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(client_config),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
};

View File

@@ -650,7 +650,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(client_config),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
};

View File

@@ -271,9 +271,8 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mk_tls,
crate::postgres_rustls::MakeRustlsConnect::new(&compute_config.tls),
&self.hostname,
)
.map_err(|e| {
@@ -349,7 +348,7 @@ mod tests {
ComputeConfig {
retry,
tls: Arc::new(client_config),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
}
}

View File

@@ -221,7 +221,7 @@ impl ConnCfg {
}
}
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
type RustlsStream = crate::postgres_rustls::RustlsStream<tokio::net::TcpStream>;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
@@ -251,9 +251,9 @@ impl ConnCfg {
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mk_tls, host,
crate::postgres_rustls::MakeRustlsConnect::new(&config.tls),
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"

View File

@@ -10,6 +10,7 @@ use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use tokio_rustls::TlsConnector;
use tracing::{error, info};
use x509_parser::oid_registry;
@@ -37,7 +38,7 @@ pub struct ProxyConfig {
pub struct ComputeConfig {
pub retry: RetryConfig,
pub tls: Arc<rustls::ClientConfig>,
pub tls: TlsConnector,
pub timeout: Duration,
}

View File

@@ -1,9 +1,8 @@
use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
pub use private::RustlsStream;
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
@@ -123,32 +122,30 @@ mod private {
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
pub struct MakeRustlsConnect {
pub config: TlsConnector,
pub struct MakeRustlsConnect<'a> {
pub connector: &'a TlsConnector,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
impl<'a> MakeRustlsConnect<'a> {
/// Creates a new `MakeRustlsConnect` from the provided `TlsConnector`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self {
config: config.into(),
}
pub fn new(connector: &'a TlsConnector) -> Self {
Self { connector }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
impl<'a, S> MakeTlsConnect<S> for MakeRustlsConnect<'a>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect<'a> = private::RustlsConnectData<'a>;
type TlsConnect = private::RustlsConnectData<'a>;
type Error = rustls::pki_types::InvalidDnsNameError;
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect<'_>, Self::Error> {
fn make_tls_connect(self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: &self.config,
connector: self.connector,
})
}
}

View File

@@ -15,6 +15,7 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::{pki_types, RootCertStore};
use tokio::io::DuplexStream;
use tokio_rustls::TlsConnector;
use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
@@ -67,15 +68,16 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: MakeRustlsConnect,
config: TlsConnector,
hostname: &'a str,
}
type TlsConnect<'a, S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect<'a>;
type TlsConnect<'a, S> = <MakeRustlsConnect<'a> as MakeTlsConnect<S>>::TlsConnect;
impl ClientConfig<'_> {
fn make_tls_connect(&self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&self.config, self.hostname)?;
let mk = MakeRustlsConnect::new(&self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(mk, self.hostname)?;
Ok(tls)
}
}
@@ -122,7 +124,7 @@ fn generate_tls_config<'a>(
let config = Arc::new(config);
ClientConfig {
config: MakeRustlsConnect::new(config),
config: TlsConnector::from(config),
hostname,
}
};
@@ -597,7 +599,7 @@ fn config() -> ComputeConfig {
ComputeConfig {
retry,
tls: Arc::new(client_config),
tls: Arc::new(client_config).into(),
timeout: Duration::from_secs(2),
}
}