mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-29 19:10:38 +00:00
Another go at #12341. LKB-2497 We now only need 1 connect mechanism (and 1 more for testing) which saves us some code and complexity. We should be able to remove the final connect mechanism when we create a separate worker task for pglb->compute connections - either via QUIC streams or via in-memory channels. This also now ensures that connect_once always returns a ConnectionError type - something simple enough we can probably define a serialisation for in pglb. * I've abstracted connect_to_compute to always use TcpMechanism and the ProxyConfig. * I've abstracted connect_to_compute_and_auth to perform authentication, managing any retries for stale computes * I had to introduce a separate `managed` function for taking ownership of the compute connection into the Client/Connection pair
69 lines
2.0 KiB
Rust
69 lines
2.0 KiB
Rust
use futures::FutureExt;
|
|
use postgres_client::config::SslMode;
|
|
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
|
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
|
|
use rustls::pki_types::InvalidDnsNameError;
|
|
use thiserror::Error;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
|
|
use crate::pqproto::request_tls;
|
|
use crate::proxy::connect_compute::TlsNegotiation;
|
|
use crate::proxy::retry::CouldRetry;
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum TlsError {
|
|
#[error(transparent)]
|
|
Dns(#[from] InvalidDnsNameError),
|
|
#[error(transparent)]
|
|
Connection(#[from] std::io::Error),
|
|
#[error("TLS required but not provided")]
|
|
Required,
|
|
}
|
|
|
|
impl CouldRetry for TlsError {
|
|
fn could_retry(&self) -> bool {
|
|
match self {
|
|
TlsError::Dns(_) => false,
|
|
TlsError::Connection(err) => err.could_retry(),
|
|
// perhaps compute didn't realise it supports TLS?
|
|
TlsError::Required => true,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn connect_tls<S, T>(
|
|
mut stream: S,
|
|
mode: SslMode,
|
|
tls: &T,
|
|
host: &str,
|
|
negotiation: TlsNegotiation,
|
|
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send,
|
|
T: MakeTlsConnect<
|
|
S,
|
|
Error = InvalidDnsNameError,
|
|
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
|
|
>,
|
|
{
|
|
match mode {
|
|
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
|
|
SslMode::Prefer | SslMode::Require => {}
|
|
}
|
|
|
|
match negotiation {
|
|
// No TLS request needed
|
|
TlsNegotiation::Direct => {}
|
|
// TLS request successful
|
|
TlsNegotiation::Postgres if request_tls(&mut stream).await? => {}
|
|
// TLS request failed but is required
|
|
TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required),
|
|
// TLS request failed but is not required
|
|
TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)),
|
|
}
|
|
|
|
Ok(MaybeTlsStream::Tls(
|
|
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
|
|
))
|
|
}
|