[proxy] separate compute connect from compute authentication (#12145)

## Problem

PGLB/Neonkeeper needs to separate the concerns of connecting to compute,
and authenticating to compute.

Additionally, the code within `connect_to_compute` is rather messy,
spending effort on recovering the authentication info after
wake_compute.

## Summary of changes

Split `ConnCfg` into `ConnectInfo` and `AuthInfo`. `wake_compute` only
returns `ConnectInfo` and `AuthInfo` is determined separately from the
`handshake`/`authenticate` process.

Additionally, `ConnectInfo::connect_raw` is in-charge or establishing
the TLS connection, and the `postgres_client::Config::connect_raw` is
configured to use `NoTls` which will force it to skip the TLS
negotiation. This should just work.
This commit is contained in:
Conrad Ludgate
2025-06-06 11:29:55 +01:00
committed by GitHub
parent 590301df08
commit 4d99b6ff4d
24 changed files with 382 additions and 356 deletions

View File

@@ -18,11 +18,6 @@ pub(super) async fn authenticate(
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");

View File

@@ -6,10 +6,9 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::compute::AuthInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
@@ -98,15 +97,11 @@ impl ConsoleRedirectBackend {
ctx: &RequestContext,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(
ConsoleRedirectNodeInfo,
ComputeUserInfo,
Option<Vec<IpPattern>>,
)> {
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(|(node_info, user_info, ip_allowlist)| {
(ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist)
.map(|(node_info, auth_info, user_info)| {
(ConsoleRedirectNodeInfo(node_info), auth_info, user_info)
})
}
}
@@ -121,10 +116,6 @@ impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.0.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}
async fn authenticate(
@@ -132,7 +123,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<(NodeInfo, ComputeUserInfo, Option<Vec<IpPattern>>)> {
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
// registering waiter can fail if we get unlucky with rng.
@@ -192,10 +183,24 @@ async fn authenticate(
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
let ssl_mode = if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
SslMode::Require
} else {
SslMode::Disable
};
let conn_info = compute::ConnectInfo {
host: db_info.host.into(),
port: db_info.port,
ssl_mode,
host_addr: None,
};
let auth_info =
AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref());
let user: RoleName = db_info.user.into();
let user_info = ComputeUserInfo {
@@ -209,26 +214,12 @@ async fn authenticate(
ctx.set_project(db_info.aux.clone());
info!("woken up a compute node");
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok((
NodeInfo {
config,
conn_info,
aux: db_info.aux,
},
auth_info,
user_info,
db_info.allowed_ips,
))
}

View File

@@ -1,11 +1,12 @@
use std::net::SocketAddr;
use arc_swap::ArcSwapOption;
use postgres_client::config::SslMode;
use tokio::sync::Semaphore;
use super::jwt::{AuthRule, FetchAuthRules};
use crate::auth::backend::jwt::FetchAuthRulesError;
use crate::compute::ConnCfg;
use crate::compute::ConnectInfo;
use crate::compute_ctl::ComputeCtlApi;
use crate::context::RequestContext;
use crate::control_plane::NodeInfo;
@@ -29,7 +30,12 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
conn_info: ConnectInfo {
host_addr: Some(postgres_addr.ip()),
host: postgres_addr.ip().to_string().into(),
port: postgres_addr.port(),
ssl_mode: SslMode::Disable,
},
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),

View File

@@ -168,8 +168,6 @@ impl ComputeUserInfo {
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
@@ -419,13 +417,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}
#[cfg(test)]

View File

@@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange(
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;

View File

@@ -24,7 +24,6 @@ use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type IpSubnetKey = IpNet;
@@ -497,10 +496,8 @@ impl CancelClosure {
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let mut mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
compute_config,
&self.hostname,
)
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;

View File

@@ -1,21 +1,24 @@
mod tls;
use std::fmt::Debug;
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use std::net::{IpAddr, SocketAddr};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
@@ -25,7 +28,6 @@ use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -38,10 +40,7 @@ pub(crate) enum ConnectionError {
Postgres(#[from] postgres_client::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
CouldNotConnect(#[from] io::Error),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
TlsError(#[from] TlsError),
#[error("{COULD_NOT_CONNECT}: {0}")]
WakeComputeError(#[from] WakeComputeError),
@@ -73,7 +72,7 @@ impl UserFacingError for ConnectionError {
ConnectionError::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
_ => COULD_NOT_CONNECT.to_owned(),
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
}
}
}
@@ -85,7 +84,6 @@ impl ReportableError for ConnectionError {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -96,34 +94,85 @@ impl ReportableError for ConnectionError {
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
pub enum Auth {
/// Only used during console-redirect.
Password(Vec<u8>),
/// Used by sql-over-http, ws, tcp.
Scram(Box<ScramKeys>),
}
/// A config for authenticating to the compute node.
pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
/// Might be None for console-redirect, but that's only a consequence of testing environments ATM.
auth: Option<Auth>,
server_params: StartupMessageParams,
/// Console redirect sets user and database, we shouldn't re-use those from the params.
skip_db_user: bool,
}
/// Contains only the data needed to establish a secure connection to compute.
#[derive(Clone)]
pub struct ConnectInfo {
pub host_addr: Option<IpAddr>,
pub host: Host,
pub port: u16,
pub ssl_mode: SslMode,
}
/// Creation and initialization routines.
impl ConnCfg {
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}
/// Reuse password or auth keys from the other config.
pub(crate) fn reuse_password(&mut self, other: Self) {
if let Some(password) = other.get_password() {
self.password(password);
}
if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
impl AuthInfo {
pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self {
let mut server_params = StartupMessageParams::default();
server_params.insert("database", db);
server_params.insert("user", user);
Self {
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
server_params,
skip_db_user: true,
}
}
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(*auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,
}
}
}
impl ConnectInfo {
pub fn to_postgres_client_config(&self) -> postgres_client::Config {
let mut config = postgres_client::Config::new(self.host.to_string(), self.port);
config.ssl_mode(self.ssl_mode);
if let Some(host_addr) = self.host_addr {
config.set_host_addr(host_addr);
}
config
}
}
impl AuthInfo {
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
match &self.auth {
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
Some(Auth::Password(pw)) => config.password(pw),
None => &mut config,
};
for (k, v) in self.server_params.iter() {
config.set_param(k, v);
}
config
}
/// Apply startup message params to the connection config.
pub(crate) fn set_startup_params(
@@ -132,27 +181,26 @@ impl ConnCfg {
arbitrary_params: bool,
) {
if !arbitrary_params {
self.set_param("client_encoding", "UTF8");
self.server_params.insert("client_encoding", "UTF8");
}
for (k, v) in params.iter() {
match k {
// Only set `user` if it's not present in the config.
// Console redirect auth flow takes username from the console's response.
"user" if self.user_is_set() => {}
"database" if self.db_is_set() => {}
"user" | "database" if self.skip_db_user => {}
"options" => {
if let Some(options) = filtered_options(v) {
self.set_param(k, &options);
self.server_params.insert(k, &options);
}
}
"user" | "database" | "application_name" | "replication" => {
self.set_param(k, v);
self.server_params.insert(k, v);
}
// if we allow arbitrary params, then we forward them through.
// this is a flag for a period of backwards compatibility
k if arbitrary_params => {
self.set_param(k, v);
self.server_params.insert(k, v);
}
_ => {}
}
@@ -160,25 +208,13 @@ impl ConnCfg {
}
}
impl std::ops::Deref for ConnCfg {
type Target = postgres_client::Config;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// For now, let's make it easier to setup the config.
impl std::ops::DerefMut for ConnCfg {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
use postgres_client::config::Host;
impl ConnectInfo {
/// Establish a raw TCP+TLS connection to the compute node.
async fn connect_raw(
&self,
config: &ComputeConfig,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
// wrap TcpStream::connect with timeout
let connect_with_timeout = |addrs| {
@@ -208,34 +244,32 @@ impl ConnCfg {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let port = self.0.get_port();
let host = self.0.get_host();
let port = self.port;
let host = &*self.host;
let host = match host {
Host::Tcp(host) => host.as_str(),
};
let addrs = match self.0.get_host_addr() {
let addrs = match self.host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port)).await?.collect(),
};
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
Err(err)
Err(TlsError::Connection(err))
}
}
}
}
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
pub(crate) stream:
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// PostgreSQL connection parameters.
pub(crate) params: std::collections::HashMap<String, String>,
/// Query cancellation token.
@@ -248,28 +282,23 @@ pub(crate) struct PostgresConnection {
_guage: NumDbConnectionsGuard<'static>,
}
impl ConnCfg {
impl ConnectInfo {
/// Connect to a corresponding compute node.
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
auth: &AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
// we setup SSL early in `ConnectInfo::connect_raw`.
tmp_config.ssl_mode(SslMode::Disable);
let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = self.0.connect_raw(stream, tls).await?;
let (socket_addr, stream) = self.connect_raw(config).await?;
let connection = tmp_config.connect_raw(stream, NoTls).await?;
drop(pause);
let RawConnection {
@@ -282,13 +311,14 @@ impl ConnCfg {
tracing::Span::current().record("pid", tracing::field::display(process_id));
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
let stream = stream.into_inner();
let MaybeTlsStream::Raw(stream) = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.0.get_ssl_mode(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
@@ -299,11 +329,11 @@ impl ConnCfg {
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.0.get_ssl_mode(),
ssl_mode: self.ssl_mode,
process_id,
secret_key,
},
host.to_string(),
self.host.to_string(),
user_info,
);

63
proxy/src/compute/tls.rs Normal file
View File

@@ -0,0 +1,63 @@
use futures::FutureExt;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::{MakeTlsConnect, TlsConnect};
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::pqproto::request_tls;
use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
Required,
}
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
}
}
}
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: &T,
host: &str,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
T: MakeTlsConnect<
S,
Error = InvalidDnsNameError,
TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
}

View File

@@ -210,20 +210,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
let (node_info, mut auth_info, user_info) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Ok(auth_result) => auth_result,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
auth_info.set_startup_params(&params, true);
let node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
params_compat: true,
params: &params,
auth: auth_info,
locks: &config.connect_compute_locks,
},
&node_info,

View File

@@ -261,24 +261,18 @@ impl NeonControlPlaneClient {
Some(_) => SslMode::Require,
None => SslMode::Disable,
};
let host_name = match body.server_name {
Some(host) => host,
None => host.to_owned(),
let host = match body.server_name {
Some(host) => host.into(),
None => host.into(),
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new(host_name, port);
if let Some(addr) = host_addr {
config.set_host_addr(addr);
}
config.ssl_mode(ssl_mode);
let node = NodeInfo {
config,
conn_info: compute::ConnectInfo {
host_addr,
host,
port,
ssl_mode,
},
aux: body.aux,
};

View File

@@ -6,6 +6,7 @@ use std::str::FromStr;
use std::sync::Arc;
use futures::TryFutureExt;
use postgres_client::config::SslMode;
use thiserror::Error;
use tokio_postgres::Client;
use tracing::{Instrument, error, info, info_span, warn};
@@ -14,6 +15,7 @@ use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
@@ -24,9 +26,9 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::scram;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::url::ApiUrl;
use crate::{compute, scram};
#[derive(Debug, Error)]
enum MockApiError {
@@ -87,8 +89,7 @@ impl MockControlPlane {
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
scram::ServerSecret::parse(&entry).map(AuthSecret::Scram)
} else {
warn!("user '{role}' does not exist");
None
@@ -170,25 +171,23 @@ impl MockControlPlane {
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let port = self.endpoint.port().unwrap_or(5432);
let mut config = match self.endpoint.host_str() {
None => {
let mut config = compute::ConnCfg::new("localhost".to_string(), port);
config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST));
config
}
Some(host) => {
let mut config = compute::ConnCfg::new(host.to_string(), port);
if let Ok(addr) = IpAddr::from_str(host) {
config.set_host_addr(addr);
}
config
}
let conn_info = match self.endpoint.host_str() {
None => ConnectInfo {
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
host: "localhost".into(),
port,
ssl_mode: SslMode::Disable,
},
Some(host) => ConnectInfo {
host_addr: IpAddr::from_str(host).ok(),
host: host.into(),
port,
ssl_mode: SslMode::Disable,
},
};
config.ssl_mode(postgres_client::config::SslMode::Disable);
let node = NodeInfo {
config,
conn_info,
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -266,12 +265,3 @@ impl super::ControlPlaneApi for MockControlPlane {
self.do_wake_compute().map_ok(Cached::new_uncached).await
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {
let text = input.strip_prefix("md5")?;
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
Some(bytes)
}

View File

@@ -11,8 +11,8 @@ pub(crate) mod errors;
use std::sync::Arc;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
@@ -39,10 +39,6 @@ pub mod mgmt;
/// Auth secret which is managed by the cloud.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) enum AuthSecret {
#[cfg(any(test, feature = "testing"))]
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
@@ -63,13 +59,9 @@ pub(crate) struct AuthInfo {
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
#[derive(Clone)]
pub(crate) struct NodeInfo {
/// Compute node connection params.
/// It's sad that we have to clone this, but this will improve
/// once we migrate to a bespoke connection logic.
pub(crate) config: compute::ConnCfg,
pub(crate) conn_info: compute::ConnectInfo,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
@@ -79,26 +71,14 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
auth: &compute::AuthInfo,
config: &ComputeConfig,
user_info: ComputeUserInfo,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
.connect(ctx, self.aux.clone(), config, user_info)
self.conn_info
.connect(ctx, self.aux.clone(), auth, config, user_info)
.await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {
self.config.reuse_password(other.config);
}
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
};
}
}
#[derive(Copy, Clone, Default)]

View File

@@ -2,8 +2,8 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection};
use crate::auth::backend::ComputeUserInfo;
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
@@ -13,7 +13,6 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pqproto::StartupMessageParams;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
@@ -48,8 +47,6 @@ pub(crate) trait ConnectMechanism {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
@@ -58,24 +55,17 @@ pub(crate) trait ComputeConnectBackend {
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
fn get_keys(&self) -> &ComputeCredentialKeys;
}
pub(crate) struct TcpMechanism<'a> {
pub(crate) params_compat: bool,
/// KV-dictionary with PostgreSQL connection params.
pub(crate) params: &'a StartupMessageParams,
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
pub(crate) user_info: ComputeUserInfo,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism<'_> {
impl ConnectMechanism for TcpMechanism {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
@@ -90,13 +80,12 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params, self.params_compat);
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(
node_info
.connect(ctx, &self.auth, config, self.user_info.clone())
.await,
)
}
}
@@ -114,12 +103,9 @@ where
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let mut node_info =
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
@@ -155,14 +141,9 @@ where
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
debug!("compute node's state has likely changed; requesting a wake-up");
let old_node_info = invalidate_cache(node_info);
invalidate_cache(node_info);
// TODO: increment num_retries?
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);
node_info
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
};
// now that we have a new node, try connect to it repeatedly.

View File

@@ -8,7 +8,7 @@ use std::io::{self, Cursor};
use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
@@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion {
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
/// read the type from the stream using zerocopy.
///
/// not cancel safe.
@@ -66,32 +88,38 @@ macro_rules! read {
}};
}
/// Returns true if TLS is supported.
///
/// This is not cancel safe.
pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let payload = StartupHeader {
len: 8.into(),
version: NEGOTIATE_SSL_CODE,
};
stream.write_all(payload.as_bytes()).await?;
stream.flush().await?;
// we expect back either `S` or `N` as a single byte.
let mut res = *b"0";
stream.read_exact(&mut res).await?;
debug_assert!(
res == *b"S" || res == *b"N",
"unexpected SSL negotiation response: {}",
char::from(res[0]),
);
// S for SSL.
Ok(res == *b"S")
}
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
where
S: AsyncRead + Unpin,
{
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
let header = read!(stream => StartupHeader);
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
@@ -564,9 +592,8 @@ mod tests {
use tokio::io::{AsyncWriteExt, duplex};
use zerocopy::IntoBytes;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
use super::ProtocolVersion;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
#[tokio::test]
async fn reject_large_startup() {

View File

@@ -358,21 +358,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
params_compat,
params: &params,
user_info: creds.info.clone(),
auth: auth_info,
locks: &config.connect_compute_locks,
},
&user_info,

View File

@@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.could_retry(),
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
_ => false,
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
}
}
}

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -29,7 +29,6 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
@@ -72,13 +71,14 @@ struct ClientConfig<'a> {
hostname: &'a str,
}
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
type TlsConnect<S> = <ComputeConfig as MakeTlsConnect<S>>::TlsConnect;
impl ClientConfig<'_> {
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
Ok(crate::tls::postgres_rustls::make_tls_connect(
&self.config,
self.hostname,
)?)
}
}
@@ -497,8 +497,6 @@ impl ConnectMechanism for TestConnectMechanism {
x => panic!("expecting action {x:?}, connect is called instead"),
}
}
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
}
impl TestControlPlaneClient for TestConnectMechanism {
@@ -557,7 +555,12 @@ impl TestControlPlaneClient for TestConnectMechanism {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new("test".to_owned(), 5432),
conn_info: compute::ConnectInfo {
host: "test".into(),
port: 5432,
ssl_mode: SslMode::Disable,
host_addr: None,
},
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
@@ -581,7 +584,10 @@ fn helper_create_connect_info(
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
},
)
}

View File

@@ -23,7 +23,6 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
@@ -305,12 +304,13 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let (key, jwk) = create_random_jwk();
let config = node_info
.config
let mut config = local_backend
.node_info
.conn_info
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
@@ -322,7 +322,7 @@ impl PoolingBackend {
);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(postgres_client::NoTls).await?;
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
drop(pause);
let pid = client.get_process_id();
@@ -336,7 +336,7 @@ impl PoolingBackend {
connection,
key,
conn_id,
node_info.aux.clone(),
local_backend.node_info.aux.clone(),
);
{
@@ -512,19 +512,16 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = (*node_info.config).clone();
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
let mk_tls =
crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(mk_tls).await;
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -548,8 +545,6 @@ impl ConnectMechanism for TokioMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
@@ -573,20 +568,20 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.config.get_host_addr();
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.config.get_port();
let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
@@ -609,8 +604,6 @@ impl ConnectMechanism for HyperMechanism {
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(

View File

@@ -23,12 +23,12 @@ use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {

View File

@@ -2,10 +2,11 @@ use std::convert::TryFrom;
use std::sync::Arc;
use postgres_client::tls::MakeTlsConnect;
use rustls::ClientConfig;
use rustls::pki_types::ServerName;
use rustls::pki_types::{InvalidDnsNameError, ServerName};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::config::ComputeConfig;
mod private {
use std::future::Future;
use std::io;
@@ -123,36 +124,27 @@ mod private {
}
}
/// A `MakeTlsConnect` implementation using `rustls`.
///
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
impl<S> MakeTlsConnect<S> for ComputeConfig
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = private::RustlsStream<S>;
type TlsConnect = private::RustlsConnect;
type Error = rustls::pki_types::InvalidDnsNameError;
type Error = InvalidDnsNameError;
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
})
})
fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
make_tls_connect(&self.tls, hostname)
}
}
pub fn make_tls_connect(
tls: &Arc<rustls::ClientConfig>,
hostname: &str,
) -> Result<private::RustlsConnect, InvalidDnsNameError> {
ServerName::try_from(hostname).map(|dns_name| {
private::RustlsConnect(private::RustlsConnectData {
hostname: dns_name.to_owned(),
connector: tls.clone().into(),
})
})
}