[proxy] remove TokioMechanism and HyperMechanism (#12672)

Another go at #12341. LKB-2497

We now only need 1 connect mechanism (and 1 more for testing) which
saves us some code and complexity. We should be able to remove the final
connect mechanism when we create a separate worker task for
pglb->compute connections - either via QUIC streams or via in-memory
channels.

This also now ensures that connect_once always returns a ConnectionError
type - something simple enough we can probably define a serialisation
for in pglb.

* I've abstracted connect_to_compute to always use TcpMechanism and the
ProxyConfig.
* I've abstracted connect_to_compute_and_auth to perform authentication,
managing any retries for stale computes
* I had to introduce a separate `managed` function for taking ownership
of the compute connection into the Client/Connection pair
This commit is contained in:
Conrad Ludgate
2025-07-24 13:37:04 +01:00
committed by GitHub
parent ab14521ea5
commit 8daebb6ed4
12 changed files with 315 additions and 468 deletions

View File

@@ -7,7 +7,7 @@ use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::config::Host;
use crate::config::{Host, SslMode};
use crate::connect_raw::StartupStream;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
@@ -45,14 +45,36 @@ where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
let mut stream = config.tls_and_authenticate(socket, tls).await?;
let stream = config.tls_and_authenticate(socket, tls).await?;
managed(
stream,
host_addr,
host.clone(),
port,
config.ssl_mode,
config.connect_timeout,
)
.await
}
pub async fn managed<TlsStream>(
mut stream: StartupStream<TcpStream, TlsStream>,
host_addr: Option<IpAddr>,
host: Host,
port: u16,
ssl_mode: SslMode,
connect_timeout: Option<std::time::Duration>,
) -> Result<(Client, Connection<TcpStream, TlsStream>), Error>
where
TlsStream: AsyncRead + AsyncWrite + Unpin,
{
let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
let socket_config = SocketConfig {
host_addr,
host: host.clone(),
host,
port,
connect_timeout: config.connect_timeout,
connect_timeout,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
@@ -61,7 +83,7 @@ where
client_tx,
client_rx,
socket_config,
config.ssl_mode,
ssl_mode,
process_id,
secret_key,
);

View File

@@ -48,7 +48,7 @@ mod cancel_token;
mod client;
mod codec;
pub mod config;
mod connect;
pub mod connect;
pub mod connect_raw;
mod connect_socket;
mod connect_tls;

View File

@@ -25,6 +25,7 @@ use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::connect_compute::TlsNegotiation;
use crate::proxy::neon_option;
use crate::types::Host;
@@ -84,6 +85,14 @@ pub(crate) enum ConnectionError {
#[error("error acquiring resource permit: {0}")]
TooManyConnectionAttempts(#[from] ApiLockError),
#[cfg(test)]
#[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")]
TestError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
},
}
impl UserFacingError for ConnectionError {
@@ -94,6 +103,8 @@ impl UserFacingError for ConnectionError {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
#[cfg(test)]
ConnectionError::TestError { .. } => self.to_string(),
}
}
}
@@ -104,6 +115,8 @@ impl ReportableError for ConnectionError {
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]
ConnectionError::TestError { kind, .. } => *kind,
}
}
}
@@ -256,6 +269,7 @@ impl ConnectInfo {
async fn connect_raw(
&self,
config: &ComputeConfig,
tls: TlsNegotiation,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
@@ -298,7 +312,7 @@ impl ConnectInfo {
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
@@ -329,9 +343,10 @@ impl ConnectInfo {
ctx: &RequestContext,
aux: &MetricsAuxInfo,
config: &ComputeConfig,
tls: TlsNegotiation,
) -> Result<ComputeConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let (socket_addr, stream) = self.connect_raw(config, tls).await?;
drop(pause);
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));

View File

@@ -7,6 +7,7 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::pqproto::request_tls;
use crate::proxy::connect_compute::TlsNegotiation;
use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
@@ -35,6 +36,7 @@ pub async fn connect_tls<S, T>(
mode: SslMode,
tls: &T,
host: &str,
negotiation: TlsNegotiation,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
@@ -49,12 +51,15 @@ where
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
return Ok(MaybeTlsStream::Raw(stream));
match negotiation {
// No TLS request needed
TlsNegotiation::Direct => {}
// TLS request successful
TlsNegotiation::Postgres if request_tls(&mut stream).await? => {}
// TLS request failed but is required
TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required),
// TLS request failed but is not required
TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)),
}
Ok(MaybeTlsStream::Tls(

View File

@@ -16,8 +16,9 @@ use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting};
use crate::proxy::{
ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting,
};
use crate::util::run_until_cancelled;
pub async fn task_main(
@@ -215,14 +216,11 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
};
auth_info.set_startup_params(&params, true);
let mut node = connect_to_compute(
let mut node = connect_compute::connect_to_compute(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
},
config,
&node_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
connect_compute::TlsNegotiation::Postgres,
)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;

View File

@@ -17,7 +17,6 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
@@ -72,16 +71,6 @@ pub(crate) struct NodeInfo {
pub(crate) aux: MetricsAuxInfo,
}
impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
config: &ComputeConfig,
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
self.conn_info.connect(ctx, &self.aux, config).await
}
}
#[derive(Copy, Clone, Default, Debug)]
pub(crate) struct AccessBlockerFlags {
pub public_access_blocked: bool,

View File

@@ -0,0 +1,82 @@
use thiserror::Error;
use crate::auth::Backend;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cache;
use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::error::{ReportableError, UserFacingError};
use crate::proxy::connect_compute::{TlsNegotiation, connect_to_compute};
use crate::proxy::retry::ShouldRetryWakeCompute;
#[derive(Debug, Error)]
pub enum AuthError {
#[error(transparent)]
Auth(#[from] PostgresError),
#[error(transparent)]
Connect(#[from] ConnectionError),
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
match self {
AuthError::Auth(postgres_error) => postgres_error.to_string_client(),
AuthError::Connect(connection_error) => connection_error.to_string_client(),
}
}
}
impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
AuthError::Auth(postgres_error) => postgres_error.get_error_kind(),
AuthError::Connect(connection_error) => connection_error.get_error_kind(),
}
}
}
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute_and_auth(
ctx: &RequestContext,
config: &ProxyConfig,
user_info: &Backend<'_, ComputeUserInfo>,
auth_info: AuthInfo,
tls: TlsNegotiation,
) -> Result<ComputeConnection, AuthError> {
let mut attempt = 0;
// NOTE: This is messy, but should hopefully be detangled with PGLB.
// We wanted to separate the concerns of **connect** to compute (a PGLB operation),
// from **authenticate** to compute (a NeonKeeper operation).
//
// This unfortunately removed retry handling for one error case where
// the compute was cached, and we connected, but the compute cache was actually stale
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
loop {
attempt += 1;
let mut node = connect_to_compute(ctx, config, user_info, tls).await?;
let res = auth_info.authenticate(ctx, &mut node).await;
match res {
Ok(()) => return Ok(node),
Err(e) => {
if attempt < 2
&& let Backend::ControlPlane(cplane, user_info) = user_info
&& let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane
&& e.should_retry_wake_compute()
{
tracing::warn!(error = ?e, "retrying wake compute");
let key = user_info.endpoint_cache_key();
cplane_proxy_v1.caches.node_info.invalidate(&key);
continue;
}
return Err(e)?;
}
}
}
}

View File

@@ -1,18 +1,15 @@
use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::types::Host;
@@ -35,29 +32,32 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node
node_info.invalidate()
}
#[async_trait]
pub(crate) trait ConnectMechanism {
type Connection;
type ConnectError: ReportableError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
) -> Result<Self::Connection, compute::ConnectionError>;
}
pub(crate) struct TcpMechanism {
struct TcpMechanism<'a> {
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
locks: &'a ApiLocks<Host>,
tls: TlsNegotiation,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum TlsNegotiation {
/// TLS is assumed
Direct,
/// We must ask for TLS using the postgres SSLRequest message
Postgres,
}
impl ConnectMechanism for TcpMechanism<'_> {
type Connection = ComputeConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
@@ -68,25 +68,47 @@ impl ConnectMechanism for TcpMechanism {
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<ComputeConnection, Self::Error> {
) -> Result<ComputeConnection, compute::ConnectionError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(
node_info
.conn_info
.connect(ctx, &node_info.aux, config, self.tls)
.await,
)
}
}
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
pub(crate) async fn connect_to_compute<B: WakeComputeBackend>(
ctx: &RequestContext,
config: &ProxyConfig,
user_info: &B,
tls: TlsNegotiation,
) -> Result<ComputeConnection, compute::ConnectionError> {
connect_to_compute_inner(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
tls,
},
user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await
}
/// Try to connect to the compute node, retrying if necessary.
pub(crate) async fn connect_to_compute_inner<M: ConnectMechanism, B: WakeComputeBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
) -> Result<M::Connection, compute::ConnectionError> {
let mut num_retries = 0;
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
@@ -120,7 +142,7 @@ where
},
num_retries.into(),
);
return Err(err.into());
return Err(err);
}
node_info
} else {
@@ -161,7 +183,7 @@ where
},
num_retries.into(),
);
return Err(e.into());
return Err(e);
}
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);

View File

@@ -1,6 +1,7 @@
#[cfg(test)]
mod tests;
pub(crate) mod connect_auth;
pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
@@ -23,17 +24,13 @@ use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tracing::Instrument;
use crate::cache::Cache;
use crate::cancellation::{CancelClosure, CancellationHandler};
use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::{ClientMode, ClientRequestError};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::retry::ShouldRetryWakeCompute;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
@@ -95,61 +92,24 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(params, params_compat);
let mut node;
let mut attempt = 0;
let connect = TcpMechanism {
locks: &config.connect_compute_locks,
};
let backend = auth::Backend::ControlPlane(cplane, creds.info);
// NOTE: This is messy, but should hopefully be detangled with PGLB.
// We wanted to separate the concerns of **connect** to compute (a PGLB operation),
// from **authenticate** to compute (a NeonKeeper operation).
//
// This unfortunately removed retry handling for one error case where
// the compute was cached, and we connected, but the compute cache was actually stale
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
loop {
attempt += 1;
// TODO: callback to pglb
let res = connect_auth::connect_to_compute_and_auth(
ctx,
config,
&backend,
auth_info,
connect_compute::TlsNegotiation::Postgres,
)
.await;
// TODO: callback to pglb
let res = connect_to_compute(
ctx,
&connect,
&backend,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
let mut node = match res {
Ok(node) => node,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
match res {
Ok(n) => node = n,
Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?,
}
let auth::Backend::ControlPlane(cplane, user_info) = &backend else {
unreachable!("ensured above");
};
let res = auth_info.authenticate(ctx, &mut node).await;
match res {
Ok(()) => {
send_client_greeting(ctx, &config.greetings, client);
break;
}
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
tracing::warn!(error = ?e, "retrying wake compute");
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane {
let key = user_info.endpoint_cache_key();
cplane_proxy_v1.caches.node_info.invalidate(&key);
}
}
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
}
}
send_client_greeting(ctx, &config.greetings, client);
let auth::Backend::ControlPlane(_, user_info) = backend else {
unreachable!("ensured above");

View File

@@ -31,18 +31,6 @@ impl CouldRetry for io::Error {
}
}
impl CouldRetry for postgres_client::error::DbError {
fn could_retry(&self) -> bool {
use postgres_client::error::SqlState;
matches!(
self.code(),
&SqlState::CONNECTION_FAILURE
| &SqlState::CONNECTION_EXCEPTION
| &SqlState::CONNECTION_DOES_NOT_EXIST
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
}
}
impl ShouldRetryWakeCompute for postgres_client::error::DbError {
fn should_retry_wake_compute(&self) -> bool {
use postgres_client::error::SqlState;
@@ -73,17 +61,6 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError {
}
}
impl CouldRetry for postgres_client::Error {
fn could_retry(&self) -> bool {
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
io::Error::could_retry(io_err)
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::could_retry(db_err)
} else {
false
}
}
}
impl ShouldRetryWakeCompute for postgres_client::Error {
fn should_retry_wake_compute(&self) -> bool {
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
@@ -102,6 +79,8 @@ impl CouldRetry for compute::ConnectionError {
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { retryable, .. } => *retryable,
}
}
}
@@ -110,6 +89,8 @@ impl ShouldRetryWakeCompute for compute::ConnectionError {
match self {
// the cache entry was not checked for validity
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { wakeable, .. } => *wakeable,
_ => true,
}
}

View File

@@ -24,13 +24,13 @@ use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::{ErrorKind, ReportableError};
use crate::error::ErrorKind;
use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after};
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute_inner};
use crate::proxy::retry::retry_after;
use crate::stream::{PqStream, Stream};
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
@@ -430,71 +430,36 @@ impl TestConnectMechanism {
#[derive(Debug)]
struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
}
impl ReportableError for TestConnectError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
self.kind
}
}
impl std::fmt::Display for TestConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for TestConnectError {}
impl CouldRetry for TestConnectError {
fn could_retry(&self) -> bool {
self.retryable
}
}
impl ShouldRetryWakeCompute for TestConnectError {
fn should_retry_wake_compute(&self) -> bool {
self.wakeable
}
}
#[async_trait]
impl ConnectMechanism for TestConnectMechanism {
type Connection = TestConnection;
type ConnectError = TestConnectError;
type Error = anyhow::Error;
async fn connect_once(
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
) -> Result<Self::Connection, compute::ConnectionError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError {
ConnectAction::Retry => Err(compute::ConnectionError::TestError {
retryable: true,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::RetryNoWake => Err(TestConnectError {
ConnectAction::RetryNoWake => Err(compute::ConnectionError::TestError {
retryable: true,
wakeable: false,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(TestConnectError {
ConnectAction::Fail => Err(compute::ConnectionError::TestError {
retryable: false,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::FailNoWake => Err(TestConnectError {
ConnectAction::FailNoWake => Err(compute::ConnectionError::TestError {
retryable: false,
wakeable: false,
kind: ErrorKind::Compute,
@@ -620,7 +585,7 @@ async fn connect_to_compute_success() {
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -634,7 +599,7 @@ async fn connect_to_compute_retry() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -649,7 +614,7 @@ async fn connect_to_compute_non_retry_1() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -664,7 +629,7 @@ async fn connect_to_compute_non_retry_2() {
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -686,7 +651,7 @@ async fn connect_to_compute_non_retry_3() {
backoff_factor: 2.0,
};
let config = config();
connect_to_compute(
connect_to_compute_inner(
&ctx,
&mechanism,
&user_info,
@@ -707,7 +672,7 @@ async fn wake_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -722,7 +687,7 @@ async fn wake_non_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -741,7 +706,7 @@ async fn fail_but_wake_invalidates_cache() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
@@ -762,7 +727,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
@@ -783,7 +748,7 @@ async fn retry_but_wake_invalidates_cache() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
@@ -806,7 +771,7 @@ async fn retry_no_wake_skips_invalidation() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();
@@ -829,7 +794,7 @@ async fn retry_no_wake_error_fast() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();
@@ -852,7 +817,7 @@ async fn retry_cold_wake_skips_invalidation() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();

View File

@@ -1,17 +1,11 @@
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use rand_core::OsRng;
use rustls::pki_types::{DnsName, ServerName};
use tokio::net::{TcpStream, lookup_host};
use tokio_rustls::TlsConnector;
use tracing::field::display;
use tracing::{debug, info};
@@ -21,23 +15,22 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{connect_auth, connect_compute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool:
@@ -186,20 +179,42 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let mut params = StartupMessageParams::default();
params.insert("database", &conn_info.dbname);
params.insert("user", &conn_info.user_info.user);
let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys);
auth_info.set_startup_params(&params, true);
let node = connect_auth::connect_to_compute_and_auth(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
self.config,
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
auth_info,
connect_compute::TlsNegotiation::Postgres,
)
.await
.await?;
let (client, connection) = postgres_client::connect::managed(
node.stream,
Some(node.socket_addr.ip()),
postgres_client::config::Host::Tcp(node.hostname.to_string()),
node.socket_addr.port(),
node.ssl_mode,
Some(self.config.connect_to_compute.timeout),
)
.await?;
Ok(poll_client(
self.pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
node.aux,
))
}
// Wake up the destination if needed
@@ -228,19 +243,38 @@ impl PoolingBackend {
)),
options: conn_info.user_info.options.clone(),
});
crate::proxy::connect_compute::connect_to_compute(
let node = connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
self.config,
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
connect_compute::TlsNegotiation::Direct,
)
.await
.await?;
let stream = match node.stream.into_framed().into_inner() {
MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW,
MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW,
};
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await
.map_err(LocalProxyConnError::H2)?;
Ok(poll_http2_client(
self.http_conn_pool.clone(),
ctx,
&conn_info,
client,
connection,
conn_id,
node.aux.clone(),
))
}
/// Connect to postgres over localhost.
@@ -380,6 +414,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connect to compute")]
ConnectError(#[from] compute::ConnectionError),
#[error("could not connect to postgres in compute")]
PostgresConnectionError(#[from] postgres_client::Error),
#[error("could not connect to local-proxy in compute")]
@@ -399,10 +435,19 @@ pub(crate) enum HttpConnError {
TooManyConnectionAttempts(#[from] ApiLockError),
}
impl From<connect_auth::AuthError> for HttpConnError {
fn from(value: connect_auth::AuthError) -> Self {
match value {
connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => {
Self::PostgresConnectionError(error)
}
connect_auth::AuthError::Connect(error) => Self::ConnectError(error),
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper::Error),
}
@@ -410,6 +455,7 @@ pub(crate) enum LocalProxyConnError {
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => {
if p.as_db_error().is_some() {
@@ -434,6 +480,7 @@ impl ReportableError for HttpConnError {
impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectError(p) => p.to_string_client(),
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
@@ -449,36 +496,9 @@ impl UserFacingError for HttpConnError {
}
}
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
HttpConnError::WakeCompute(_) => false,
HttpConnError::TooManyConnectionAttempts(_) => false,
}
}
}
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
}
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
@@ -489,215 +509,3 @@ impl UserFacingError for LocalProxyConnError {
"Could not establish HTTP connection to the database".to_string()
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<postgres_client::Client>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<LocalProxyClient>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
async fn connect_http2(
host_addr: Option<IpAddr>,
host: &str,
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<
(
http_conn_pool::LocalProxyClient,
http_conn_pool::LocalProxyConnection,
),
LocalProxyConnError,
> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?
.collect(),
};
let mut last_err = None;
let mut addrs = addrs.into_iter();
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
}
};
let stream = if let Some(tls) = tls {
let host = DnsName::try_from(host)
.map_err(io::Error::other)
.map_err(LocalProxyConnError::Io)?
.to_owned();
let stream = TlsConnector::from(tls.clone())
.connect(ServerName::DnsName(host), stream)
.await
.map_err(LocalProxyConnError::Io)?;
Box::pin(stream) as AsyncRW
} else {
Box::pin(stream) as AsyncRW
};
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}