proxy simplify cancellation (#5916)

## Problem

The cancellation code was confusing and error prone (as seen before in
our memory leaks).

## Summary of changes

* Use the new `TaskTracker` primitve instead of JoinSet to gracefully
wait for tasks to shutdown.
* Updated libs/utils/completion to use `TaskTracker`
* Remove `tokio::select` in favour of `futures::future::select` in a
specialised `run_until_cancelled()` helper function
This commit is contained in:
Conrad Ludgate
2023-12-08 16:21:17 +00:00
committed by GitHub
parent f5b9af6ac7
commit e1a564ace2
9 changed files with 130 additions and 126 deletions

View File

@@ -8,6 +8,7 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::proxy::run_until_cancelled;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
@@ -20,7 +21,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
use tracing::{error, info, warn, Instrument};
use tracing::{error, info, Instrument};
project_git_version!(GIT_VERSION);
@@ -151,63 +152,39 @@ async fn task_main(
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let mut connections = tokio::task::JoinSet::new();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
loop {
tokio::select! {
accept_result = listener.accept() => {
let (socket, peer_addr) = accept_result?;
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let tls_config = Arc::clone(&tls_config);
let dest_suffix = Arc::clone(&dest_suffix);
let session_id = uuid::Uuid::new_v4();
let tls_config = Arc::clone(&tls_config);
let dest_suffix = Arc::clone(&dest_suffix);
connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(tracing::info_span!("handle_client", ?session_id))
);
info!(%peer_addr, "serving");
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
}
// Don't modify this unless you read https://docs.rs/tokio/latest/tokio/macro.select.html carefully.
// If this future completes and the pattern doesn't match, this branch is disabled for this call to `select!`.
// This only counts for this loop and it will be enabled again on next `select!`.
//
// Prior code had this as `Some(Err(e))` which _looks_ equivalent to the current setup, but it's not.
// When `connections.join_next()` returned `Some(Ok(()))` (which we expect), it would disable the join_next and it would
// not get called again, even if there are more connections to remove.
Some(res) = connections.join_next() => {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
}
_ = cancellation_token.cancelled() => {
drop(listener);
break;
}
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(tracing::info_span!("handle_client", ?session_id)),
);
}
// Drain connections
info!("waiting for all client connections to finish");
while let Some(res) = connections.join_next().await {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
}
connections.close();
drop(listener);
connections.wait().await;
info!("all client connections have finished");
Ok(())
}

View File

@@ -277,6 +277,21 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match futures::future::select(
std::pin::pin!(f),
std::pin::pin!(cancellation_token.cancelled()),
)
.await
{
futures::future::Either::Left((f, _)) => Some(f),
futures::future::Either::Right(((), _)) => None,
}
}
pub async fn task_main(
config: &'static ProxyConfig,
listener: tokio::net::TcpListener,
@@ -290,71 +305,62 @@ pub async fn task_main(
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let mut connections = tokio::task::JoinSet::new();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancel_map = Arc::new(CancelMap::default());
loop {
tokio::select! {
accept_result = listener.accept() => {
let (socket, peer_addr) = accept_result?;
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
connections.spawn(
async move {
info!("accepted postgres client connection");
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
connections.spawn(
async move {
info!("accepted postgres client connection");
let mut socket = WithClientIp::new(socket);
let mut peer_addr = peer_addr;
if let Some(ip) = socket.wait_for_addr().await? {
peer_addr = ip;
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
} else if config.require_client_ip {
bail!("missing required client IP");
}
socket
.inner
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr.ip()).await
}
.instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
error!(?session_id, "per-client task finished with an error: {e:#}");
}),
);
}
// Don't modify this unless you read https://docs.rs/tokio/latest/tokio/macro.select.html carefully.
// If this future completes and the pattern doesn't match, this branch is disabled for this call to `select!`.
// This only counts for this loop and it will be enabled again on next `select!`.
//
// Prior code had this as `Some(Err(e))` which _looks_ equivalent to the current setup, but it's not.
// When `connections.join_next()` returned `Some(Ok(()))` (which we expect), it would disable the join_next and it would
// not get called again, even if there are more connections to remove.
Some(res) = connections.join_next() => {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
let mut socket = WithClientIp::new(socket);
let mut peer_addr = peer_addr;
if let Some(ip) = socket.wait_for_addr().await? {
peer_addr = ip;
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
} else if config.require_client_ip {
bail!("missing required client IP");
}
socket
.inner
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(
config,
&cancel_map,
session_id,
socket,
ClientMode::Tcp,
peer_addr.ip(),
)
.await
}
_ = cancellation_token.cancelled() => {
drop(listener);
break;
}
}
.instrument(info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty
))
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
error!(?session_id, "per-client task finished with an error: {e:#}");
}),
);
}
connections.close();
drop(listener);
// Drain connections
while let Some(res) = connections.join_next().await {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
}
connections.wait().await;
Ok(())
}

View File

@@ -10,6 +10,7 @@ use anyhow::bail;
use hyper::StatusCode;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
@@ -70,6 +71,9 @@ pub async fn task_main(
incoming: addr_incoming,
};
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
@@ -86,6 +90,7 @@ pub async fn task_main(
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let ws_connections = ws_connections.clone();
async move {
let peer_addr = match client_addr {
@@ -97,6 +102,7 @@ pub async fn task_main(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let ws_connections = ws_connections.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
@@ -106,6 +112,7 @@ pub async fn task_main(
req,
config,
conn_pool,
ws_connections,
cancel_map,
session_id,
sni_name,
@@ -129,6 +136,9 @@ pub async fn task_main(
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
// await websocket connections
ws_connections.wait().await;
Ok(())
}
@@ -170,10 +180,12 @@ where
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
ws_connections: TaskTracker,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
sni_hostname: Option<String>,
@@ -193,7 +205,7 @@ async fn request_handler(
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
.map_err(|e| ApiError::BadRequest(e.into()))?;
tokio::spawn(
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
websocket,