From c90b082222c7055634c10ad6462ca6c041ffead3 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Tue, 10 Jun 2025 19:41:07 +0200 Subject: [PATCH] Split handle_client and pass async callback for connect_once --- proxy/src/compute/mod.rs | 2 + proxy/src/pglb/mod.rs | 110 ++++++++-------------------- proxy/src/proxy/connect_compute.rs | 33 ++++++--- proxy/src/proxy/mod.rs | 113 ++++++++++++++++++++++++++++- 4 files changed, 165 insertions(+), 93 deletions(-) diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index aae1fea07d..708e093553 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -103,6 +103,8 @@ pub enum Auth { } /// A config for authenticating to the compute node. +// TODO: avoid Clone +#[derive(Clone)] pub(crate) struct AuthInfo { /// None for local-proxy, as we use trust-based localhost auth. /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect. diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index ef652972f1..2cc0d6a5bc 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -12,6 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; +use crate::auth; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; @@ -21,12 +22,11 @@ pub use crate::pglb::copy_bidirectional::ErrorSource; use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute_pglb}; -use crate::proxy::{NeonOptions, prepare_client_connection}; +use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism}; +use crate::proxy::handle_connect_request; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::util::run_until_cancelled; -use crate::{auth, compute}; pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; @@ -193,7 +193,7 @@ impl ClientMode { } } - fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { + pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { match self { ClientMode::Tcp => s.sni_hostname(), ClientMode::Websockets { hostname } => hostname.as_deref(), @@ -246,7 +246,7 @@ pub(crate) async fn handle_client( auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, cancellation_handler: Arc, - stream: S, + client: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, @@ -266,12 +266,12 @@ pub(crate) async fn handle_client( let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); + let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error); - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) + let (client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) .await?? { - HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Startup(client, params) => (client, params), HandshakeData::Cancel(cancel_key_data) => { // spawn a task to cancel the session, but don't wait for it cancellations.spawn({ @@ -300,86 +300,38 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); - let hostname = mode.hostname(stream.get_ref()); - let common_names = tls.map(|tls| &tls.common_names); - // Extract credentials which we're going to use for auth. - let result = auth_backend - .as_ref() - .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) - .transpose(); - - let user_info = match result { - Ok(user_info) => user_info, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, - }; - - let user = user_info.get_user().to_owned(); - let user_info = match user_info - .authenticate( - ctx, - &mut stream, - mode.allow_cleartext(), - &config.authentication_config, - endpoint_rate_limiter, - ) - .await - { - Ok(auth_result) => auth_result, - Err(e) => { - let db = params.get("database"); - let app = params.get("application_name"); - let params_span = tracing::info_span!("", ?user, ?db, ?app); - - return Err(stream - .throw_error(e, Some(ctx)) - .instrument(params_span) - .await)?; - } - }; - - let (cplane, creds) = match user_info { - auth::Backend::ControlPlane(cplane, creds) => (cplane, creds), - auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"), - }; - let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); - let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); - auth_info.set_startup_params(¶ms, params_compat); - - let res = connect_to_compute_pglb( - ctx, - &TcpMechanism { - user_info: creds.info.clone(), - auth: auth_info, - locks: &config.connect_compute_locks, - }, - &auth::Backend::ControlPlane(cplane, creds.info), - config.wake_compute_retry_config, - &config.connect_to_compute, - ) - .await; - - let node = match res { - Ok(node) => node, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, - }; - - let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let session = cancellation_handler_clone.get_key(); - - session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream); - let stream = stream.flush_and_into_inner().await?; - let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), None => None, }; + let (node, client, session) = handle_connect_request( + config, + auth_backend, + ctx, + cancellation_handler, + client, + &mode, + endpoint_rate_limiter, + ¶ms, + common_names, + async |config, ctx, node_info, auth_info, creds, compute_config| { + TcpMechanism { + auth: auth_info.clone(), + locks: &config.connect_compute_locks, + user_info: creds.info.clone(), + } + .connect_once(ctx, node_info, compute_config) + .await + }, + ) + .await?; + Ok(Some(ProxyPassthrough { - client: stream, + client, aux: node.aux.clone(), private_link_id, compute: node, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 80387fe9a5..bb6eb9a825 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -2,13 +2,13 @@ use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; -use crate::auth::backend::ComputeUserInfo; +use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection}; -use crate::config::{ComputeConfig, RetryConfig}; +use crate::config::{ComputeConfig, ProxyConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::{self, NodeInfo}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, @@ -185,23 +185,32 @@ where } #[tracing::instrument(skip_all)] -pub(crate) async fn connect_to_compute_pglb( +pub(crate) async fn connect_to_compute_pglb< + F: AsyncFn( + &'static ProxyConfig, + &RequestContext, + &CachedNodeInfo, + &AuthInfo, + &ComputeCredentials, + &ComputeConfig, + ) -> Result, + B: WakeComputeBackend, +>( + config: &'static ProxyConfig, ctx: &RequestContext, - mechanism: &M, + connect_compute_fn: F, user_info: &B, + auth_info: &AuthInfo, + creds: &ComputeCredentials, wake_compute_retry_config: RetryConfig, compute: &ComputeConfig, -) -> Result -where - M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug, - M::Error: From, -{ +) -> Result { let mut num_retries = 0; let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; // try once - let err = match mechanism.connect_once(ctx, &node_info, compute).await { + let err = match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await { Ok(res) => { ctx.success(); Metrics::get().proxy.retries_metric.observe( @@ -246,7 +255,7 @@ where debug!("wake_compute success. attempting to connect"); num_retries = 1; loop { - match mechanism.connect_once(ctx, &node_info, compute).await { + match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await { Ok(res) => { ctx.success(); Metrics::get().proxy.retries_metric.observe( diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 24846124c2..3d60f934b7 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -5,18 +5,127 @@ pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; +use std::collections::HashSet; +use std::sync::Arc; + use itertools::Itertools; use once_cell::sync::OnceCell; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, format_smolstr}; use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::Instrument; -use crate::compute; +use crate::auth::backend::ComputeCredentials; +use crate::cancellation::{CancellationHandler, Session}; +use crate::compute::{AuthInfo, PostgresConnection}; +use crate::config::{ComputeConfig, ProxyConfig}; +use crate::context::RequestContext; +use crate::control_plane::CachedNodeInfo; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; +use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; -use crate::stream::PqStream; +use crate::proxy::connect_compute::connect_to_compute_pglb; +use crate::rate_limiter::EndpointRateLimiter; +use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; +use crate::{auth, compute}; + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_connect_request< + S: AsyncRead + AsyncWrite + Unpin + Send, + F: AsyncFn( + &'static ProxyConfig, + &RequestContext, + &CachedNodeInfo, + &AuthInfo, + &ComputeCredentials, + &ComputeConfig, + ) -> Result, +>( + config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, ()>, + ctx: &RequestContext, + cancellation_handler: Arc, + mut client: PqStream>, + mode: &ClientMode, + endpoint_rate_limiter: Arc, + params: &StartupMessageParams, + common_names: Option<&HashSet>, + connect_compute_fn: F, +) -> Result<(PostgresConnection, Stream, Session), ClientRequestError> { + // TODO: to pglb + let hostname = mode.hostname(client.get_ref()); + + // Extract credentials which we're going to use for auth. + let result = auth_backend + .as_ref() + .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) + .transpose(); + + let user_info = match result { + Ok(user_info) => user_info, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, + }; + + let user = user_info.get_user().to_owned(); + let user_info = match user_info + .authenticate( + ctx, + &mut client, + mode.allow_cleartext(), + &config.authentication_config, + endpoint_rate_limiter, + ) + .await + { + Ok(auth_result) => auth_result, + Err(e) => { + let db = params.get("database"); + let app = params.get("application_name"); + let params_span = tracing::info_span!("", ?user, ?db, ?app); + + return Err(client + .throw_error(e, Some(ctx)) + .instrument(params_span) + .await)?; + } + }; + + let (cplane, creds) = match user_info { + auth::Backend::ControlPlane(cplane, creds) => (cplane, creds), + auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"), + }; + let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); + let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); + auth_info.set_startup_params(¶ms, params_compat); + + let res = connect_to_compute_pglb( + config, + ctx, + connect_compute_fn, + &auth::Backend::ControlPlane(cplane, creds.info), + &auth_info, + &creds, + config.wake_compute_retry_config, + &config.connect_to_compute, + ) + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, + }; + + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session = cancellation_handler_clone.get_key(); + + session.write_cancel_key(node.cancel_closure.clone())?; + prepare_client_connection(&node, *session.key(), &mut client); + let client = client.flush_and_into_inner().await?; + + Ok((node, client, session)) +} /// Finish client connection initialization: confirm auth success, send params, etc. pub(crate) fn prepare_client_connection(