Files
neon/proxy/src/compute/tls.rs
Conrad Ludgate 4d99b6ff4d [proxy] separate compute connect from compute authentication (#12145)
## Problem

PGLB/Neonkeeper needs to separate the concerns of connecting to compute,
and authenticating to compute.

Additionally, the code within `connect_to_compute` is rather messy,
spending effort on recovering the authentication info after
wake_compute.

## Summary of changes

Split `ConnCfg` into `ConnectInfo` and `AuthInfo`. `wake_compute` only
returns `ConnectInfo` and `AuthInfo` is determined separately from the
`handshake`/`authenticate` process.

Additionally, `ConnectInfo::connect_raw` is in-charge or establishing
the TLS connection, and the `postgres_client::Config::connect_raw` is
configured to use `NoTls` which will force it to skip the TLS
negotiation. This should just work.
2025-06-06 10:29:55 +00:00

64 lines
1.7 KiB
Rust

use futures::FutureExt;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::pqproto::request_tls;
use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
Required,
}
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
}
}
}
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: &T,
host: &str,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
T: MakeTlsConnect<
S,
Error = InvalidDnsNameError,
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
}