move authenticate outside of tokiomechanism

This commit is contained in:
Conrad Ludgate
2025-06-24 16:17:46 +01:00
parent 517a3d0d86
commit 16d9889a51
5 changed files with 107 additions and 100 deletions

View File

@@ -1,12 +1,13 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::config::{Host, SslMode};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::connect_tls::connect_tls;
@@ -46,13 +47,7 @@ where
{
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
let raw = connect_raw(stream, config).await?;
let socket_config = SocketConfig {
host_addr,
@@ -61,24 +56,46 @@ where
connect_timeout: config.connect_timeout,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
Ok((client, connection))
Ok(raw.into_managed_conn(socket_config, config.ssl_mode))
}
impl<S, T> RawConnection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn into_managed_conn(
self,
socket_config: SocketConfig,
ssl_mode: SslMode,
) -> (Client, Connection<S, T>) {
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = self;
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
(client, connection)
}
}

View File

@@ -357,7 +357,7 @@ pub struct PostgresSettings {
pub struct ComputeConnection {
/// Socket connected to a compute node.
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub stream: MaybeRustlsStream,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
pub hostname: Host,

View File

@@ -11,8 +11,6 @@ use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
@@ -22,7 +20,6 @@ pub enum TlsError {
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
@@ -57,7 +54,6 @@ where
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
let c = tls.make_tls_connect(host).map_err(std::io::Error::other)?;
Ok(MaybeTlsStream::Tls(c.connect(stream).boxed().await?))
}

View File

@@ -263,7 +263,12 @@ impl NeonControlPlaneClient {
None => SslMode::Disable,
};
let host = match body.server_name {
Some(host) => host.into(),
Some(host) => {
if rustls::pki_types::DnsName::try_from_str(&host).is_err() {
return Err(WakeComputeError::BadComputeAddress(host.into_boxed_str()));
}
host.into()
}
None => host.into(),
};

View File

@@ -7,6 +7,7 @@ use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::SocketConfig;
use postgres_client::config::SslMode;
use rand::rngs::OsRng;
use rustls::pki_types::{DnsName, ServerName};
@@ -23,6 +24,7 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute::{self, ComputeConnection};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
@@ -34,7 +36,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
@@ -184,20 +186,18 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let connection = crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
&TcpMechanism {
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await
.await?;
authenticate(ctx, &self.pool, &conn_info, keys.keys, connection, conn_id).await
}
// Wake up the destination if needed
@@ -373,6 +373,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connect to compute")]
ConnectError(#[from] compute::ConnectionError),
#[error("could not connect to postgres in compute")]
PostgresConnectionError(#[from] postgres_client::Error),
#[error("could not connect to local-proxy in compute")]
@@ -403,6 +405,7 @@ pub(crate) enum LocalProxyConnError {
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
@@ -419,6 +422,7 @@ impl ReportableError for HttpConnError {
impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectError(p) => p.to_string_client(),
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
@@ -437,6 +441,7 @@ impl UserFacingError for HttpConnError {
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::ConnectError(e) => e.could_retry(),
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
@@ -492,65 +497,49 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
async fn authenticate(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: &ConnInfo,
keys: ComputeCredentialKeys,
compute: ComputeConnection,
conn_id: uuid::Uuid,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
// client config with stubbed connect info.
let mut config = postgres_client::Config::new(String::new(), 0);
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<postgres_client::Client>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
if let ComputeCredentialKeys::AuthKeys(auth_keys) = keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = config.authenticate(compute.stream).await?;
drop(pause);
let (client, connection) = connection.into_managed_conn(
SocketConfig {
host_addr: Some(compute.socket_addr.ip()),
host: postgres_client::config::Host::Tcp(compute.hostname.to_string()),
port: compute.socket_addr.port(),
connect_timeout: None,
},
compute.ssl_mode,
);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
Ok(poll_client(
pool.clone(),
ctx,
conn_info.clone(),
client,
connection,
conn_id,
compute.aux,
))
}
struct HyperMechanism {