diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index e7ce867111..a211d0b65a 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -547,6 +547,7 @@ mod tests { use postgres_protocol::message::backend::Message as PgMessage; use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use tokio_util::task::TaskTracker; use super::jwt::JwkCache; use super::{AuthRateLimiter, auth_quirks}; @@ -697,7 +698,7 @@ mod tests { #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token()); let ctx = RequestContext::test(); let api = Auth { @@ -779,7 +780,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token()); let ctx = RequestContext::test(); let api = Auth { @@ -833,7 +834,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token()); let ctx = RequestContext::test(); let api = Auth { diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..03ea34cc1f 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -18,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; use tokio_util::sync::CancellationToken; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::{Instrument, error, info}; use utils::project_git_version; use utils::sentry_init::init_sentry; @@ -226,7 +227,8 @@ pub(super) async fn task_main( let dest_suffix = Arc::clone(&dest_suffix); let compute_tls_config = compute_tls_config.clone(); - connections.spawn( + let tracker = connections.token(); + tokio::spawn( async move { socket .set_nodelay(true) @@ -249,6 +251,7 @@ pub(super) async fn task_main( compute_tls_config, tls_server_end_point, socket, + tracker, ) .await } @@ -274,10 +277,11 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, + tracker: TaskTrackerToken, tls_config: Arc, tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); +) -> anyhow::Result<(Stream, TaskTrackerToken)> { + let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker); let msg = stream.read_startup_packet().await?; use pq_proto::FeStartupPacket::SslRequest; @@ -291,7 +295,7 @@ async fn ssl_handshake( // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let (raw, read_buf) = stream.into_inner(); + let (raw, read_buf, tracker) = stream.into_inner(); // TODO: Normally, client doesn't send any data before // server says TLS handshake is ok and read_buf is empty. // However, you could imagine pipelining of postgres @@ -302,13 +306,16 @@ async fn ssl_handshake( bail!("data is sent before server replied with EncryptionResponse"); } - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(( + Stream::Tls { + tls: Box::new( + raw.upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?, + ), + tls_server_end_point, + }, + tracker, + )) } unexpected => { info!( @@ -329,8 +336,10 @@ async fn handle_client( compute_tls_config: Option>, tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, + tracker: TaskTrackerToken, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let (mut tls_stream, _tracker) = + ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..cd7542093f 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::{Instrument, debug, error, info}; use crate::auth::backend::ConsoleRedirectBackend; @@ -35,7 +36,6 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -49,11 +49,11 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); - let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); - connections.spawn(async move { + let tracker = connections.token(); + tokio::spawn(async move { let (socket, peer_addr) = match read_proxy_protocol(socket).await { Err(e) => { error!("per-client task finished with an error: {e:#}"); @@ -110,7 +110,7 @@ pub async fn task_main( cancellation_handler, socket, conn_gauge, - cancellations, + tracker, ) .instrument(ctx.span()) .boxed() @@ -148,12 +148,10 @@ pub async fn task_main( } connections.close(); - cancellations.close(); drop(listener); // Drain connections connections.wait().await; - cancellations.wait().await; Ok(()) } @@ -166,7 +164,7 @@ pub(crate) async fn handle_client( cancellation_handler: Arc, stream: S, conn_gauge: NumClientConnectionsGuard<'static>, - cancellations: tokio_util::task::task_tracker::TaskTracker, + tracker: TaskTrackerToken, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -182,20 +180,21 @@ pub(crate) async fn handle_client( let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, tls, record_handshake_error); + let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) .await?? { HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(cancel_key_data) => { + HandshakeData::Cancel(cancel_key_data, tracker) => { // spawn a task to cancel the session, but don't wait for it - cancellations.spawn({ - let cancellation_handler_clone = Arc::clone(&cancellation_handler); + tokio::spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); let ctx = ctx.clone(); let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); cancel_span.follows_from(tracing::Span::current()); async move { + let _tracker = tracker; cancellation_handler_clone .cancel_session( cancel_key_data, @@ -205,8 +204,10 @@ pub(crate) async fn handle_client( backend.get_api(), ) .await - .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); - }.instrument(cancel_span) + .inspect_err(|e| debug!(error = ?e, "cancel_session failed")) + .ok(); + } + .instrument(cancel_span) }); return Ok(None); @@ -252,7 +253,7 @@ pub(crate) async fn handle_client( // PqStream input buffer. Normally there is none, but our serverless npm // driver in pipeline mode sends startup, password and first query // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); + let (stream, read_buf, tracker) = stream.into_inner(); node.stream.write_all(&read_buf).await?; Ok(Some(ProxyPassthrough { @@ -264,5 +265,6 @@ pub(crate) async fn handle_client( cancel: session, _req: request_gauge, _conn: conn_gauge, + _tracker: tracker, })) } diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 54c02f2c15..79f8d06f61 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -5,6 +5,7 @@ use pq_proto::{ }; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::{debug, info, warn}; use crate::auth::endpoint_sni; @@ -51,7 +52,7 @@ impl ReportableError for HandshakeError { pub(crate) enum HandshakeData { Startup(PqStream>, StartupMessageParams), - Cancel(CancelKeyData), + Cancel(CancelKeyData, TaskTrackerToken), } /// Establish a (most probably, secure) connection with the client. @@ -62,6 +63,7 @@ pub(crate) enum HandshakeData { pub(crate) async fn handshake( ctx: &RequestContext, stream: S, + tracker: TaskTrackerToken, mut tls: Option<&TlsConfig>, record_handshake_error: bool, ) -> Result, HandshakeError> { @@ -71,7 +73,7 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let mut stream = PqStream::new(Stream::from_raw(stream), tracker); loop { let msg = stream.read_startup_packet().await?; match msg { @@ -157,15 +159,13 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, + stream.framed = Framed { + stream: Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, }, + read_buf, + write_buf, }; } } @@ -248,7 +248,7 @@ pub(crate) async fn handshake( } FeStartupPacket::CancelRequest(cancel_key_data) => { info!(session_type = "cancellation", "successful handshake"); - break Ok(HandshakeData::Cancel(cancel_key_data)); + break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker)); } } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 3423538c92..dab43b9289 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -20,6 +20,7 @@ use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::{Instrument, debug, error, info, warn}; use self::connect_compute::{TcpMechanism, connect_to_compute}; @@ -70,7 +71,6 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -84,12 +84,12 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); - let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); - connections.spawn(async move { + let tracker = connections.token(); + tokio::spawn(async move { let (socket, conn_info) = match read_proxy_protocol(socket).await { Err(e) => { warn!("per-client task finished with an error: {e:#}"); @@ -148,7 +148,7 @@ pub async fn task_main( ClientMode::Tcp, endpoint_rate_limiter2, conn_gauge, - cancellations, + tracker, ) .instrument(ctx.span()) .boxed() @@ -186,12 +186,10 @@ pub async fn task_main( } connections.close(); - cancellations.close(); drop(listener); // Drain connections connections.wait().await; - cancellations.wait().await; Ok(()) } @@ -267,7 +265,7 @@ pub(crate) async fn handle_client( mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, - cancellations: tokio_util::task::task_tracker::TaskTracker, + tracker: TaskTrackerToken, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -283,20 +281,29 @@ pub(crate) async fn handle_client( let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); + let do_handshake = handshake( + ctx, + stream, + tracker, + mode.handshake_tls(tls), + record_handshake_error, + ); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) .await?? { HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(cancel_key_data) => { + HandshakeData::Cancel(cancel_key_data, tracker) => { // spawn a task to cancel the session, but don't wait for it - cancellations.spawn({ + tokio::spawn({ let cancellation_handler_clone = Arc::clone(&cancellation_handler); let ctx = ctx.clone(); let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); cancel_span.follows_from(tracing::Span::current()); async move { + // ensure the proxy doesn't shutdown until we complete this task. + let _tracker = tracker; + cancellation_handler_clone .cancel_session( cancel_key_data, @@ -306,8 +313,10 @@ pub(crate) async fn handle_client( auth_backend.get_api(), ) .await - .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); - }.instrument(cancel_span) + .inspect_err(|e| debug!(error = ?e, "cancel_session failed")) + .ok(); + } + .instrument(cancel_span) }); return Ok(None); @@ -391,7 +400,7 @@ pub(crate) async fn handle_client( // PqStream input buffer. Normally there is none, but our serverless npm // driver in pipeline mode sends startup, password and first query // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); + let (stream, read_buf, tracker) = stream.into_inner(); node.stream.write_all(&read_buf).await?; let private_link_id = match ctx.extra() { @@ -409,6 +418,7 @@ pub(crate) async fn handle_client( cancel: session, _req: request_gauge, _conn: conn_gauge, + _tracker: tracker, })) } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 8f9bd2de2d..6ece89994a 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,5 +1,6 @@ use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::debug; use utils::measured_stream::MeasuredStream; @@ -71,6 +72,8 @@ pub(crate) struct ProxyPassthrough { pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, + /// ensures proxy stays online while this is set. + pub(crate) _tracker: TaskTrackerToken, } impl ProxyPassthrough { diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..1c4dc5c3db 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -38,6 +38,7 @@ async fn proxy_mitm( let (end_client, startup) = match handshake( &RequestContext::test(), client1, + TaskTracker::new().token(), Some(&server_config1), false, ) @@ -45,7 +46,7 @@ async fn proxy_mitm( .unwrap() { HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(_) => panic!("cancellation not supported"), + HandshakeData::Cancel(_, _) => panic!("cancellation not supported"), }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..f28982df60 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -15,6 +15,7 @@ use rstest::rstest; use rustls::crypto::ring; use rustls::pki_types; use tokio::io::DuplexStream; +use tokio_util::task::TaskTracker; use tracing_test::traced_test; use super::connect_compute::ConnectMechanism; @@ -178,10 +179,12 @@ async fn dummy_proxy( auth: impl TestAuth + Send, ) -> anyhow::Result<()> { let (client, _) = read_proxy_protocol(client).await?; - let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? { - HandshakeData::Startup(stream, _) => stream, - HandshakeData::Cancel(_) => bail!("cancellation not supported"), - }; + let t = TaskTracker::new().token(); + let mut stream = + match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? { + HandshakeData::Startup(stream, _) => stream, + HandshakeData::Cancel(_, _) => bail!("cancellation not supported"), + }; auth.authenticate(&mut stream).await?; diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 2a7069b1c2..364490e70d 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -124,7 +124,6 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` - let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { let (conn, peer_addr) = res.context("could not accept TCP stream")?; if let Err(e) = conn.set_nodelay(true) { @@ -153,7 +152,6 @@ pub async fn task_main( let connections2 = connections.clone(); let cancellation_handler = cancellation_handler.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - let cancellations = cancellations.clone(); connections.spawn( async move { let conn_token2 = conn_token.clone(); @@ -182,7 +180,6 @@ pub async fn task_main( config, backend, connections2, - cancellations, cancellation_handler, endpoint_rate_limiter, conn_token, @@ -306,7 +303,6 @@ async fn connection_handler( config: &'static ProxyConfig, backend: Arc, connections: TaskTracker, - cancellations: TaskTracker, cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, @@ -347,7 +343,6 @@ async fn connection_handler( // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. // By spawning the future, we ensure it never gets cancelled until it decides to. - let cancellations = cancellations.clone(); let handler = connections.spawn( request_handler( req, @@ -359,7 +354,6 @@ async fn connection_handler( conn_info2.clone(), http_request_token, endpoint_rate_limiter.clone(), - cancellations, ) .in_current_span() .map_ok_or_else(api_error_into_response, |r| r), @@ -407,7 +401,6 @@ async fn request_handler( // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellations: TaskTracker, ) -> Result>, ApiError> { let host = request .headers() @@ -441,8 +434,8 @@ async fn request_handler( let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) .map_err(|e| ApiError::BadRequest(e.into()))?; - let cancellations = cancellations.clone(); - ws_connections.spawn( + let tracker = ws_connections.token(); + tokio::spawn( async move { if let Err(e) = websocket::serve_websocket( config, @@ -452,7 +445,7 @@ async fn request_handler( cancellation_handler, endpoint_rate_limiter, host, - cancellations, + tracker, ) .await { diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 8648a94869..d0d207ea40 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -10,6 +10,7 @@ use hyper::upgrade::OnUpgrade; use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::warn; use crate::cancellation::CancellationHandler; @@ -132,7 +133,7 @@ pub(crate) async fn serve_websocket( cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, - cancellations: tokio_util::task::task_tracker::TaskTracker, + tracker: TaskTrackerToken, ) -> anyhow::Result<()> { let websocket = websocket.await?; let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket)); @@ -151,7 +152,7 @@ pub(crate) async fn serve_websocket( ClientMode::Websockets { hostname }, endpoint_rate_limiter, conn_gauge, - cancellations, + tracker, )) .await; diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..053db76713 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::debug; use crate::control_plane::messages::ColdStartInfo; @@ -24,19 +25,22 @@ use crate::tls::TlsServerEndPoint; /// to pass random malformed bytes through the connection). pub struct PqStream { pub(crate) framed: Framed, + pub(crate) tracker: TaskTrackerToken, } impl PqStream { /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn new(stream: S, tracker: TaskTrackerToken) -> Self { Self { framed: Framed::new(stream), + tracker, } } /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() + pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) { + let (stream, read) = self.framed.into_inner(); + (stream, read, self.tracker) } /// Get a shared reference to the underlying stream.