diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index a5f50cc7c1..849af47cfc 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -5,6 +5,7 @@ /// the outside. Similar to an ingress controller for HTTPS. use std::{net::SocketAddr, sync::Arc}; +use futures::future::Either; use tokio::net::TcpListener; use anyhow::{anyhow, bail, ensure, Context}; @@ -109,20 +110,25 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let main = proxy::flatten_err(tokio::spawn(task_main( + let main = tokio::spawn(task_main( Arc::new(destination), tls_config, proxy_listener, cancellation_token.clone(), - ))); - let signals_task = proxy::flatten_err(tokio::spawn(proxy::handle_signals(cancellation_token))); + )); + let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token)); - tokio::select! { - res = main => { res?; }, - res = signals_task => { res?; }, - } + // the signal task cant ever succeed. + // the main task can error, or can succeed on cancellation. + // we want to immediately exit on either of these cases + let signal = match futures::future::select(signals_task, main).await { + Either::Left((res, _)) => proxy::flatten_err(res)?, + Either::Right((res, _)) => return proxy::flatten_err(res), + }; - Ok(()) + // maintenance tasks return `Infallible` success values, this is an impossible value + // so this match statically ensures that there are no possibilities for that value + match signal {} } async fn task_main( diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 28e6e25317..fc8bc39742 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,3 +1,4 @@ +use futures::future::Either; use proxy::auth; use proxy::console; use proxy::http; @@ -6,8 +7,10 @@ use proxy::metrics; use anyhow::bail; use clap::{self, Arg}; use proxy::config::{self, ProxyConfig}; +use std::pin::pin; use std::{borrow::Cow, net::SocketAddr}; use tokio::net::TcpListener; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; use tracing::warn; @@ -43,43 +46,59 @@ async fn main() -> anyhow::Result<()> { let proxy_listener = TcpListener::bind(proxy_address).await?; let cancellation_token = CancellationToken::new(); - let mut client_tasks = vec![tokio::spawn(proxy::proxy::task_main( + // client facing tasks. these will exit on error or on cancellation + // cancellation returns Ok(()) + let mut client_tasks = JoinSet::new(); + client_tasks.spawn(proxy::proxy::task_main( config, proxy_listener, cancellation_token.clone(), - ))]; + )); if let Some(wss_address) = args.get_one::("wss") { let wss_address: SocketAddr = wss_address.parse()?; info!("Starting wss on {wss_address}"); let wss_listener = TcpListener::bind(wss_address).await?; - client_tasks.push(tokio::spawn(http::websocket::task_main( + client_tasks.spawn(http::websocket::task_main( config, wss_listener, cancellation_token.clone(), - ))); + )); } - let mut tasks = vec![ - tokio::spawn(proxy::handle_signals(cancellation_token)), - tokio::spawn(http::server::task_main(http_listener)), - tokio::spawn(console::mgmt::task_main(mgmt_listener)), - ]; + // maintenance tasks. these never return unless there's an error + let mut maintenance_tasks = JoinSet::new(); + maintenance_tasks.spawn(proxy::handle_signals(cancellation_token)); + maintenance_tasks.spawn(http::server::task_main(http_listener)); + maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener)); if let Some(metrics_config) = &config.metric_collection { - tasks.push(tokio::spawn(metrics::task_main(metrics_config))); + maintenance_tasks.spawn(metrics::task_main(metrics_config)); } - let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err)); - let client_tasks = - futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err)); - tokio::select! { - // We are only expecting an error from these forever tasks - res = tasks => { res?; }, - res = client_tasks => { res?; }, - } - Ok(()) + let maintenance = loop { + // get one complete task + match futures::future::select( + pin!(maintenance_tasks.join_next()), + pin!(client_tasks.join_next()), + ) + .await + { + // exit immediately on maintenance task completion + Either::Left((Some(res), _)) => break proxy::flatten_err(res)?, + // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above) + Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"), + // exit immediately on client task error + Either::Right((Some(res), _)) => proxy::flatten_err(res)?, + // exit if all our client tasks have shutdown gracefully + Either::Right((None, _)) => return Ok(()), + } + }; + + // maintenance tasks return Infallible success values, this is an impossible value + // so this match statically ensures that there are no possibilities for that value + match maintenance {} } /// ProxyConfig is created at proxy startup, and lives forever. diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index 35d1ff59b7..f0e084b679 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -6,7 +6,7 @@ use anyhow::Context; use once_cell::sync::Lazy; use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError}; use pq_proto::{BeMessage, SINGLE_COL_ROWDESC}; -use std::future; +use std::{convert::Infallible, future}; use tokio::net::{TcpListener, TcpStream}; use tracing::{error, info, info_span, Instrument}; @@ -31,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N /// Console management API listener task. /// It spawns console response handlers needed for the link auth. -pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> { +pub async fn task_main(listener: TcpListener) -> anyhow::Result { scopeguard::defer! { info!("mgmt has shut down"); } diff --git a/proxy/src/http/server.rs b/proxy/src/http/server.rs index f35f4f9a62..6186ddde0d 100644 --- a/proxy/src/http/server.rs +++ b/proxy/src/http/server.rs @@ -1,6 +1,6 @@ -use anyhow::anyhow; +use anyhow::{anyhow, bail}; use hyper::{Body, Request, Response, StatusCode}; -use std::net::TcpListener; +use std::{convert::Infallible, net::TcpListener}; use tracing::info; use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService}; @@ -12,7 +12,7 @@ fn make_router() -> RouterBuilder { endpoint::make_router().get("/v1/status", status_handler) } -pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> { +pub async fn task_main(http_listener: TcpListener) -> anyhow::Result { scopeguard::defer! { info!("http has shut down"); } @@ -23,5 +23,5 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> { .serve(service().map_err(|e| anyhow!(e))?) .await?; - Ok(()) + bail!("hyper server without shutdown handling cannot shutdown successfully"); } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 148ee67d90..1e1e216bb7 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -1,5 +1,6 @@ +use std::convert::Infallible; + use anyhow::{bail, Context}; -use futures::{Future, FutureExt}; use tokio::task::JoinError; use tokio_util::sync::CancellationToken; use tracing::warn; @@ -23,7 +24,7 @@ pub mod url; pub mod waiters; /// Handle unix signals appropriately. -pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> { +pub async fn handle_signals(token: CancellationToken) -> anyhow::Result { use tokio::signal::unix::{signal, SignalKind}; let mut hangup = signal(SignalKind::hangup())?; @@ -50,8 +51,6 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> { } /// Flattens `Result>` into `Result`. -pub async fn flatten_err( - f: impl Future, JoinError>>, -) -> anyhow::Result<()> { - f.map(|r| r.context("join error").and_then(|x| x)).await +pub fn flatten_err(r: Result, JoinError>) -> anyhow::Result { + r.context("join error").and_then(|x| x) } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 00fd7f0405..c4be7e1f08 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -4,7 +4,7 @@ use crate::{config::MetricCollectionConfig, http}; use chrono::{DateTime, Utc}; use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE}; use serde::Serialize; -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, convert::Infallible, time::Duration}; use tracing::{error, info, instrument, trace, warn}; const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client"; @@ -26,7 +26,7 @@ pub struct Ids { pub branch_id: String, } -pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> { +pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result { info!("metrics collector config: {config:?}"); scopeguard::defer! { info!("metrics collector has shut down"); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 12ca9c5187..2204fc62c6 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -162,7 +162,10 @@ pub async fn handle_ws_client( .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names)) .transpose(); - async { result }.or_else(|e| stream.throw_error(e)).await? + match result { + Ok(creds) => creds, + Err(e) => stream.throw_error(e).await?, + } }; let client = Client::new(stream, creds, ¶ms, session_id, false); @@ -201,7 +204,10 @@ async fn handle_client( .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_names)) .transpose(); - async { result }.or_else(|e| stream.throw_error(e)).await? + match result { + Ok(creds) => creds, + Err(e) => stream.throw_error(e).await?, + } }; let allow_self_signed_compute = config.allow_self_signed_compute; @@ -595,15 +601,13 @@ impl Client<'_, S> { application_name: params.get("application_name"), }; - let auth_result = async { - // `&mut stream` doesn't let us merge those 2 lines. - let res = creds - .authenticate(&extra, &mut stream, allow_cleartext) - .await; - - async { res }.or_else(|e| stream.throw_error(e)).await - } - .await?; + let auth_result = match creds + .authenticate(&extra, &mut stream, allow_cleartext) + .await + { + Ok(auth_result) => auth_result, + Err(e) => return stream.throw_error(e).await, + }; let AuthSuccess { reported_auth_ok,