From 3f6a971e4f69b86c4ce4841269325a708a5bf0e3 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Mon, 9 Jun 2025 17:28:17 +0200 Subject: [PATCH] WIP --- proxy/src/binary/proxy.rs | 2 +- proxy/src/compute/mod.rs | 2 + proxy/src/console_redirect_proxy.rs | 5 +- proxy/src/pglb/mod.rs | 357 ++++++++++++++++++++++ proxy/src/proxy/mod.rs | 443 +++++++++------------------- proxy/src/proxy/tests/mod.rs | 3 +- proxy/src/serverless/websocket.rs | 2 +- 7 files changed, 511 insertions(+), 303 deletions(-) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 757c1e988b..b40871113d 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -410,7 +410,7 @@ pub async fn run() -> anyhow::Result<()> { match auth_backend { Either::Left(auth_backend) => { if let Some(proxy_listener) = proxy_listener { - client_tasks.spawn(crate::proxy::task_main( + client_tasks.spawn(crate::pglb::task_main( config, auth_backend, proxy_listener, diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 0dacd15547..818d46c043 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -103,6 +103,8 @@ pub enum Auth { } /// A config for authenticating to the compute node. +// XXX: clone +#[derive(Clone)] pub(crate) struct AuthInfo { /// None for local-proxy, as we use trust-based localhost auth. /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect. diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 8ea24fbffb..22e5ae1d3c 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -14,10 +14,9 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard}; use crate::pglb::connect_compute::TcpMechanism; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; +use crate::pglb::{ClientRequestError, ErrorSource}; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::{ - ClientRequestError, ErrorSource, connect_to_compute, prepare_client_connection, -}; +use crate::proxy::{connect_to_compute, prepare_client_connection}; use crate::util::run_until_cancelled; pub async fn task_main( diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 4b107142a7..5b00a47c30 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -3,3 +3,360 @@ pub mod copy_bidirectional; pub mod handshake; pub mod inprocess; pub mod passthrough; + +use std::sync::Arc; + +use futures::FutureExt; +use itertools::Itertools; +use once_cell::sync::OnceCell; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use smol_str::{SmolStr, ToSmolStr, format_smolstr}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, debug, error, info, warn}; + +use crate::cancellation::{self, CancellationHandler}; +use crate::compute::COULD_NOT_CONNECT; +use crate::config::{ComputeConfig, ProxyConfig, ProxyProtocolV2, RetryConfig, TlsConfig}; +use crate::context::RequestContext; +use crate::control_plane::NodeInfo; +use crate::control_plane::errors::WakeComputeError; +use crate::error::{ReportableError, UserFacingError}; +use crate::metrics::{ + ConnectOutcome, ConnectionFailureKind, Metrics, NumClientConnectionsGuard, RetriesMetricGroup, + RetryType, +}; +use crate::pglb::connect_compute::{ComputeConnectBackend, ConnectMechanism, TcpMechanism}; +pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; +use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; +use crate::pglb::passthrough::ProxyPassthrough; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; +use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; +use crate::proxy::handle_connect_request; +use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; +use crate::proxy::wake_compute::wake_compute; +use crate::rate_limiter::EndpointRateLimiter; +use crate::stream::{PqStream, Stream}; +use crate::types::EndpointCacheKey; +use crate::util::run_until_cancelled; +use crate::{auth, compute, control_plane}; + +const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; + +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + +pub async fn task_main( + config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, ()>, + listener: tokio::net::TcpListener, + cancellation_token: CancellationToken, + cancellation_handler: Arc, + endpoint_rate_limiter: Arc, +) -> anyhow::Result<()> { + scopeguard::defer! { + info!("proxy has shut down"); + } + + // When set for the server socket, the keepalive setting + // will be inherited by all accepted client sockets. + socket2::SockRef::from(&listener).set_keepalive(true)?; + + let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); + + while let Some(accept_result) = + run_until_cancelled(listener.accept(), &cancellation_token).await + { + let (socket, peer_addr) = accept_result?; + + let conn_gauge = Metrics::get() + .proxy + .client_connections + .guard(crate::metrics::Protocol::Tcp); + + let session_id = uuid::Uuid::new_v4(); + let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); + + debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); + let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); + + connections.spawn(async move { + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } + } + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( + socket, + ConnectionInfo { + addr: peer_addr, + extra: None, + }, + ), + }; + + match socket.set_nodelay(true) { + Ok(()) => {} + Err(e) => { + error!( + "per-client task finished with an error: failed to set socket option: {e:#}" + ); + return; + } + } + + let ctx = RequestContext::new( + session_id, + conn_info, + crate::metrics::Protocol::Tcp, + &config.region, + ); + + let res = handle_client( + config, + auth_backend, + &ctx, + cancellation_handler, + socket, + ClientMode::Tcp, + endpoint_rate_limiter2, + conn_gauge, + cancellations, + ) + .instrument(ctx.span()) + .boxed() + .await; + + match res { + Err(e) => { + ctx.set_error_kind(e.get_error_kind()); + warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); + } + Ok(None) => { + ctx.set_success(); + } + Ok(Some(p)) => { + ctx.set_success(); + let _disconnect = ctx.log_connect(); + match p.proxy_pass(&config.connect_to_compute).await { + Ok(()) => {} + Err(ErrorSource::Client(e)) => { + warn!( + ?session_id, + "per-client task finished with an IO error from the client: {e:#}" + ); + } + Err(ErrorSource::Compute(e)) => { + error!( + ?session_id, + "per-client task finished with an IO error from the compute: {e:#}" + ); + } + } + } + } + }); + } + + connections.close(); + cancellations.close(); + drop(listener); + + // Drain connections + connections.wait().await; + cancellations.wait().await; + + Ok(()) +} + +pub(crate) enum ClientMode { + Tcp, + Websockets { hostname: Option }, +} + +/// Abstracts the logic of handling TCP vs WS clients +impl ClientMode { + pub(crate) fn allow_cleartext(&self) -> bool { + match self { + ClientMode::Tcp => false, + ClientMode::Websockets { .. } => true, + } + } + + pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { + match self { + ClientMode::Tcp => s.sni_hostname(), + ClientMode::Websockets { hostname } => hostname.as_deref(), + } + } + + pub(crate) fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { + match self { + ClientMode::Tcp => tls, + // TLS is None here if using websockets, because the connection is already encrypted. + ClientMode::Websockets { .. } => None, + } + } +} + +#[derive(Debug, Error)] +// almost all errors should be reported to the user, but there's a few cases where we cannot +// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons +// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, +// we cannot be sure the client even understands our error message +// 3. PrepareClient: The client disconnected, so we can't tell them anyway... +pub(crate) enum ClientRequestError { + #[error("{0}")] + Cancellation(#[from] cancellation::CancelError), + #[error("{0}")] + Handshake(#[from] HandshakeError), + #[error("{0}")] + HandshakeTimeout(#[from] tokio::time::error::Elapsed), + #[error("{0}")] + PrepareClient(#[from] std::io::Error), + #[error("{0}")] + ReportedError(#[from] crate::stream::ReportedError), +} + +impl ReportableError for ClientRequestError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ClientRequestError::Cancellation(e) => e.get_error_kind(), + ClientRequestError::Handshake(e) => e.get_error_kind(), + ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, + ClientRequestError::ReportedError(e) => e.get_error_kind(), + ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_client( + config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, ()>, + ctx: &RequestContext, + cancellation_handler: Arc, + client: S, + mode: ClientMode, + endpoint_rate_limiter: Arc, + conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, +) -> Result>, ClientRequestError> { + debug!( + protocol = %ctx.protocol(), + "handling interactive connection from client" + ); + + let metrics = &Metrics::get().proxy; + let proto = ctx.protocol(); + let request_gauge = metrics.connection_requests.guard(proto); + + let tls = config.tls_config.load(); + let tls = tls.as_deref(); + + let record_handshake_error = !ctx.has_private_peer_addr(); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error); + + let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) + .await?? + { + HandshakeData::Startup(client, params) => (client, params), + HandshakeData::Cancel(cancel_key_data) => { + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let ctx = ctx.clone(); + let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); + cancel_span.follows_from(tracing::Span::current()); + async move { + cancellation_handler_clone + .cancel_session( + cancel_key_data, + ctx, + config.authentication_config.ip_allowlist_check_enabled, + config.authentication_config.is_vpc_acccess_proxy, + auth_backend.get_api(), + ) + .await + .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); + }.instrument(cancel_span) + }); + + return Ok(None); + } + }; + drop(pause); + + ctx.set_db_options(params.clone()); + + let hostname = mode.hostname(client.get_ref()); + + let common_names = tls.map(|tls| &tls.common_names); + + let private_link_id = match ctx.extra() { + Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), + Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), + None => None, + }; + + let client = handle_connect_request( + config, + auth_backend, + ctx, + cancellation_handler, + client, + mode, + endpoint_rate_limiter, + ¶ms, + hostname, + common_names, + async |ctx, node_info, auth_info, creds, compute_config| { + let mech = &TcpMechanism { + user_info: creds.info.clone(), + auth: auth_info.clone(), + locks: &config.connect_compute_locks, + }; + + mech.connect_once(ctx, node_info, compute_config).await + }, + ) + .await?; + + Ok(Some(ProxyPassthrough { + client, + aux: node.aux.clone(), + private_link_id, + compute: node, + session_id: ctx.session_id(), + cancel: session, + _req: request_gauge, + _conn: conn_gauge, + })) +} diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 8e55ce50d2..ac50c0b587 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -4,43 +4,39 @@ mod tests; pub(crate) mod retry; pub(crate) mod wake_compute; +use std::collections::HashSet; use std::sync::Arc; -use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; use regex::Regex; use serde::{Deserialize, Serialize}; -use smol_str::{SmolStr, ToSmolStr, format_smolstr}; +use smol_str::{SmolStr, format_smolstr}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time; -use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; -use crate::cancellation::{self, CancellationHandler}; -use crate::compute::COULD_NOT_CONNECT; -use crate::config::{ComputeConfig, ProxyConfig, ProxyProtocolV2, RetryConfig, TlsConfig}; +use crate::auth::backend::ComputeCredentials; +use crate::cancellation::CancellationHandler; +use crate::compute::{AuthInfo, COULD_NOT_CONNECT}; +use crate::config::{ComputeConfig, ProxyConfig, RetryConfig}; use crate::context::RequestContext; -use crate::control_plane::NodeInfo; use crate::control_plane::errors::WakeComputeError; +use crate::control_plane::{CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{ - ConnectOutcome, ConnectionFailureKind, Metrics, NumClientConnectionsGuard, RetriesMetricGroup, - RetryType, + ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::pglb::connect_compute::{ComputeConnectBackend, ConnectMechanism, TcpMechanism}; +use crate::pglb::connect_compute::{ComputeConnectBackend, ConnectMechanism}; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; -use crate::pglb::passthrough::ProxyPassthrough; +use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; -use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; -use crate::util::run_until_cancelled; use crate::{auth, compute, control_plane}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; @@ -57,268 +53,29 @@ impl ReportableError for TlsRequired { impl UserFacingError for TlsRequired {} -pub async fn task_main( - config: &'static ProxyConfig, - auth_backend: &'static auth::Backend<'static, ()>, - listener: tokio::net::TcpListener, - cancellation_token: CancellationToken, - cancellation_handler: Arc, - endpoint_rate_limiter: Arc, -) -> anyhow::Result<()> { - scopeguard::defer! { - info!("proxy has shut down"); - } - - // When set for the server socket, the keepalive setting - // will be inherited by all accepted client sockets. - socket2::SockRef::from(&listener).set_keepalive(true)?; - - let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); - - while let Some(accept_result) = - run_until_cancelled(listener.accept(), &cancellation_token).await - { - let (socket, peer_addr) = accept_result?; - - let conn_gauge = Metrics::get() - .proxy - .client_connections - .guard(crate::metrics::Protocol::Tcp); - - let session_id = uuid::Uuid::new_v4(); - let cancellation_handler = Arc::clone(&cancellation_handler); - let cancellations = cancellations.clone(); - - debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); - let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); - - connections.spawn(async move { - let (socket, conn_info) = match config.proxy_protocol_v2 { - ProxyProtocolV2::Required => { - match read_proxy_protocol(socket).await { - Err(e) => { - warn!("per-client task finished with an error: {e:#}"); - return; - } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - } - } - // ignore the header - it cannot be confused for a postgres or http connection so will - // error later. - ProxyProtocolV2::Rejected => ( - socket, - ConnectionInfo { - addr: peer_addr, - extra: None, - }, - ), - }; - - match socket.set_nodelay(true) { - Ok(()) => {} - Err(e) => { - error!( - "per-client task finished with an error: failed to set socket option: {e:#}" - ); - return; - } - } - - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Tcp, - &config.region, - ); - - let res = handle_client( - config, - auth_backend, - &ctx, - cancellation_handler, - socket, - ClientMode::Tcp, - endpoint_rate_limiter2, - conn_gauge, - cancellations, - ) - .instrument(ctx.span()) - .boxed() - .await; - - match res { - Err(e) => { - ctx.set_error_kind(e.get_error_kind()); - warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); - } - Ok(None) => { - ctx.set_success(); - } - Ok(Some(p)) => { - ctx.set_success(); - let _disconnect = ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { - Ok(()) => {} - Err(ErrorSource::Client(e)) => { - warn!( - ?session_id, - "per-client task finished with an IO error from the client: {e:#}" - ); - } - Err(ErrorSource::Compute(e)) => { - error!( - ?session_id, - "per-client task finished with an IO error from the compute: {e:#}" - ); - } - } - } - } - }); - } - - connections.close(); - cancellations.close(); - drop(listener); - - // Drain connections - connections.wait().await; - cancellations.wait().await; - - Ok(()) -} - -pub(crate) enum ClientMode { - Tcp, - Websockets { hostname: Option }, -} - -/// Abstracts the logic of handling TCP vs WS clients -impl ClientMode { - pub(crate) fn allow_cleartext(&self) -> bool { - match self { - ClientMode::Tcp => false, - ClientMode::Websockets { .. } => true, - } - } - - fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { - match self { - ClientMode::Tcp => s.sni_hostname(), - ClientMode::Websockets { hostname } => hostname.as_deref(), - } - } - - fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { - match self { - ClientMode::Tcp => tls, - // TLS is None here if using websockets, because the connection is already encrypted. - ClientMode::Websockets { .. } => None, - } - } -} - -#[derive(Debug, Error)] -// almost all errors should be reported to the user, but there's a few cases where we cannot -// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons -// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, -// we cannot be sure the client even understands our error message -// 3. PrepareClient: The client disconnected, so we can't tell them anyway... -pub(crate) enum ClientRequestError { - #[error("{0}")] - Cancellation(#[from] cancellation::CancelError), - #[error("{0}")] - Handshake(#[from] HandshakeError), - #[error("{0}")] - HandshakeTimeout(#[from] tokio::time::error::Elapsed), - #[error("{0}")] - PrepareClient(#[from] std::io::Error), - #[error("{0}")] - ReportedError(#[from] crate::stream::ReportedError), -} - -impl ReportableError for ClientRequestError { - fn get_error_kind(&self) -> crate::error::ErrorKind { - match self { - ClientRequestError::Cancellation(e) => e.get_error_kind(), - ClientRequestError::Handshake(e) => e.get_error_kind(), - ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, - ClientRequestError::ReportedError(e) => e.get_error_kind(), - ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, - } - } -} - #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_connect_request< + S: AsyncRead + AsyncWrite + Unpin + Send, + C: AsyncFnMut( + &RequestContext, + &CachedNodeInfo, + &AuthInfo, + &ComputeCredentials, + &ComputeConfig, + ) -> Result, +>( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, cancellation_handler: Arc, - stream: S, + mut client: PqStream>, mode: ClientMode, endpoint_rate_limiter: Arc, - conn_gauge: NumClientConnectionsGuard<'static>, - cancellations: tokio_util::task::task_tracker::TaskTracker, -) -> Result>, ClientRequestError> { - debug!( - protocol = %ctx.protocol(), - "handling interactive connection from client" - ); - - let metrics = &Metrics::get().proxy; - let proto = ctx.protocol(); - let request_gauge = metrics.connection_requests.guard(proto); - - let tls = config.tls_config.load(); - let tls = tls.as_deref(); - - let record_handshake_error = !ctx.has_private_peer_addr(); - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); - - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) - .await?? - { - HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(cancel_key_data) => { - // spawn a task to cancel the session, but don't wait for it - cancellations.spawn({ - let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let ctx = ctx.clone(); - let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); - cancel_span.follows_from(tracing::Span::current()); - async move { - cancellation_handler_clone - .cancel_session( - cancel_key_data, - ctx, - config.authentication_config.ip_allowlist_check_enabled, - config.authentication_config.is_vpc_acccess_proxy, - auth_backend.get_api(), - ) - .await - .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); - }.instrument(cancel_span) - }); - - return Ok(None); - } - }; - drop(pause); - - ctx.set_db_options(params.clone()); - - let hostname = mode.hostname(stream.get_ref()); - - let common_names = tls.map(|tls| &tls.common_names); - + params: &StartupMessageParams, + hostname: Option<&str>, + common_names: Option<&HashSet>, + connect_compute_fn: C, +) -> Result, ClientRequestError> { // Extract credentials which we're going to use for auth. let result = auth_backend .as_ref() @@ -327,14 +84,14 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); let user_info = match user_info .authenticate( ctx, - &mut stream, + &mut client, mode.allow_cleartext(), &config.authentication_config, endpoint_rate_limiter, @@ -347,7 +104,7 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return Err(stream + return Err(client .throw_error(e, Some(ctx)) .instrument(params_span) .await)?; @@ -362,14 +119,12 @@ pub(crate) async fn handle_client( let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys); auth_info.set_startup_params(¶ms, params_compat); - let res = connect_to_compute( + let res = connect_to_compute_pglb( ctx, - &TcpMechanism { - user_info: creds.info.clone(), - auth: auth_info, - locks: &config.connect_compute_locks, - }, + connect_compute_fn, &user_info, + &auth_info, + &creds, config.wake_compute_retry_config, &config.connect_to_compute, ) @@ -377,32 +132,17 @@ pub(crate) async fn handle_client( let node = match res { Ok(node) => node, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream); - let stream = stream.flush_and_into_inner().await?; + prepare_client_connection(&node, *session.key(), &mut client); + let client = client.flush_and_into_inner().await?; - let private_link_id = match ctx.extra() { - Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), - Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), - None => None, - }; - - Ok(Some(ProxyPassthrough { - client: stream, - aux: node.aux.clone(), - private_link_id, - compute: node, - session_id: ctx.session_id(), - cancel: session, - _req: request_gauge, - _conn: conn_gauge, - })) + Ok(client) } /// If we couldn't connect, a cached connection info might be to blame @@ -424,6 +164,115 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node node_info.invalidate() } +#[tracing::instrument(skip_all)] +pub(crate) async fn connect_to_compute_pglb< + C: AsyncFnMut( + &RequestContext, + &CachedNodeInfo, + &AuthInfo, + &ComputeCredentials, + &ComputeConfig, + ) -> Result, + B: ComputeConnectBackend, +>( + ctx: &RequestContext, + mut connect_compute_fn: C, + user_info: &B, + auth_info: &AuthInfo, + creds: &ComputeCredentials, + wake_compute_retry_config: RetryConfig, + compute: &ComputeConfig, +) -> Result { + let mut num_retries = 0; + let node_info = + wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; + + // try once + let err = match connect_compute_fn(ctx, &node_info, &auth_info, creds, compute).await { + Ok(res) => { + ctx.success(); + Metrics::get().proxy.retries_metric.observe( + RetriesMetricGroup { + outcome: ConnectOutcome::Success, + retry_type: RetryType::ConnectToCompute, + }, + num_retries.into(), + ); + return Ok(res); + } + Err(e) => e, + }; + + debug!(error = ?err, COULD_NOT_CONNECT); + + let node_info = if !node_info.cached() || !err.should_retry_wake_compute() { + // If we just recieved this from cplane and didn't get it from cache, we shouldn't retry. + // Do not need to retrieve a new node_info, just return the old one. + if should_retry(&err, num_retries, compute.retry) { + Metrics::get().proxy.retries_metric.observe( + RetriesMetricGroup { + outcome: ConnectOutcome::Failed, + retry_type: RetryType::ConnectToCompute, + }, + num_retries.into(), + ); + return Err(err.into()); + } + node_info + } else { + // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node + debug!("compute node's state has likely changed; requesting a wake-up"); + invalidate_cache(node_info); + // TODO: increment num_retries? + wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await? + }; + + // now that we have a new node, try connect to it repeatedly. + // this can error for a few reasons, for instance: + // * DNS connection settings haven't quite propagated yet + debug!("wake_compute success. attempting to connect"); + num_retries = 1; + loop { + match connect_compute_fn(ctx, &node_info, &auth_info, creds, compute).await { + Ok(res) => { + ctx.success(); + Metrics::get().proxy.retries_metric.observe( + RetriesMetricGroup { + outcome: ConnectOutcome::Success, + retry_type: RetryType::ConnectToCompute, + }, + num_retries.into(), + ); + // TODO: is this necessary? We have a metric. + info!(?num_retries, "connected to compute node after"); + return Ok(res); + } + Err(e) => { + if !should_retry(&e, num_retries, compute.retry) { + // Don't log an error here, caller will print the error + Metrics::get().proxy.retries_metric.observe( + RetriesMetricGroup { + outcome: ConnectOutcome::Failed, + retry_type: RetryType::ConnectToCompute, + }, + num_retries.into(), + ); + return Err(e.into()); + } + + warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT); + } + } + + let wait_duration = retry_after(num_retries, compute.retry); + num_retries += 1; + + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout); + time::sleep(wait_duration).await; + drop(pause); + } +} + /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] pub(crate) async fn connect_to_compute( diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 028247a97d..e4f54db1ec 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -22,12 +22,13 @@ use super::*; use crate::auth::backend::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, }; -use crate::config::{ComputeConfig, RetryConfig}; +use crate::config::{ComputeConfig, RetryConfig, TlsConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::pglb::connect_compute::ConnectMechanism; +use crate::pglb::handshake::{HandshakeData, handshake}; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 8648a94869..96ac859080 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -17,7 +17,7 @@ use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::{ClientMode, ErrorSource, handle_client}; +use crate::pglb::{ClientMode, ErrorSource, handle_client}; use crate::rate_limiter::EndpointRateLimiter; pin_project! {