mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-19 14:10:37 +00:00
reduce cloning
This commit is contained in:
@@ -9,7 +9,7 @@ use std::io;
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
mut tls: T,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
|
||||
@@ -10,7 +10,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
|
||||
@@ -39,14 +39,14 @@ pub trait MakeTlsConnect<S> {
|
||||
/// The stream type created by the `TlsConnect` implementation.
|
||||
type Stream: TlsStream + Unpin;
|
||||
/// The `TlsConnect` implementation created by this type.
|
||||
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
|
||||
type TlsConnect<'a>: TlsConnect<S, Stream = Self::Stream> where Self: 'a;
|
||||
/// The error type returned by the `TlsConnect` implementation.
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
|
||||
/// Creates a new `TlsConnect`or.
|
||||
///
|
||||
/// The domain name is provided for certificate verification and SNI.
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
fn make_tls_connect<'a>(&'a self, domain: &str) -> Result<Self::TlsConnect<'a>, Self::Error>;
|
||||
}
|
||||
|
||||
/// An asynchronous function wrapping a stream in a TLS session.
|
||||
@@ -81,10 +81,10 @@ pub struct NoTls;
|
||||
|
||||
impl<S> MakeTlsConnect<S> for NoTls {
|
||||
type Stream = NoTlsStream;
|
||||
type TlsConnect = NoTls;
|
||||
type TlsConnect<'a> = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,9 +271,9 @@ impl CancelClosure {
|
||||
) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
&mk_tls,
|
||||
&self.hostname,
|
||||
)
|
||||
.map_err(|e| {
|
||||
|
||||
@@ -251,10 +251,9 @@ impl ConnCfg {
|
||||
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
|
||||
drop(pause);
|
||||
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
|
||||
let mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
&mk_tls, host,
|
||||
)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
|
||||
@@ -5,6 +5,7 @@ use postgres_client::tls::MakeTlsConnect;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::ClientConfig;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
@@ -35,14 +36,12 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsConnect(pub RustlsConnectData);
|
||||
|
||||
pub struct RustlsConnectData {
|
||||
pub struct RustlsConnectData<'a> {
|
||||
pub hostname: ServerName<'static>,
|
||||
pub connector: TlsConnector,
|
||||
pub connector: &'a TlsConnector,
|
||||
}
|
||||
|
||||
impl<S> TlsConnect<S> for RustlsConnect
|
||||
impl<S> TlsConnect<S> for RustlsConnectData<'_>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
@@ -52,7 +51,7 @@ mod private {
|
||||
|
||||
fn connect(self, stream: S) -> Self::Future {
|
||||
TlsConnectFuture {
|
||||
inner: self.0.connector.connect(self.0.hostname, stream),
|
||||
inner: self.connector.connect(self.hostname, stream),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -124,16 +123,17 @@ mod private {
|
||||
/// A `MakeTlsConnect` implementation using `rustls`.
|
||||
///
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
pub config: Arc<ClientConfig>,
|
||||
pub config: TlsConnector,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
#[must_use]
|
||||
pub fn new(config: Arc<ClientConfig>) -> Self {
|
||||
Self { config }
|
||||
Self {
|
||||
config: config.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,15 +142,13 @@ where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = private::RustlsStream<S>;
|
||||
type TlsConnect = private::RustlsConnect;
|
||||
type TlsConnect<'a> = private::RustlsConnectData<'a>;
|
||||
type Error = rustls::pki_types::InvalidDnsNameError;
|
||||
|
||||
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: Arc::clone(&self.config).into(),
|
||||
})
|
||||
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect<'_>, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: &self.config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,16 +67,15 @@ fn generate_certs(
|
||||
}
|
||||
|
||||
struct ClientConfig<'a> {
|
||||
config: Arc<rustls::ClientConfig>,
|
||||
config: MakeRustlsConnect,
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
|
||||
type TlsConnect<'a, S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect<'a>;
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
fn make_tls_connect(&self) -> anyhow::Result<TlsConnect<DuplexStream>> {
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&self.config, self.hostname)?;
|
||||
Ok(tls)
|
||||
}
|
||||
}
|
||||
@@ -122,7 +121,10 @@ fn generate_tls_config<'a>(
|
||||
.with_no_client_auth();
|
||||
let config = Arc::new(config);
|
||||
|
||||
ClientConfig { config, hostname }
|
||||
ClientConfig {
|
||||
config: MakeRustlsConnect::new(config),
|
||||
hostname,
|
||||
}
|
||||
};
|
||||
|
||||
Ok((client_config, tls_config))
|
||||
|
||||
Reference in New Issue
Block a user