mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-11 07:22:55 +00:00
proxy: refactor some error handling and shutdowns (#4684)
## Problem It took me a while to understand the purpose of all the tasks spawned in the main functions. ## Summary of changes Utilising the type system and less macros, plus much more comments, document the shutdown procedure of each task in detail
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
/// the outside. Similar to an ingress controller for HTTPS.
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use futures::future::Either;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
@@ -109,20 +110,25 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let main = proxy::flatten_err(tokio::spawn(task_main(
|
||||
let main = tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
)));
|
||||
let signals_task = proxy::flatten_err(tokio::spawn(proxy::handle_signals(cancellation_token)));
|
||||
));
|
||||
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
|
||||
|
||||
tokio::select! {
|
||||
res = main => { res?; },
|
||||
res = signals_task => { res?; },
|
||||
}
|
||||
// the signal task cant ever succeed.
|
||||
// the main task can error, or can succeed on cancellation.
|
||||
// we want to immediately exit on either of these cases
|
||||
let signal = match futures::future::select(signals_task, main).await {
|
||||
Either::Left((res, _)) => proxy::flatten_err(res)?,
|
||||
Either::Right((res, _)) => return proxy::flatten_err(res),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
// maintenance tasks return `Infallible` success values, this is an impossible value
|
||||
// so this match statically ensures that there are no possibilities for that value
|
||||
match signal {}
|
||||
}
|
||||
|
||||
async fn task_main(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::console;
|
||||
use proxy::http;
|
||||
@@ -6,8 +7,10 @@ use proxy::metrics;
|
||||
use anyhow::bail;
|
||||
use clap::{self, Arg};
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use std::pin::pin;
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
@@ -43,43 +46,59 @@ async fn main() -> anyhow::Result<()> {
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let mut client_tasks = vec![tokio::spawn(proxy::proxy::task_main(
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
))];
|
||||
));
|
||||
|
||||
if let Some(wss_address) = args.get_one::<String>("wss") {
|
||||
let wss_address: SocketAddr = wss_address.parse()?;
|
||||
info!("Starting wss on {wss_address}");
|
||||
let wss_listener = TcpListener::bind(wss_address).await?;
|
||||
|
||||
client_tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
client_tasks.spawn(http::websocket::task_main(
|
||||
config,
|
||||
wss_listener,
|
||||
cancellation_token.clone(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let mut tasks = vec![
|
||||
tokio::spawn(proxy::handle_signals(cancellation_token)),
|
||||
tokio::spawn(http::server::task_main(http_listener)),
|
||||
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
|
||||
];
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
|
||||
maintenance_tasks.spawn(http::server::task_main(http_listener));
|
||||
maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
|
||||
maintenance_tasks.spawn(metrics::task_main(metrics_config));
|
||||
}
|
||||
|
||||
let tasks = futures::future::try_join_all(tasks.into_iter().map(proxy::flatten_err));
|
||||
let client_tasks =
|
||||
futures::future::try_join_all(client_tasks.into_iter().map(proxy::flatten_err));
|
||||
tokio::select! {
|
||||
// We are only expecting an error from these forever tasks
|
||||
res = tasks => { res?; },
|
||||
res = client_tasks => { res?; },
|
||||
}
|
||||
Ok(())
|
||||
let maintenance = loop {
|
||||
// get one complete task
|
||||
match futures::future::select(
|
||||
pin!(maintenance_tasks.join_next()),
|
||||
pin!(client_tasks.join_next()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
// exit immediately on maintenance task completion
|
||||
Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
|
||||
// exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
|
||||
Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
|
||||
// exit immediately on client task error
|
||||
Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
|
||||
// exit if all our client tasks have shutdown gracefully
|
||||
Either::Right((None, _)) => return Ok(()),
|
||||
}
|
||||
};
|
||||
|
||||
// maintenance tasks return Infallible success values, this is an impossible value
|
||||
// so this match statically ensures that there are no possibilities for that value
|
||||
match maintenance {}
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
|
||||
@@ -6,7 +6,7 @@ use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::future;
|
||||
use std::{convert::Infallible, future};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
|
||||
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{anyhow, bail};
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use std::net::TcpListener;
|
||||
use std::{convert::Infallible, net::TcpListener};
|
||||
use tracing::info;
|
||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||
|
||||
@@ -12,7 +12,7 @@ fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
endpoint::make_router().get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
scopeguard::defer! {
|
||||
info!("http has shut down");
|
||||
}
|
||||
@@ -23,5 +23,5 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
bail!("hyper server without shutdown handling cannot shutdown successfully");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use futures::{Future, FutureExt};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
@@ -23,7 +24,7 @@ pub mod url;
|
||||
pub mod waiters;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
@@ -50,8 +51,6 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
pub async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
) -> anyhow::Result<()> {
|
||||
f.map(|r| r.context("join error").and_then(|x| x)).await
|
||||
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
|
||||
r.context("join error").and_then(|x| x)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{config::MetricCollectionConfig, http};
|
||||
use chrono::{DateTime, Utc};
|
||||
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use std::{collections::HashMap, convert::Infallible, time::Duration};
|
||||
use tracing::{error, info, instrument, trace, warn};
|
||||
|
||||
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
|
||||
@@ -26,7 +26,7 @@ pub struct Ids {
|
||||
pub branch_id: String,
|
||||
}
|
||||
|
||||
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> {
|
||||
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<Infallible> {
|
||||
info!("metrics collector config: {config:?}");
|
||||
scopeguard::defer! {
|
||||
info!("metrics collector has shut down");
|
||||
|
||||
@@ -162,7 +162,10 @@ pub async fn handle_ws_client(
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
match result {
|
||||
Ok(creds) => creds,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
}
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id, false);
|
||||
@@ -201,7 +204,10 @@ async fn handle_client(
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_names))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
match result {
|
||||
Ok(creds) => creds,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
}
|
||||
};
|
||||
|
||||
let allow_self_signed_compute = config.allow_self_signed_compute;
|
||||
@@ -595,15 +601,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
application_name: params.get("application_name"),
|
||||
};
|
||||
|
||||
let auth_result = async {
|
||||
// `&mut stream` doesn't let us merge those 2 lines.
|
||||
let res = creds
|
||||
.authenticate(&extra, &mut stream, allow_cleartext)
|
||||
.await;
|
||||
|
||||
async { res }.or_else(|e| stream.throw_error(e)).await
|
||||
}
|
||||
.await?;
|
||||
let auth_result = match creds
|
||||
.authenticate(&extra, &mut stream, allow_cleartext)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => return stream.throw_error(e).await,
|
||||
};
|
||||
|
||||
let AuthSuccess {
|
||||
reported_auth_ok,
|
||||
|
||||
Reference in New Issue
Block a user