Split handle_client and pass async callback for connect_once

This commit is contained in:
Folke Behrens
2025-06-10 19:41:07 +02:00
parent 0957c8ea69
commit c90b082222
4 changed files with 165 additions and 93 deletions

View File

@@ -103,6 +103,8 @@ pub enum Auth {
}
/// A config for authenticating to the compute node.
// TODO: avoid Clone
#[derive(Clone)]
pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.

View File

@@ -12,6 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::auth;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
@@ -21,12 +22,11 @@ pub use crate::pglb::copy_bidirectional::ErrorSource;
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute_pglb};
use crate::proxy::{NeonOptions, prepare_client_connection};
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
use crate::proxy::handle_connect_request;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
@@ -193,7 +193,7 @@ impl ClientMode {
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
@@ -246,7 +246,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
client: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
@@ -266,12 +266,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
let (client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(client, params) => (client, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
@@ -300,86 +300,38 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await)?;
}
};
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute_pglb(
ctx,
&TcpMechanism {
user_info: creds.info.clone(),
auth: auth_info,
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info),
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
let node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let (node, client, session) = handle_connect_request(
config,
auth_backend,
ctx,
cancellation_handler,
client,
&mode,
endpoint_rate_limiter,
&params,
common_names,
async |config, ctx, node_info, auth_info, creds, compute_config| {
TcpMechanism {
auth: auth_info.clone(),
locks: &config.connect_compute_locks,
user_info: creds.info.clone(),
}
.connect_once(ctx, node_info, compute_config)
.await
},
)
.await?;
Ok(Some(ProxyPassthrough {
client: stream,
client,
aux: node.aux.clone(),
private_link_id,
compute: node,

View File

@@ -2,13 +2,13 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
@@ -185,23 +185,32 @@ where
}
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute_pglb<M: ConnectMechanism, B: WakeComputeBackend>(
pub(crate) async fn connect_to_compute_pglb<
F: AsyncFn(
&'static ProxyConfig,
&RequestContext,
&CachedNodeInfo,
&AuthInfo,
&ComputeCredentials,
&ComputeConfig,
) -> Result<PostgresConnection, compute::ConnectionError>,
B: WakeComputeBackend,
>(
config: &'static ProxyConfig,
ctx: &RequestContext,
mechanism: &M,
connect_compute_fn: F,
user_info: &B,
auth_info: &AuthInfo,
creds: &ComputeCredentials,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
) -> Result<PostgresConnection, compute::ConnectionError> {
let mut num_retries = 0;
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
// try once
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
let err = match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -246,7 +255,7 @@ where
debug!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism.connect_once(ctx, &node_info, compute).await {
match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(

View File

@@ -5,18 +5,127 @@ pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::collections::HashSet;
use std::sync::Arc;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, format_smolstr};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::Instrument;
use crate::compute;
use crate::auth::backend::ComputeCredentials;
use crate::cancellation::{CancellationHandler, Session};
use crate::compute::{AuthInfo, PostgresConnection};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::{ClientMode, ClientRequestError};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::stream::PqStream;
use crate::proxy::connect_compute::connect_to_compute_pglb;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::{auth, compute};
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_connect_request<
S: AsyncRead + AsyncWrite + Unpin + Send,
F: AsyncFn(
&'static ProxyConfig,
&RequestContext,
&CachedNodeInfo,
&AuthInfo,
&ComputeCredentials,
&ComputeConfig,
) -> Result<PostgresConnection, compute::ConnectionError>,
>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
mut client: PqStream<Stream<S>>,
mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
params: &StartupMessageParams,
common_names: Option<&HashSet<String>>,
connect_compute_fn: F,
) -> Result<(PostgresConnection, Stream<S>, Session), ClientRequestError> {
// TODO: to pglb
let hostname = mode.hostname(client.get_ref());
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
&mut client,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(client
.throw_error(e, Some(ctx))
.instrument(params_span)
.await)?;
}
};
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute_pglb(
config,
ctx,
connect_compute_fn,
&auth::Backend::ControlPlane(cplane, creds.info),
&auth_info,
&creds,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
let node = match res {
Ok(node) => node,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut client);
let client = client.flush_and_into_inner().await?;
Ok((node, client, session))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn prepare_client_connection(