proxy: Move client connection accept and handshake to pglb (#12380)

* This must be a no-op.
* Move proxy::task_main to pglb::task_main.
* Move client accept, TLS and handshake to pglb.
* Keep auth and wake in proxy.
This commit is contained in:
Folke Behrens
2025-06-27 17:20:23 +02:00
committed by GitHub
parent 4c7956fa56
commit 0ee15002fc
9 changed files with 388 additions and 327 deletions

View File

@@ -26,9 +26,10 @@ use utils::sentry_init::init_sentry;
use crate::context::RequestContext; use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics}; use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket; use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo; use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute}; use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
use crate::stream::{PqStream, Stream}; use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled; use crate::util::run_until_cancelled;

View File

@@ -392,7 +392,7 @@ pub async fn run() -> anyhow::Result<()> {
match auth_backend { match auth_backend {
Either::Left(auth_backend) => { Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener { if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main( client_tasks.spawn(crate::pglb::task_main(
config, config,
auth_backend, auth_backend,
proxy_listener, proxy_listener,

View File

@@ -11,11 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext; use crate::context::RequestContext;
use crate::error::ReportableError; use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard}; use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough; use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection}; use crate::proxy::{ErrorSource, finish_client_init};
use crate::util::run_until_cancelled; use crate::util::run_until_cancelled;
pub async fn task_main( pub async fn task_main(
@@ -232,7 +233,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let session = cancellation_handler.get_key(); let session = cancellation_handler.get_key();
prepare_client_connection(&pg_settings, *session.key(), &mut stream); finish_client_init(&pg_settings, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?; let stream = stream.flush_and_into_inner().await?;
let session_id = ctx.session_id(); let session_id = ctx.session_id();

View File

@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
use crate::context::RequestContext; use crate::context::RequestContext;
use crate::error::ReportableError; use crate::error::ReportableError;
use crate::metrics::Metrics; use crate::metrics::Metrics;
use crate::pglb::TlsRequired;
use crate::pqproto::{ use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
}; };
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL; use crate::tls::PG_ALPN_PROTOCOL;

View File

@@ -2,3 +2,332 @@ pub mod copy_bidirectional;
pub mod handshake; pub mod handshake;
pub mod inprocess; pub mod inprocess;
pub mod passthrough; pub mod passthrough;
use std::sync::Arc;
use futures::FutureExt;
use smol_str::ToSmolStr;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::auth;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::ErrorSource;
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handle_client;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::util::run_until_cancelled;
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_connection(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
client: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(client, params) => (client, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let common_names = tls.map(|tls| &tls.common_names);
let (node, cancel_on_shutdown) = handle_client(
config,
auth_backend,
ctx,
cancellation_handler,
&mut client,
&mode,
endpoint_rate_limiter,
common_names,
&params,
)
.await?;
let client = client.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client,
compute: node.stream,
aux: node.aux,
private_link_id,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}

View File

@@ -5,326 +5,64 @@ pub(crate) mod connect_compute;
pub(crate) mod retry; pub(crate) mod retry;
pub(crate) mod wake_compute; pub(crate) mod wake_compute;
use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::Arc; use std::sync::Arc;
use futures::FutureExt;
use itertools::Itertools; use itertools::Itertools;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use smol_str::{SmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken; use tokio::sync::oneshot;
use tracing::{Instrument, debug, error, info, warn}; use tracing::Instrument;
use crate::cache::Cache; use crate::cache::Cache;
use crate::cancellation::{self, CancellationHandler}; use crate::cancellation::CancellationHandler;
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::compute::ComputeConnection;
use crate::config::ProxyConfig;
use crate::context::RequestContext; use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::client::ControlPlaneClient;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; use crate::pglb::{ClientMode, ClientRequestError};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::retry::ShouldRetryWakeCompute; use crate::proxy::retry::ShouldRetryWakeCompute;
use crate::rate_limiter::EndpointRateLimiter; use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream}; use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey; use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute}; use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub(crate) fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>( pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig, config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>, auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext, ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>, cancellation_handler: Arc<CancellationHandler>,
stream: S, client: &mut PqStream<Stream<S>>,
mode: ClientMode, mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>, common_names: Option<&HashSet<String>>,
cancellations: tokio_util::task::task_tracker::TaskTracker, params: &StartupMessageParams,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> { ) -> Result<(ComputeConnection, oneshot::Sender<Infallible>), ClientRequestError> {
debug!( let hostname = mode.hostname(client.get_ref());
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth. // Extract credentials which we're going to use for auth.
let result = auth_backend let result = auth_backend
.as_ref() .as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names)) .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names))
.transpose(); .transpose();
let user_info = match result { let user_info = match result {
Ok(user_info) => user_info, Ok(user_info) => user_info,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
}; };
let user = user_info.get_user().to_owned(); let user = user_info.get_user().to_owned();
let user_info = match user_info let user_info = match user_info
.authenticate( .authenticate(
ctx, ctx,
&mut stream, client,
mode.allow_cleartext(), mode.allow_cleartext(),
&config.authentication_config, &config.authentication_config,
endpoint_rate_limiter, endpoint_rate_limiter,
@@ -337,7 +75,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let app = params.get("application_name"); let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app); let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(stream return Err(client
.throw_error(e, Some(ctx)) .throw_error(e, Some(ctx))
.instrument(params_span) .instrument(params_span)
.await)?; .await)?;
@@ -350,7 +88,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}; };
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat); auth_info.set_startup_params(params, params_compat);
let mut node; let mut node;
let mut attempt = 0; let mut attempt = 0;
@@ -370,6 +108,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let pg_settings = loop { let pg_settings = loop {
attempt += 1; attempt += 1;
// TODO: callback to pglb
let res = connect_to_compute( let res = connect_to_compute(
ctx, ctx,
&connect, &connect,
@@ -381,7 +120,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
match res { match res {
Ok(n) => node = n, Ok(n) => node = n,
Err(e) => return Err(stream.throw_error(e, Some(ctx)).await)?, Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?,
} }
let auth::Backend::ControlPlane(cplane, user_info) = &backend else { let auth::Backend::ControlPlane(cplane, user_info) = &backend else {
@@ -400,17 +139,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
cplane_proxy_v1.caches.node_info.invalidate(&key); cplane_proxy_v1.caches.node_info.invalidate(&key);
} }
} }
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
} }
}; };
let session = cancellation_handler.get_key(); let session = cancellation_handler.get_key();
prepare_client_connection(&pg_settings, *session.key(), &mut stream); finish_client_init(&pg_settings, *session.key(), client);
let stream = stream.flush_and_into_inner().await?;
let session_id = ctx.session_id(); let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); let (cancel_on_shutdown, cancel) = oneshot::channel();
tokio::spawn(async move { tokio::spawn(async move {
session session
.maintain_cancel_key( .maintain_cancel_key(
@@ -422,50 +160,32 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await; .await;
}); });
let private_link_id = match ctx.extra() { Ok((node, cancel_on_shutdown))
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
compute: node.stream,
aux: node.aux,
private_link_id,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
} }
/// Finish client connection initialization: confirm auth success, send params, etc. /// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn prepare_client_connection( pub(crate) fn finish_client_init(
settings: &compute::PostgresSettings, settings: &compute::PostgresSettings,
cancel_key_data: CancelKeyData, cancel_key_data: CancelKeyData,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) { ) {
// Forward all deferred notices to the client. // Forward all deferred notices to the client.
for notice in &settings.delayed_notice { for notice in &settings.delayed_notice {
stream.write_raw(notice.as_bytes().len(), b'N', |buf| { client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes()); buf.extend_from_slice(notice.as_bytes());
}); });
} }
// Forward all postgres connection params to the client. // Forward all postgres connection params to the client.
for (name, value) in &settings.params { for (name, value) in &settings.params {
stream.write_message(BeMessage::ParameterStatus { client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(), name: name.as_bytes(),
value: value.as_bytes(), value: value.as_bytes(),
}); });
} }
stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); client.write_message(BeMessage::BackendKeyData(cancel_key_data));
stream.write_message(BeMessage::ReadyForQuery); client.write_message(BeMessage::ReadyForQuery);
} }
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
@@ -475,7 +195,7 @@ impl NeonOptions {
// proxy options: // proxy options:
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute. /// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
const PARAMS_COMPAT: &str = "proxy_params_compat"; pub const PARAMS_COMPAT: &str = "proxy_params_compat";
// cplane options: // cplane options:

View File

@@ -14,6 +14,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use super::*; use super::*;
use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::pglb::handshake::{HandshakeData, handshake};
enum Intercept { enum Intercept {
None, None,

View File

@@ -3,6 +3,7 @@
mod mitm; mod mitm;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::{Context, bail}; use anyhow::{Context, bail};
@@ -10,26 +11,31 @@ use async_trait::async_trait;
use http::StatusCode; use http::StatusCode;
use postgres_client::config::SslMode; use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls}; use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest; use rstest::rstest;
use rustls::crypto::ring; use rustls::crypto::ring;
use rustls::pki_types; use rustls::pki_types;
use tokio::io::DuplexStream; use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
use tracing_test::traced_test; use tracing_test::traced_test;
use super::retry::CouldRetry; use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::config::{ComputeConfig, RetryConfig}; use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind; use crate::error::{ErrorKind, ReportableError};
use crate::proxy::connect_compute::ConnectMechanism; use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after};
use crate::stream::{PqStream, Stream};
use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver; use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId}; use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram}; use crate::{auth, compute, sasl, scram};
/// Generate a set of TLS certificates: CA + server. /// Generate a set of TLS certificates: CA + server.
fn generate_certs( fn generate_certs(

View File

@@ -17,7 +17,8 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext; use crate::context::RequestContext;
use crate::error::ReportableError; use crate::error::ReportableError;
use crate::metrics::Metrics; use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client}; use crate::pglb::{ClientMode, handle_connection};
use crate::proxy::ErrorSource;
use crate::rate_limiter::EndpointRateLimiter; use crate::rate_limiter::EndpointRateLimiter;
pin_project! { pin_project! {
@@ -142,7 +143,7 @@ pub(crate) async fn serve_websocket(
.client_connections .client_connections
.guard(crate::metrics::Protocol::Ws); .guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client( let res = Box::pin(handle_connection(
config, config,
auth_backend, auth_backend,
&ctx, &ctx,