From 8daebb6ed4022e4c984a3ab166850de87d6563f8 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 24 Jul 2025 13:37:04 +0100 Subject: [PATCH] [proxy] remove TokioMechanism and HyperMechanism (#12672) Another go at #12341. LKB-2497 We now only need 1 connect mechanism (and 1 more for testing) which saves us some code and complexity. We should be able to remove the final connect mechanism when we create a separate worker task for pglb->compute connections - either via QUIC streams or via in-memory channels. This also now ensures that connect_once always returns a ConnectionError type - something simple enough we can probably define a serialisation for in pglb. * I've abstracted connect_to_compute to always use TcpMechanism and the ProxyConfig. * I've abstracted connect_to_compute_and_auth to perform authentication, managing any retries for stale computes * I had to introduce a separate `managed` function for taking ownership of the compute connection into the Client/Connection pair --- libs/proxy/tokio-postgres2/src/connect.rs | 32 +- libs/proxy/tokio-postgres2/src/lib.rs | 2 +- proxy/src/compute/mod.rs | 19 +- proxy/src/compute/tls.rs | 17 +- proxy/src/console_redirect_proxy.rs | 14 +- proxy/src/control_plane/mod.rs | 11 - proxy/src/proxy/connect_auth.rs | 82 +++++ proxy/src/proxy/connect_compute.rs | 72 +++-- proxy/src/proxy/mod.rs | 70 +---- proxy/src/proxy/retry.rs | 27 +- proxy/src/proxy/tests/mod.rs | 77 ++--- proxy/src/serverless/backend.rs | 360 +++++----------------- 12 files changed, 315 insertions(+), 468 deletions(-) create mode 100644 proxy/src/proxy/connect_auth.rs diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 41d95c5f84..ca6f69f049 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::client::SocketConfig; -use crate::config::Host; +use crate::config::{Host, SslMode}; use crate::connect_raw::StartupStream; use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; @@ -45,14 +45,36 @@ where T: TlsConnect, { let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; - let mut stream = config.tls_and_authenticate(socket, tls).await?; + let stream = config.tls_and_authenticate(socket, tls).await?; + managed( + stream, + host_addr, + host.clone(), + port, + config.ssl_mode, + config.connect_timeout, + ) + .await +} + +pub async fn managed( + mut stream: StartupStream, + host_addr: Option, + host: Host, + port: u16, + ssl_mode: SslMode, + connect_timeout: Option, +) -> Result<(Client, Connection), Error> +where + TlsStream: AsyncRead + AsyncWrite + Unpin, +{ let (process_id, secret_key) = wait_until_ready(&mut stream).await?; let socket_config = SocketConfig { host_addr, - host: host.clone(), + host, port, - connect_timeout: config.connect_timeout, + connect_timeout, }; let (client_tx, conn_rx) = mpsc::unbounded_channel(); @@ -61,7 +83,7 @@ where client_tx, client_rx, socket_config, - config.ssl_mode, + ssl_mode, process_id, secret_key, ); diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index a858ddca39..da2665095c 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -48,7 +48,7 @@ mod cancel_token; mod client; mod codec; pub mod config; -mod connect; +pub mod connect; pub mod connect_raw; mod connect_socket; mod connect_tls; diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 1e3631363e..ca784423ee 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -25,6 +25,7 @@ use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; +use crate::proxy::connect_compute::TlsNegotiation; use crate::proxy::neon_option; use crate::types::Host; @@ -84,6 +85,14 @@ pub(crate) enum ConnectionError { #[error("error acquiring resource permit: {0}")] TooManyConnectionAttempts(#[from] ApiLockError), + + #[cfg(test)] + #[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")] + TestError { + retryable: bool, + wakeable: bool, + kind: crate::error::ErrorKind, + }, } impl UserFacingError for ConnectionError { @@ -94,6 +103,8 @@ impl UserFacingError for ConnectionError { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(), + #[cfg(test)] + ConnectionError::TestError { .. } => self.to_string(), } } } @@ -104,6 +115,8 @@ impl ReportableError for ConnectionError { ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), + #[cfg(test)] + ConnectionError::TestError { kind, .. } => *kind, } } } @@ -256,6 +269,7 @@ impl ConnectInfo { async fn connect_raw( &self, config: &ComputeConfig, + tls: TlsNegotiation, ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { let timeout = config.timeout; @@ -298,7 +312,7 @@ impl ConnectInfo { match connect_once(&*addrs).await { Ok((sockaddr, stream)) => Ok(( sockaddr, - tls::connect_tls(stream, self.ssl_mode, config, host).await?, + tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?, )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); @@ -329,9 +343,10 @@ impl ConnectInfo { ctx: &RequestContext, aux: &MetricsAuxInfo, config: &ComputeConfig, + tls: TlsNegotiation, ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream) = self.connect_raw(config).await?; + let (socket_addr, stream) = self.connect_raw(config, tls).await?; drop(pause); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs index 000d75fca5..cc1c0d1658 100644 --- a/proxy/src/compute/tls.rs +++ b/proxy/src/compute/tls.rs @@ -7,6 +7,7 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use crate::pqproto::request_tls; +use crate::proxy::connect_compute::TlsNegotiation; use crate::proxy::retry::CouldRetry; #[derive(Debug, Error)] @@ -35,6 +36,7 @@ pub async fn connect_tls( mode: SslMode, tls: &T, host: &str, + negotiation: TlsNegotiation, ) -> Result, TlsError> where S: AsyncRead + AsyncWrite + Unpin + Send, @@ -49,12 +51,15 @@ where SslMode::Prefer | SslMode::Require => {} } - if !request_tls(&mut stream).await? { - if SslMode::Require == mode { - return Err(TlsError::Required); - } - - return Ok(MaybeTlsStream::Raw(stream)); + match negotiation { + // No TLS request needed + TlsNegotiation::Direct => {} + // TLS request successful + TlsNegotiation::Postgres if request_tls(&mut stream).await? => {} + // TLS request failed but is required + TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required), + // TLS request failed but is not required + TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)), } Ok(MaybeTlsStream::Tls( diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 639cd123e1..f947abebc0 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -16,8 +16,9 @@ use crate::pglb::ClientRequestError; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting}; +use crate::proxy::{ + ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting, +}; use crate::util::run_until_cancelled; pub async fn task_main( @@ -215,14 +216,11 @@ pub(crate) async fn handle_client( }; auth_info.set_startup_params(¶ms, true); - let mut node = connect_to_compute( + let mut node = connect_compute::connect_to_compute( ctx, - &TcpMechanism { - locks: &config.connect_compute_locks, - }, + config, &node_info, - config.wake_compute_retry_config, - &config.connect_to_compute, + connect_compute::TlsNegotiation::Postgres, ) .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 9bbd3f4fb7..5bfa24c92d 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -17,7 +17,6 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; -use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt}; @@ -72,16 +71,6 @@ pub(crate) struct NodeInfo { pub(crate) aux: MetricsAuxInfo, } -impl NodeInfo { - pub(crate) async fn connect( - &self, - ctx: &RequestContext, - config: &ComputeConfig, - ) -> Result { - self.conn_info.connect(ctx, &self.aux, config).await - } -} - #[derive(Copy, Clone, Default, Debug)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, diff --git a/proxy/src/proxy/connect_auth.rs b/proxy/src/proxy/connect_auth.rs new file mode 100644 index 0000000000..5a1d1ae314 --- /dev/null +++ b/proxy/src/proxy/connect_auth.rs @@ -0,0 +1,82 @@ +use thiserror::Error; + +use crate::auth::Backend; +use crate::auth::backend::ComputeUserInfo; +use crate::cache::Cache; +use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::control_plane::client::ControlPlaneClient; +use crate::error::{ReportableError, UserFacingError}; +use crate::proxy::connect_compute::{TlsNegotiation, connect_to_compute}; +use crate::proxy::retry::ShouldRetryWakeCompute; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error(transparent)] + Auth(#[from] PostgresError), + #[error(transparent)] + Connect(#[from] ConnectionError), +} + +impl UserFacingError for AuthError { + fn to_string_client(&self) -> String { + match self { + AuthError::Auth(postgres_error) => postgres_error.to_string_client(), + AuthError::Connect(connection_error) => connection_error.to_string_client(), + } + } +} + +impl ReportableError for AuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + AuthError::Auth(postgres_error) => postgres_error.get_error_kind(), + AuthError::Connect(connection_error) => connection_error.get_error_kind(), + } + } +} + +/// Try to connect to the compute node, retrying if necessary. +#[tracing::instrument(skip_all)] +pub(crate) async fn connect_to_compute_and_auth( + ctx: &RequestContext, + config: &ProxyConfig, + user_info: &Backend<'_, ComputeUserInfo>, + auth_info: AuthInfo, + tls: TlsNegotiation, +) -> Result { + let mut attempt = 0; + + // NOTE: This is messy, but should hopefully be detangled with PGLB. + // We wanted to separate the concerns of **connect** to compute (a PGLB operation), + // from **authenticate** to compute (a NeonKeeper operation). + // + // This unfortunately removed retry handling for one error case where + // the compute was cached, and we connected, but the compute cache was actually stale + // and is associated with the wrong endpoint. We detect this when the **authentication** fails. + // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. + loop { + attempt += 1; + let mut node = connect_to_compute(ctx, config, user_info, tls).await?; + + let res = auth_info.authenticate(ctx, &mut node).await; + match res { + Ok(()) => return Ok(node), + Err(e) => { + if attempt < 2 + && let Backend::ControlPlane(cplane, user_info) = user_info + && let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane + && e.should_retry_wake_compute() + { + tracing::warn!(error = ?e, "retrying wake compute"); + let key = user_info.endpoint_cache_key(); + cplane_proxy_v1.caches.node_info.invalidate(&key); + continue; + } + + return Err(e)?; + } + } + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index ce9774e3eb..1a4e5f77d2 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,18 +1,15 @@ -use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection}; -use crate::config::{ComputeConfig, RetryConfig}; +use crate::config::{ComputeConfig, ProxyConfig, RetryConfig}; use crate::context::RequestContext; -use crate::control_plane::errors::WakeComputeError; use crate::control_plane::locks::ApiLocks; use crate::control_plane::{self, NodeInfo}; -use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; +use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute}; use crate::types::Host; @@ -35,29 +32,32 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node node_info.invalidate() } -#[async_trait] pub(crate) trait ConnectMechanism { type Connection; - type ConnectError: ReportableError; - type Error: From; async fn connect_once( &self, ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, - ) -> Result; + ) -> Result; } -pub(crate) struct TcpMechanism { +struct TcpMechanism<'a> { /// connect_to_compute concurrency lock - pub(crate) locks: &'static ApiLocks, + locks: &'a ApiLocks, + tls: TlsNegotiation, } -#[async_trait] -impl ConnectMechanism for TcpMechanism { +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum TlsNegotiation { + /// TLS is assumed + Direct, + /// We must ask for TLS using the postgres SSLRequest message + Postgres, +} + +impl ConnectMechanism for TcpMechanism<'_> { type Connection = ComputeConnection; - type ConnectError = compute::ConnectionError; - type Error = compute::ConnectionError; #[tracing::instrument(skip_all, fields( pid = tracing::field::Empty, @@ -68,25 +68,47 @@ impl ConnectMechanism for TcpMechanism { ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, - ) -> Result { + ) -> Result { let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - permit.release_result(node_info.connect(ctx, config).await) + + permit.release_result( + node_info + .conn_info + .connect(ctx, &node_info.aux, config, self.tls) + .await, + ) } } /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] -pub(crate) async fn connect_to_compute( +pub(crate) async fn connect_to_compute( + ctx: &RequestContext, + config: &ProxyConfig, + user_info: &B, + tls: TlsNegotiation, +) -> Result { + connect_to_compute_inner( + ctx, + &TcpMechanism { + locks: &config.connect_compute_locks, + tls, + }, + user_info, + config.wake_compute_retry_config, + &config.connect_to_compute, + ) + .await +} + +/// Try to connect to the compute node, retrying if necessary. +pub(crate) async fn connect_to_compute_inner( ctx: &RequestContext, mechanism: &M, user_info: &B, wake_compute_retry_config: RetryConfig, compute: &ComputeConfig, -) -> Result -where - M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug, - M::Error: From, -{ +) -> Result { let mut num_retries = 0; let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; @@ -120,7 +142,7 @@ where }, num_retries.into(), ); - return Err(err.into()); + return Err(err); } node_info } else { @@ -161,7 +183,7 @@ where }, num_retries.into(), ); - return Err(e.into()); + return Err(e); } warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 053726505d..b42457cd95 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests; +pub(crate) mod connect_auth; pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; @@ -23,17 +24,13 @@ use tokio::net::TcpStream; use tokio::sync::oneshot; use tracing::Instrument; -use crate::cache::Cache; use crate::cancellation::{CancelClosure, CancellationHandler}; use crate::compute::{ComputeConnection, PostgresError, RustlsStream}; use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::client::ControlPlaneClient; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::retry::ShouldRetryWakeCompute; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; @@ -95,61 +92,24 @@ pub(crate) async fn handle_client( let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); auth_info.set_startup_params(params, params_compat); - let mut node; - let mut attempt = 0; - let connect = TcpMechanism { - locks: &config.connect_compute_locks, - }; let backend = auth::Backend::ControlPlane(cplane, creds.info); - // NOTE: This is messy, but should hopefully be detangled with PGLB. - // We wanted to separate the concerns of **connect** to compute (a PGLB operation), - // from **authenticate** to compute (a NeonKeeper operation). - // - // This unfortunately removed retry handling for one error case where - // the compute was cached, and we connected, but the compute cache was actually stale - // and is associated with the wrong endpoint. We detect this when the **authentication** fails. - // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. - loop { - attempt += 1; + // TODO: callback to pglb + let res = connect_auth::connect_to_compute_and_auth( + ctx, + config, + &backend, + auth_info, + connect_compute::TlsNegotiation::Postgres, + ) + .await; - // TODO: callback to pglb - let res = connect_to_compute( - ctx, - &connect, - &backend, - config.wake_compute_retry_config, - &config.connect_to_compute, - ) - .await; + let mut node = match res { + Ok(node) => node, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, + }; - match res { - Ok(n) => node = n, - Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?, - } - - let auth::Backend::ControlPlane(cplane, user_info) = &backend else { - unreachable!("ensured above"); - }; - - let res = auth_info.authenticate(ctx, &mut node).await; - match res { - Ok(()) => { - send_client_greeting(ctx, &config.greetings, client); - break; - } - Err(e) if attempt < 2 && e.should_retry_wake_compute() => { - tracing::warn!(error = ?e, "retrying wake compute"); - - #[allow(irrefutable_let_patterns)] - if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane { - let key = user_info.endpoint_cache_key(); - cplane_proxy_v1.caches.node_info.invalidate(&key); - } - } - Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, - } - } + send_client_greeting(ctx, &config.greetings, client); let auth::Backend::ControlPlane(_, user_info) = backend else { unreachable!("ensured above"); diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index b06c3be72c..876d252517 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -31,18 +31,6 @@ impl CouldRetry for io::Error { } } -impl CouldRetry for postgres_client::error::DbError { - fn could_retry(&self) -> bool { - use postgres_client::error::SqlState; - matches!( - self.code(), - &SqlState::CONNECTION_FAILURE - | &SqlState::CONNECTION_EXCEPTION - | &SqlState::CONNECTION_DOES_NOT_EXIST - | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION, - ) - } -} impl ShouldRetryWakeCompute for postgres_client::error::DbError { fn should_retry_wake_compute(&self) -> bool { use postgres_client::error::SqlState; @@ -73,17 +61,6 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError { } } -impl CouldRetry for postgres_client::Error { - fn could_retry(&self) -> bool { - if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) { - io::Error::could_retry(io_err) - } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { - postgres_client::error::DbError::could_retry(db_err) - } else { - false - } - } -} impl ShouldRetryWakeCompute for postgres_client::Error { fn should_retry_wake_compute(&self) -> bool { if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { @@ -102,6 +79,8 @@ impl CouldRetry for compute::ConnectionError { compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), compute::ConnectionError::TooManyConnectionAttempts(_) => false, + #[cfg(test)] + compute::ConnectionError::TestError { retryable, .. } => *retryable, } } } @@ -110,6 +89,8 @@ impl ShouldRetryWakeCompute for compute::ConnectionError { match self { // the cache entry was not checked for validity compute::ConnectionError::TooManyConnectionAttempts(_) => false, + #[cfg(test)] + compute::ConnectionError::TestError { wakeable, .. } => *wakeable, _ => true, } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index f8bff450e1..d1084628b1 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -24,13 +24,13 @@ use crate::context::RequestContext; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; -use crate::error::{ErrorKind, ReportableError}; +use crate::error::ErrorKind; use crate::pglb::ERR_INSECURE_CONNECTION; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute}; -use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after}; +use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute_inner}; +use crate::proxy::retry::retry_after; use crate::stream::{PqStream, Stream}; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; @@ -430,71 +430,36 @@ impl TestConnectMechanism { #[derive(Debug)] struct TestConnection; -#[derive(Debug)] -struct TestConnectError { - retryable: bool, - wakeable: bool, - kind: crate::error::ErrorKind, -} - -impl ReportableError for TestConnectError { - fn get_error_kind(&self) -> crate::error::ErrorKind { - self.kind - } -} - -impl std::fmt::Display for TestConnectError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl std::error::Error for TestConnectError {} - -impl CouldRetry for TestConnectError { - fn could_retry(&self) -> bool { - self.retryable - } -} -impl ShouldRetryWakeCompute for TestConnectError { - fn should_retry_wake_compute(&self) -> bool { - self.wakeable - } -} - -#[async_trait] impl ConnectMechanism for TestConnectMechanism { type Connection = TestConnection; - type ConnectError = TestConnectError; - type Error = anyhow::Error; async fn connect_once( &self, _ctx: &RequestContext, _node_info: &control_plane::CachedNodeInfo, _config: &ComputeConfig, - ) -> Result { + ) -> Result { let mut counter = self.counter.lock().unwrap(); let action = self.sequence[*counter]; *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { + ConnectAction::Retry => Err(compute::ConnectionError::TestError { retryable: true, wakeable: true, kind: ErrorKind::Compute, }), - ConnectAction::RetryNoWake => Err(TestConnectError { + ConnectAction::RetryNoWake => Err(compute::ConnectionError::TestError { retryable: true, wakeable: false, kind: ErrorKind::Compute, }), - ConnectAction::Fail => Err(TestConnectError { + ConnectAction::Fail => Err(compute::ConnectionError::TestError { retryable: false, wakeable: true, kind: ErrorKind::Compute, }), - ConnectAction::FailNoWake => Err(TestConnectError { + ConnectAction::FailNoWake => Err(compute::ConnectionError::TestError { retryable: false, wakeable: false, kind: ErrorKind::Compute, @@ -620,7 +585,7 @@ async fn connect_to_compute_success() { let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -634,7 +599,7 @@ async fn connect_to_compute_retry() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -649,7 +614,7 @@ async fn connect_to_compute_non_retry_1() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -664,7 +629,7 @@ async fn connect_to_compute_non_retry_2() { let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -686,7 +651,7 @@ async fn connect_to_compute_non_retry_3() { backoff_factor: 2.0, }; let config = config(); - connect_to_compute( + connect_to_compute_inner( &ctx, &mechanism, &user_info, @@ -707,7 +672,7 @@ async fn wake_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -722,7 +687,7 @@ async fn wake_non_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -741,7 +706,7 @@ async fn fail_but_wake_invalidates_cache() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg) .await .unwrap(); @@ -762,7 +727,7 @@ async fn fail_no_wake_skips_cache_invalidation() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg) .await .unwrap(); @@ -783,7 +748,7 @@ async fn retry_but_wake_invalidates_cache() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); @@ -806,7 +771,7 @@ async fn retry_no_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); @@ -829,7 +794,7 @@ async fn retry_no_wake_error_fast() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); @@ -852,7 +817,7 @@ async fn retry_cold_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 31df7eb9f1..0987b6927f 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,17 +1,11 @@ -use std::io; -use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; -use postgres_client::config::SslMode; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use rand_core::OsRng; -use rustls::pki_types::{DnsName, ServerName}; -use tokio::net::{TcpStream, lookup_host}; -use tokio_rustls::TlsConnector; use tracing::field::display; use tracing::{debug, info}; @@ -21,23 +15,22 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; -use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; +use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; +use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; -use crate::config::{ComputeConfig, ProxyConfig}; +use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; -use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::proxy::connect_compute::ConnectMechanism; -use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; +use crate::pqproto::StartupMessageParams; +use crate::proxy::{connect_auth, connect_compute}; use crate::rate_limiter::EndpointRateLimiter; -use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; +use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX}; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: @@ -186,20 +179,42 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys.info); - crate::proxy::connect_compute::connect_to_compute( + + let mut params = StartupMessageParams::default(); + params.insert("database", &conn_info.dbname); + params.insert("user", &conn_info.user_info.user); + + let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys); + auth_info.set_startup_params(¶ms, true); + + let node = connect_auth::connect_to_compute_and_auth( ctx, - &TokioMechanism { - conn_id, - conn_info, - pool: self.pool.clone(), - locks: &self.config.connect_compute_locks, - keys: keys.keys, - }, + self.config, &backend, - self.config.wake_compute_retry_config, - &self.config.connect_to_compute, + auth_info, + connect_compute::TlsNegotiation::Postgres, ) - .await + .await?; + + let (client, connection) = postgres_client::connect::managed( + node.stream, + Some(node.socket_addr.ip()), + postgres_client::config::Host::Tcp(node.hostname.to_string()), + node.socket_addr.port(), + node.ssl_mode, + Some(self.config.connect_to_compute.timeout), + ) + .await?; + + Ok(poll_client( + self.pool.clone(), + ctx, + conn_info, + client, + connection, + conn_id, + node.aux, + )) } // Wake up the destination if needed @@ -228,19 +243,38 @@ impl PoolingBackend { )), options: conn_info.user_info.options.clone(), }); - crate::proxy::connect_compute::connect_to_compute( + + let node = connect_compute::connect_to_compute( ctx, - &HyperMechanism { - conn_id, - conn_info, - pool: self.http_conn_pool.clone(), - locks: &self.config.connect_compute_locks, - }, + self.config, &backend, - self.config.wake_compute_retry_config, - &self.config.connect_to_compute, + connect_compute::TlsNegotiation::Direct, ) - .await + .await?; + + let stream = match node.stream.into_framed().into_inner() { + MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW, + MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW, + }; + + let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .timer(TokioTimer::new()) + .keep_alive_interval(Duration::from_secs(20)) + .keep_alive_while_idle(true) + .keep_alive_timeout(Duration::from_secs(5)) + .handshake(TokioIo::new(stream)) + .await + .map_err(LocalProxyConnError::H2)?; + + Ok(poll_http2_client( + self.http_conn_pool.clone(), + ctx, + &conn_info, + client, + connection, + conn_id, + node.aux.clone(), + )) } /// Connect to postgres over localhost. @@ -380,6 +414,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) { pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), + #[error("could not connect to compute")] + ConnectError(#[from] compute::ConnectionError), #[error("could not connect to postgres in compute")] PostgresConnectionError(#[from] postgres_client::Error), #[error("could not connect to local-proxy in compute")] @@ -399,10 +435,19 @@ pub(crate) enum HttpConnError { TooManyConnectionAttempts(#[from] ApiLockError), } +impl From for HttpConnError { + fn from(value: connect_auth::AuthError) -> Self { + match value { + connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => { + Self::PostgresConnectionError(error) + } + connect_auth::AuthError::Connect(error) => Self::ConnectError(error), + } + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum LocalProxyConnError { - #[error("error with connection to local-proxy")] - Io(#[source] std::io::Error), #[error("could not establish h2 connection")] H2(#[from] hyper::Error), } @@ -410,6 +455,7 @@ pub(crate) enum LocalProxyConnError { impl ReportableError for HttpConnError { fn get_error_kind(&self) -> ErrorKind { match self { + HttpConnError::ConnectError(_) => ErrorKind::Compute, HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => { if p.as_db_error().is_some() { @@ -434,6 +480,7 @@ impl ReportableError for HttpConnError { impl UserFacingError for HttpConnError { fn to_string_client(&self) -> String { match self { + HttpConnError::ConnectError(p) => p.to_string_client(), HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), @@ -449,36 +496,9 @@ impl UserFacingError for HttpConnError { } } -impl CouldRetry for HttpConnError { - fn could_retry(&self) -> bool { - match self { - HttpConnError::PostgresConnectionError(e) => e.could_retry(), - HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), - HttpConnError::ComputeCtl(_) => false, - HttpConnError::ConnectionClosedAbruptly(_) => false, - HttpConnError::JwtPayloadError(_) => false, - HttpConnError::GetAuthInfo(_) => false, - HttpConnError::AuthError(_) => false, - HttpConnError::WakeCompute(_) => false, - HttpConnError::TooManyConnectionAttempts(_) => false, - } - } -} -impl ShouldRetryWakeCompute for HttpConnError { - fn should_retry_wake_compute(&self) -> bool { - match self { - HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(), - // we never checked cache validity - HttpConnError::TooManyConnectionAttempts(_) => false, - _ => true, - } - } -} - impl ReportableError for LocalProxyConnError { fn get_error_kind(&self) -> ErrorKind { match self { - LocalProxyConnError::Io(_) => ErrorKind::Compute, LocalProxyConnError::H2(_) => ErrorKind::Compute, } } @@ -489,215 +509,3 @@ impl UserFacingError for LocalProxyConnError { "Could not establish HTTP connection to the database".to_string() } } - -impl CouldRetry for LocalProxyConnError { - fn could_retry(&self) -> bool { - match self { - LocalProxyConnError::Io(_) => false, - LocalProxyConnError::H2(_) => false, - } - } -} -impl ShouldRetryWakeCompute for LocalProxyConnError { - fn should_retry_wake_compute(&self) -> bool { - match self { - LocalProxyConnError::Io(_) => false, - LocalProxyConnError::H2(_) => false, - } - } -} - -struct TokioMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, - keys: ComputeCredentialKeys, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for TokioMechanism { - type Connection = Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - compute_config: &ComputeConfig, - ) -> Result { - let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - - let mut config = node_info.conn_info.to_postgres_client_config(); - let config = config - .user(&self.conn_info.user_info.user) - .dbname(&self.conn_info.dbname) - .connect_timeout(compute_config.timeout); - - if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys { - config.auth_keys(auth_keys); - } - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(compute_config).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_client( - self.pool.clone(), - ctx, - self.conn_info.clone(), - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -struct HyperMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for HyperMechanism { - type Connection = http_conn_pool::Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - config: &ComputeConfig, - ) -> Result { - let host_addr = node_info.conn_info.host_addr; - let host = &node_info.conn_info.host; - let permit = self.locks.get_permit(host).await?; - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - - let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { - None - } else { - Some(&config.tls) - }; - - let port = node_info.conn_info.port; - let res = connect_http2(host_addr, host, port, config.timeout, tls).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_http2_client( - self.pool.clone(), - ctx, - &self.conn_info, - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -async fn connect_http2( - host_addr: Option, - host: &str, - port: u16, - timeout: Duration, - tls: Option<&Arc>, -) -> Result< - ( - http_conn_pool::LocalProxyClient, - http_conn_pool::LocalProxyConnection, - ), - LocalProxyConnError, -> { - let addrs = match host_addr { - Some(addr) => vec![SocketAddr::new(addr, port)], - None => lookup_host((host, port)) - .await - .map_err(LocalProxyConnError::Io)? - .collect(), - }; - let mut last_err = None; - - let mut addrs = addrs.into_iter(); - let stream = loop { - let Some(addr) = addrs.next() else { - return Err(last_err.unwrap_or_else(|| { - LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })); - }; - - match tokio::time::timeout(timeout, TcpStream::connect(addr)).await { - Ok(Ok(stream)) => { - stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?; - break stream; - } - Ok(Err(e)) => { - last_err = Some(LocalProxyConnError::Io(e)); - } - Err(e) => { - last_err = Some(LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::TimedOut, - e, - ))); - } - } - }; - - let stream = if let Some(tls) = tls { - let host = DnsName::try_from(host) - .map_err(io::Error::other) - .map_err(LocalProxyConnError::Io)? - .to_owned(); - let stream = TlsConnector::from(tls.clone()) - .connect(ServerName::DnsName(host), stream) - .await - .map_err(LocalProxyConnError::Io)?; - Box::pin(stream) as AsyncRW - } else { - Box::pin(stream) as AsyncRW - }; - - let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) - .timer(TokioTimer::new()) - .keep_alive_interval(Duration::from_secs(20)) - .keep_alive_while_idle(true) - .keep_alive_timeout(Duration::from_secs(5)) - .handshake(TokioIo::new(stream)) - .await?; - - Ok((client, connection)) -}