proxy: refactor some error handling and shutdowns (#4684)

## Problem

It took me a while to understand the purpose of all the tasks spawned in
the main functions.

## Summary of changes

Utilising the type system and less macros, plus much more comments,
document the shutdown procedure of each task in detail
This commit is contained in:
Conrad Ludgate
2023-07-13 11:03:37 +01:00
committed by GitHub
parent 444d6e337f
commit 0626e0bfd3
7 changed files with 80 additions and 52 deletions

View File

@@ -5,6 +5,7 @@
/// the outside. Similar to an ingress controller for HTTPS.
use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
@@ -109,20 +110,25 @@ async fn main() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let main = proxy::flatten_err(tokio::spawn(task_main(
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
proxy_listener,
cancellation_token.clone(),
)));
let signals_task = proxy::flatten_err(tokio::spawn(proxy::handle_signals(cancellation_token)));
));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
tokio::select! {
res = main => { res?; },
res = signals_task => { res?; },
}
// the signal task cant ever succeed.
// the main task can error, or can succeed on cancellation.
// we want to immediately exit on either of these cases
let signal = match futures::future::select(signals_task, main).await {
Either::Left((res, _)) => proxy::flatten_err(res)?,
Either::Right((res, _)) => return proxy::flatten_err(res),
};
Ok(())
// maintenance tasks return `Infallible` success values, this is an impossible value
// so this match statically ensures that there are no possibilities for that value
match signal {}
}
async fn task_main(

View File

@@ -1,3 +1,4 @@
use futures::future::Either;
use proxy::auth;
use proxy::console;
use proxy::http;
@@ -6,8 +7,10 @@ use proxy::metrics;
use anyhow::bail;
use clap::{self, Arg};
use proxy::config::{self, ProxyConfig};
use std::pin::pin;
use std::{borrow::Cow, net::SocketAddr};
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
@@ -43,43 +46,59 @@ async fn main() -> anyhow::Result<()> {
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
let mut client_tasks = vec![tokio::spawn(proxy::proxy::task_main(
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
))];
));
if let Some(wss_address) = args.get_one::<String>("wss") {
let wss_address: SocketAddr = wss_address.parse()?;
info!("Starting wss on {wss_address}");
let wss_listener = TcpListener::bind(wss_address).await?;
client_tasks.push(tokio::spawn(http::websocket::task_main(
client_tasks.spawn(http::websocket::task_main(
config,
wss_listener,
cancellation_token.clone(),
)));
));
}
let mut tasks = vec![
tokio::spawn(proxy::handle_signals(cancellation_token)),
tokio::spawn(http::server::task_main(http_listener)),
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
];
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
maintenance_tasks.spawn(http::server::task_main(http_listener));
maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
if let Some(metrics_config) = &config.metric_collection {
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
maintenance_tasks.spawn(metrics::task_main(metrics_config));
}
let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err));
let client_tasks =
futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
tokio::select! {
// We are only expecting an error from these forever tasks
res = tasks => { res?; },
res = client_tasks => { res?; },
}
Ok(())
let maintenance = loop {
// get one complete task
match futures::future::select(
pin!(maintenance_tasks.join_next()),
pin!(client_tasks.join_next()),
)
.await
{
// exit immediately on maintenance task completion
Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
// exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
// exit immediately on client task error
Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
// exit if all our client tasks have shutdown gracefully
Either::Right((None, _)) => return Ok(()),
}
};
// maintenance tasks return Infallible success values, this is an impossible value
// so this match statically ensures that there are no possibilities for that value
match maintenance {}
}
/// ProxyConfig is created at proxy startup, and lives forever.

View File

@@ -6,7 +6,7 @@ use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::future;
use std::{convert::Infallible, future};
use tokio::net::{TcpListener, TcpStream};
use tracing::{error, info, info_span, Instrument};
@@ -31,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
/// Console management API listener task.
/// It spawns console response handlers needed for the link auth.
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("mgmt has shut down");
}

View File

@@ -1,6 +1,6 @@
use anyhow::anyhow;
use anyhow::{anyhow, bail};
use hyper::{Body, Request, Response, StatusCode};
use std::net::TcpListener;
use std::{convert::Infallible, net::TcpListener};
use tracing::info;
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
@@ -12,7 +12,7 @@ fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
endpoint::make_router().get("/v1/status", status_handler)
}
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("http has shut down");
}
@@ -23,5 +23,5 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
.serve(service().map_err(|e| anyhow!(e))?)
.await?;
Ok(())
bail!("hyper server without shutdown handling cannot shutdown successfully");
}

View File

@@ -1,5 +1,6 @@
use std::convert::Infallible;
use anyhow::{bail, Context};
use futures::{Future, FutureExt};
use tokio::task::JoinError;
use tokio_util::sync::CancellationToken;
use tracing::warn;
@@ -23,7 +24,7 @@ pub mod url;
pub mod waiters;
/// Handle unix signals appropriately.
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
use tokio::signal::unix::{signal, SignalKind};
let mut hangup = signal(SignalKind::hangup())?;
@@ -50,8 +51,6 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
}
/// Flattens `Result<Result<T>>` into `Result<T>`.
pub async fn flatten_err(
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
) -> anyhow::Result<()> {
f.map(|r| r.context("join error").and_then(|x| x)).await
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
r.context("join error").and_then(|x| x)
}

View File

@@ -4,7 +4,7 @@ use crate::{config::MetricCollectionConfig, http};
use chrono::{DateTime, Utc};
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
use serde::Serialize;
use std::{collections::HashMap, time::Duration};
use std::{collections::HashMap, convert::Infallible, time::Duration};
use tracing::{error, info, instrument, trace, warn};
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
@@ -26,7 +26,7 @@ pub struct Ids {
pub branch_id: String,
}
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> {
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<Infallible> {
info!("metrics collector config: {config:?}");
scopeguard::defer! {
info!("metrics collector has shut down");

View File

@@ -162,7 +162,10 @@ pub async fn handle_ws_client(
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_names))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
match result {
Ok(creds) => creds,
Err(e) => stream.throw_error(e).await?,
}
};
let client = Client::new(stream, creds, &params, session_id, false);
@@ -201,7 +204,10 @@ async fn handle_client(
.map(|_| auth::ClientCredentials::parse(&params, sni, common_names))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
match result {
Ok(creds) => creds,
Err(e) => stream.throw_error(e).await?,
}
};
let allow_self_signed_compute = config.allow_self_signed_compute;
@@ -595,15 +601,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
application_name: params.get("application_name"),
};
let auth_result = async {
// `&mut stream` doesn't let us merge those 2 lines.
let res = creds
.authenticate(&extra, &mut stream, allow_cleartext)
.await;
async { res }.or_else(|e| stream.throw_error(e)).await
}
.await?;
let auth_result = match creds
.authenticate(&extra, &mut stream, allow_cleartext)
.await
{
Ok(auth_result) => auth_result,
Err(e) => return stream.throw_error(e).await,
};
let AuthSuccess {
reported_auth_ok,