Compare commits

...

16 Commits

Author SHA1 Message Date
Conrad Ludgate
631139ceeb turns out the boxing isn't necessary, we just needed to massage the stack usage properly 2025-05-30 08:47:44 +01:00
Conrad Ludgate
fd43058bd7 optimise passthrough calling convention to further reduce memory 2025-05-29 18:35:24 +01:00
Conrad Ludgate
cf07c5b5f9 dont box handle_client anymore and move spawning passthrough into handle_client so we don't need to move a heavy object in return position anymore 2025-05-29 18:20:29 +01:00
Conrad Ludgate
11bb84c38d save 1000 bytes by removing instrument 2025-05-29 17:56:25 +01:00
Conrad Ludgate
219c72c24c optimise proxy_pass memory size a little, also boxing requestcontext since it is large 2025-05-29 17:52:26 +01:00
Conrad Ludgate
0633cd6385 small changes to connect compute mechanism/backend handling 2025-05-29 16:21:55 +01:00
Conrad Ludgate
0cdb0c5704 reuse the same tracker token for websockets and http 2025-05-29 16:04:14 +01:00
Conrad Ludgate
eefac5d78b box the connect to compute task 2025-05-29 15:58:28 +01:00
Conrad Ludgate
7d1c908b1b box authenticate task 2025-05-29 15:55:17 +01:00
Conrad Ludgate
cfa2813446 remove unnecessary aux field from passthrough 2025-05-29 15:51:57 +01:00
Conrad Ludgate
034bdb1552 move more work inside handshake 2025-05-29 15:50:10 +01:00
Conrad Ludgate
8b1ffa1718 simplify cplane authentication 2025-05-29 15:46:40 +01:00
Conrad Ludgate
2d3ea77953 box the handshake task 2025-05-29 15:39:33 +01:00
Conrad Ludgate
3124729f53 spawn passthrough as a separate task to reduce influence from the handshake task 2025-05-29 15:21:54 +01:00
Conrad Ludgate
6463eb38be manually handle task tracker tokens 2025-05-29 15:19:03 +01:00
Conrad Ludgate
ae506fd791 proxy: remove unused ip return value 2025-05-29 15:04:40 +01:00
17 changed files with 457 additions and 431 deletions

View File

@@ -17,9 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
@@ -137,16 +135,6 @@ impl<'a, T> Backend<'a, T> {
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
}
}
}
pub(crate) struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
@@ -284,7 +272,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -301,15 +289,12 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
// check allowed list
let allowed_ips = if config.ip_allowlist_check_enabled {
if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
allowed_ips
} else {
Cached::new_uncached(Arc::new(vec![]))
};
}
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
@@ -368,7 +353,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
@@ -420,53 +405,39 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::Local(_) => "local",
}
}
impl ControlPlaneClient {
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
self,
&self,
ctx: &RequestContext,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
user_info: ComputeUserInfoMaybeEndpoint,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<ComputeCredentials> {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let (credentials, ip_allowlist) = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
}
};
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
// TODO: replace with some metric
info!("user successfully authenticated");
res
Ok(credentials)
}
}
@@ -536,6 +507,25 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
}
}
pub struct ControlPlaneWakeCompute<'a> {
pub cplane: &'a ControlPlaneClient,
pub creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
self.cplane.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&self.creds.keys
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
@@ -552,6 +542,7 @@ mod tests {
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio_util::task::TaskTracker;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
@@ -702,7 +693,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -784,7 +775,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -838,7 +829,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -887,7 +878,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.0.info.endpoint, "my-endpoint");
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
pub use backend::{Backend, ControlPlaneWakeCompute};
mod credentials;
pub(crate) use credentials::{

View File

@@ -18,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
@@ -226,7 +227,8 @@ pub(super) async fn task_main(
let dest_suffix = Arc::clone(&dest_suffix);
let compute_tls_config = compute_tls_config.clone();
connections.spawn(
let tracker = connections.token();
tokio::spawn(
async move {
socket
.set_nodelay(true)
@@ -249,6 +251,7 @@ pub(super) async fn task_main(
compute_tls_config,
tls_server_end_point,
socket,
tracker,
)
.await
}
@@ -274,10 +277,11 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tracker: TaskTrackerToken,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
@@ -291,7 +295,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
let (raw, read_buf, tracker) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
@@ -302,13 +306,16 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok((
Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
},
tracker,
))
}
unexpected => {
info!(
@@ -329,8 +336,10 @@ async fn handle_client(
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
tracker: TaskTrackerToken,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let (mut tls_stream, _tracker) =
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -323,7 +323,7 @@ impl CancellationHandler {
}
}
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
pub(crate) fn get_key(self: Arc<Self>) -> Session {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
@@ -340,7 +340,7 @@ impl CancellationHandler {
Session {
key,
redis_key,
cancellation_handler: Arc::clone(self),
cancellation_handler: self,
}
}

View File

@@ -1,8 +1,9 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info};
use crate::auth::backend::ConsoleRedirectBackend;
@@ -14,10 +15,8 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
use crate::proxy::passthrough::passthrough;
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -35,7 +34,6 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -49,11 +47,11 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
connections.spawn(async move {
let tracker = connections.token();
tokio::spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
@@ -103,99 +101,80 @@ pub async fn task_main(
&config.region,
);
let span = ctx.span();
let mut slot = Some(ctx);
let res = handle_client(
config,
backend,
&ctx,
&mut slot,
cancellation_handler,
socket,
conn_gauge,
cancellations,
tracker,
)
.instrument(ctx.span())
.boxed()
.instrument(span)
.await;
match res {
Err(e) => {
match (slot, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestContext,
ctx_slot: &mut Option<RequestContext>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let request_gauge = metrics.connection_requests.guard(protocol);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let data = {
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
};
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
let (mut stream, params) = match data {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
HandshakeData::Cancel(cancel_key_data, tracker) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
tokio::spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx_slot.take().expect("context must be set");
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
let _tracker = tracker;
cancellation_handler_clone
.cancel_session(
cancel_key_data,
@@ -205,15 +184,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
.ok();
}
.instrument(cancel_span)
});
return Ok(None);
return Ok(());
}
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
@@ -228,13 +209,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
TcpMechanism {
user_info,
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},
&node_info,
node_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
@@ -252,17 +233,22 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
Ok(())
}

View File

@@ -38,7 +38,7 @@ pub struct RequestContext(
/// I would typically use a RefCell but that would break the `Send` requirements
/// so we need something with thread-safety. `TryLock` is a cheap alternative
/// that offers similar semantics to a `RefCell` but with synchronisation.
TryLock<RequestContextInner>,
TryLock<Box<RequestContextInner>>,
);
struct RequestContextInner {
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
impl Clone for RequestContext {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = RequestContextInner {
let new = Box::new(RequestContextInner {
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
disconnect_sender: None,
latency_timer: LatencyTimer::noop(inner.protocol),
disconnect_timestamp: inner.disconnect_timestamp,
};
});
Self(TryLock::new(new))
}
@@ -140,7 +140,7 @@ impl RequestContext {
role = tracing::field::Empty,
);
let inner = RequestContextInner {
let inner = Box::new(RequestContextInner {
conn_info,
session_id,
protocol,
@@ -168,7 +168,7 @@ impl RequestContext {
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
disconnect_timestamp: None,
};
});
Self(TryLock::new(inner))
}
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
}
}
pub struct DisconnectLogger(RequestContextInner);
pub struct DisconnectLogger(Box<RequestContextInner>);
impl Drop for DisconnectLogger {
fn drop(&mut self) {

View File

@@ -53,6 +53,25 @@ pub(crate) trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
type Connection = T::Connection;
type ConnectError = T::ConnectError;
type Error = T::Error;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
T::connect_once(self, ctx, node_info, config).await
}
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
T::update_connect_config(self, conf);
}
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
@@ -105,8 +124,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,
mechanism: M,
backend: B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
@@ -116,9 +135,9 @@ where
{
let mut num_retries = 0;
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
node_info.set_keys(backend.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -159,7 +178,7 @@ where
let old_node_info = invalidate_cache(node_info);
// TODO: increment num_retries?
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);

View File

@@ -67,7 +67,6 @@ where
}
}
#[tracing::instrument(skip_all)]
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,

View File

@@ -5,6 +5,7 @@ use pq_proto::{
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
@@ -51,7 +52,7 @@ impl ReportableError for HandshakeError {
pub(crate) enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
Cancel(CancelKeyData, TaskTrackerToken),
}
/// Establish a (most probably, secure) connection with the client.
@@ -62,6 +63,7 @@ pub(crate) enum HandshakeData<S> {
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
stream: S,
tracker: TaskTrackerToken,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
@@ -71,7 +73,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -157,15 +159,13 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
stream.framed = Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
};
}
}
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
}
}
}

View File

@@ -10,26 +10,27 @@ pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::{FutureExt, TryFutureExt};
use futures::TryFutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use passthrough::passthrough;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use smol_str::{SmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info, warn};
use self::connect_compute::{TcpMechanism, connect_to_compute};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -70,7 +71,6 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -84,12 +84,12 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let tracker = connections.token();
tokio::spawn(async move {
let (socket, conn_info) = match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
@@ -138,60 +138,41 @@ pub async fn task_main(
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let mut ctx = Some(ctx);
let res = handle_client(
config,
auth_backend,
&ctx,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
tracker,
)
.instrument(ctx.span())
.boxed()
.instrument(span)
.await;
match res {
Err(e) => {
match (ctx, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
@@ -258,46 +239,79 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
ctx_slot: &mut Option<RequestContext>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let cplane = match auth_backend {
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let request_gauge = metrics.connection_requests.guard(protocol);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let handshake_result: Result<_, ClientRequestError> = async {
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let data = {
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
tokio::time::timeout(
config.handshake_timeout,
handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
),
)
.await??
};
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
match data {
HandshakeData::Startup(mut stream, params) => {
ctx.set_db_options(params.clone());
let host = mode.hostname(stream.get_ref());
let cn = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, host, cn);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let session = cancellation_handler.get_key();
Ok(Some((stream, params, session, user_info)))
}
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
// spawn a task to cancel the session, but don't wait for it
tokio::spawn(async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
cancellation_handler
.cancel_session(
cancel_key_data,
ctx,
@@ -305,111 +319,108 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.instrument(cancel_span)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
});
return Ok(None);
Ok(None)
}
}
}
.await;
let Some((mut stream, params, session, user_info)) = handshake_result? else {
return Ok(());
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let auth_result: Result<_, ClientRequestError> = async {
let user = user_info.user.clone();
let hostname = mode.hostname(stream.get_ref());
match cplane
.authenticate(
ctx,
&mut stream,
user_info,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{
Ok(auth_result) => Ok(auth_result),
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?
}
}
}
.await;
let common_names = tls.map(|tls| &tls.common_names);
let compute_creds = auth_result?;
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let connect_result: Result<_, ClientRequestError> = async {
let compute_user_info = compute_creds.info.clone();
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let user = user_info.get_user().to_owned();
let (user_info, _ip_allowlist) = match user_info
.authenticate(
let mut node = connect_to_compute(
ctx,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
TcpMechanism {
user_info: compute_user_info,
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
auth::ControlPlaneWakeCompute {
cplane,
creds: compute_creds,
},
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
return stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?;
}
};
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let mut node = connect_to_compute(
Ok((node, stream, tracker))
}
.await;
let (node, stream, tracker) = connect_result?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
Ok(())
}
/// Finish client connection initialization: confirm auth success, send params, etc.

View File

@@ -1,5 +1,6 @@
use smol_str::SmolStr;
use smol_str::{SmolStr, ToSmolStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use utils::measured_stream::MeasuredStream;
@@ -7,13 +8,14 @@ use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::protocol2::ConnectionInfoExtra;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
@@ -61,41 +63,53 @@ pub(crate) async fn proxy_pass(
Ok(())
}
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
#[allow(clippy::too_many_arguments)]
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
ctx: RequestContext,
compute_config: &'static ComputeConfig,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
}
client: Stream<S>,
compute: PostgresConnection,
cancel: cancellation::Session,
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
_req: NumConnectionRequestsGuard<'static>,
_conn: NumClientConnectionsGuard<'static>,
_tracker: TaskTrackerToken,
) {
let session_id = ctx.session_id();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let _disconnect = ctx.log_connect();
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
res
}
if let Err(err) = compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
}
// we don't need a result. If the queue is full, we just log the error
drop(cancel.remove_cancel_key());
}

View File

@@ -38,6 +38,7 @@ async fn proxy_mitm(
let (end_client, startup) = match handshake(
&RequestContext::test(),
client1,
TaskTracker::new().token(),
Some(&server_config1),
false,
)
@@ -45,7 +46,7 @@ async fn proxy_mitm(
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);

View File

@@ -15,6 +15,7 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_util::task::TaskTracker;
use tracing_test::traced_test;
use super::connect_compute::ConnectMechanism;
@@ -178,10 +179,12 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
let t = TaskTracker::new().token();
let mut stream =
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -622,7 +625,7 @@ async fn connect_to_compute_success() {
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -636,7 +639,7 @@ async fn connect_to_compute_retry() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -651,7 +654,7 @@ async fn connect_to_compute_non_retry_1() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -666,7 +669,7 @@ async fn connect_to_compute_non_retry_2() {
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -691,7 +694,7 @@ async fn connect_to_compute_non_retry_3() {
connect_to_compute(
&ctx,
&mechanism,
&user_info,
user_info,
wake_compute_retry_config,
&config,
)
@@ -709,7 +712,7 @@ async fn wake_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -724,7 +727,7 @@ async fn wake_non_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -743,7 +746,7 @@ async fn fail_but_wake_invalidates_cache() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
.await
.unwrap();
@@ -764,7 +767,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
.await
.unwrap();
@@ -785,7 +788,7 @@ async fn retry_but_wake_invalidates_cache() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
@@ -808,7 +811,7 @@ async fn retry_no_wake_skips_invalidation() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -224,13 +224,13 @@ impl PoolingBackend {
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
@@ -268,13 +268,13 @@ impl PoolingBackend {
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)

View File

@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, info, warn};
use crate::cancellation::CancellationHandler;
@@ -124,7 +124,6 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -150,11 +149,11 @@ pub async fn task_main(
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
connections.spawn(
let tracker = connections.token();
tokio::spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
@@ -181,8 +180,7 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
backend,
connections2,
cancellations,
tracker,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -305,8 +303,7 @@ async fn connection_startup(
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellations: TaskTracker,
tracker: TaskTrackerToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -347,19 +344,17 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let cancellations = cancellations.clone();
let handler = connections.spawn(
let handler = tokio::spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
tracker.clone(),
cancellation_handler.clone(),
session_id,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -400,14 +395,13 @@ async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
tracker: TaskTrackerToken,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -441,10 +435,17 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let cancellations = cancellations.clone();
ws_connections.spawn(
tokio::spawn(
async move {
if let Err(e) = websocket::serve_websocket(
let websocket = match websocket.await {
Err(e) => {
warn!("could not upgrade websocket connection: {e:#}");
return;
}
Ok(websocket) => websocket,
};
websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
@@ -452,12 +453,9 @@ async fn request_handler(
cancellation_handler,
endpoint_rate_limiter,
host,
cancellations,
tracker,
)
.await
{
warn!("error in websocket connection: {e:#}");
}
.await;
}
.instrument(span),
);

View File

@@ -2,14 +2,14 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::warn;
use crate::cancellation::CancellationHandler;
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::proxy::{ClientMode, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -128,13 +128,12 @@ pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ctx: RequestContext,
websocket: OnUpgrade,
websocket: Upgraded,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
tracker: TaskTrackerToken,
) {
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
@@ -142,36 +141,28 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client(
let mut ctx_slot = Some(ctx);
let res = handle_client(
config,
auth_backend,
&ctx,
&mut ctx_slot,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
cancellations,
))
tracker,
)
.await;
match res {
Err(e) => {
match (ctx_slot, res) {
(None, _) => {}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
Err(e.into())
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
(Some(ctx), Ok(())) => {
ctx.set_success();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
}
}

View File

@@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use crate::control_plane::messages::ColdStartInfo;
@@ -24,19 +25,22 @@ use crate::tls::TlsServerEndPoint;
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
pub(crate) tracker: TaskTrackerToken,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
Self {
framed: Framed::new(stream),
tracker,
}
}
/// Extract the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
self.framed.into_inner()
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
let (stream, read) = self.framed.into_inner();
(stream, read, self.tracker)
}
/// Get a shared reference to the underlying stream.