From 956b6f17ca35f002d1dcb74a7c803db798d43c94 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Fri, 17 Feb 2023 13:16:30 +0300 Subject: [PATCH] [proxy] Handle some unix signals. On the surface, this doesn't add much, but there are some benefits: * We can do graceful shutdowns and thus record more code coverage data. * We now have a foundation for the more interesting behaviors, e.g. "stop accepting new connections after SIGTERM but keep serving the existing ones". * We give the otel machinery a chance to flush trace events before finally shutting down. --- proxy/Cargo.toml | 2 +- proxy/src/config.rs | 1 + proxy/src/console/mgmt.rs | 20 +++++---- proxy/src/http/websocket.rs | 2 +- proxy/src/logging.rs | 1 + proxy/src/main.rs | 58 ++++++++++++++++++--------- proxy/src/metrics.rs | 48 ++++++++++------------ test_runner/fixtures/neon_fixtures.py | 10 +++-- workspace_hack/Cargo.toml | 2 +- 9 files changed, 79 insertions(+), 65 deletions(-) diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 96a62d2c49..030a5f1d6e 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -51,7 +51,7 @@ thiserror.workspace = true tls-listener.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true -tokio.workspace = true +tokio = { workspace = true, features = ["signal"] } tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 5e285f3625..600db7f8ec 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -8,6 +8,7 @@ pub struct ProxyConfig { pub metric_collection: Option, } +#[derive(Debug)] pub struct MetricCollectionConfig { pub endpoint: reqwest::Url, pub interval: Duration, diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index 51a117d3b7..c00c06fbb7 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -5,10 +5,7 @@ use crate::{ use anyhow::Context; use once_cell::sync::Lazy; use pq_proto::{BeMessage, SINGLE_COL_ROWDESC}; -use std::{ - net::{TcpListener, TcpStream}, - thread, -}; +use std::{net::TcpStream, thread}; use tracing::{error, info, info_span}; use utils::{ postgres_backend::{self, AuthType, PostgresBackend}, @@ -34,23 +31,24 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N CPLANE_WAITERS.notify(psql_session_id, msg) } -/// Console management API listener thread. +/// Console management API listener task. /// It spawns console response handlers needed for the link auth. -pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> { +pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> { scopeguard::defer! { info!("mgmt has shut down"); } - listener - .set_nonblocking(false) - .context("failed to set listener to blocking")?; - loop { - let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?; + let (socket, peer_addr) = listener.accept().await?; info!("accepted connection from {peer_addr}"); + + let socket = socket.into_std()?; socket .set_nodelay(true) .context("failed to set client socket option")?; + socket + .set_nonblocking(false) + .context("failed to set client socket option")?; // TODO: replace with async tasks. thread::spawn(move || { diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index d4235c2c38..1757652a90 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -186,8 +186,8 @@ async fn ws_handler( } pub async fn task_main( - ws_listener: TcpListener, config: &'static ProxyConfig, + ws_listener: TcpListener, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); diff --git a/proxy/src/logging.rs b/proxy/src/logging.rs index 2baf824fc3..0c8c2858b9 100644 --- a/proxy/src/logging.rs +++ b/proxy/src/logging.rs @@ -41,6 +41,7 @@ impl Drop for LoggingGuard { fn drop(&mut self) { // Shutdown trace pipeline gracefully, so that it has a chance to send any // pending traces before we exit. + tracing::info!("shutting down the tracing machinery"); tracing_utils::shutdown_tracing(); } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 54f49b5a3c..c319cb9cfc 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -28,7 +28,7 @@ use config::ProxyConfig; use futures::FutureExt; use std::{borrow::Cow, future::Future, net::SocketAddr}; use tokio::{net::TcpListener, task::JoinError}; -use tracing::{info, info_span, Instrument}; +use tracing::{info, warn}; use utils::{project_git_version, sentry_init::init_sentry}; project_git_version!(GIT_VERSION); @@ -60,16 +60,17 @@ async fn main() -> anyhow::Result<()> { let mgmt_address: SocketAddr = args.get_one::("mgmt").unwrap().parse()?; info!("Starting mgmt on {mgmt_address}"); - let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?; + let mgmt_listener = TcpListener::bind(mgmt_address).await?; let proxy_address: SocketAddr = args.get_one::("proxy").unwrap().parse()?; info!("Starting proxy on {proxy_address}"); let proxy_listener = TcpListener::bind(proxy_address).await?; let mut tasks = vec![ + tokio::spawn(handle_signals()), tokio::spawn(http::server::task_main(http_listener)), tokio::spawn(proxy::task_main(config, proxy_listener)), - tokio::task::spawn_blocking(move || console::mgmt::thread_main(mgmt_listener)), + tokio::spawn(console::mgmt::task_main(mgmt_listener)), ]; if let Some(wss_address) = args.get_one::("wss") { @@ -78,35 +79,52 @@ async fn main() -> anyhow::Result<()> { let wss_listener = TcpListener::bind(wss_address).await?; tasks.push(tokio::spawn(http::websocket::task_main( - wss_listener, config, + wss_listener, ))); } - // TODO: refactor. - if let Some(metric_collection) = &config.metric_collection { - let hostname = hostname::get()? - .into_string() - .map_err(|e| anyhow::anyhow!("failed to get hostname {e:?}"))?; - - tasks.push(tokio::spawn( - metrics::collect_metrics( - &metric_collection.endpoint, - metric_collection.interval, - hostname, - ) - .instrument(info_span!("collect_metrics")), - )); + if let Some(metrics_config) = &config.metric_collection { + tasks.push(tokio::spawn(metrics::task_main(metrics_config))); } - // This will block until all tasks have completed. - // Furthermore, the first one to fail will cancel the rest. + // This combinator will block until either all tasks complete or + // one of them finishes with an error (others will be cancelled). let tasks = tasks.into_iter().map(flatten_err); let _: Vec<()> = futures::future::try_join_all(tasks).await?; Ok(()) } +/// Handle unix signals appropriately. +async fn handle_signals() -> anyhow::Result<()> { + use tokio::signal::unix::{signal, SignalKind}; + + let mut hangup = signal(SignalKind::hangup())?; + let mut interrupt = signal(SignalKind::interrupt())?; + let mut terminate = signal(SignalKind::terminate())?; + + loop { + tokio::select! { + // Hangup is commonly used for config reload. + _ = hangup.recv() => { + warn!("received SIGHUP; config reload is not supported"); + } + // Shut down the whole application. + _ = interrupt.recv() => { + warn!("received SIGINT, exiting immediately"); + bail!("interrupted"); + } + // TODO: Don't accept new proxy connections. + // TODO: Shut down once all exisiting connections have been closed. + _ = terminate.recv() => { + warn!("received SIGTERM, exiting immediately"); + bail!("terminated"); + } + } + } +} + /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> { let tls_config = match ( diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 83b28288ee..8bbae9638b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -1,10 +1,10 @@ //! Periodically collect proxy consumption metrics //! and push them to a HTTP endpoint. -use crate::http; +use crate::{config::MetricCollectionConfig, http}; use chrono::{DateTime, Utc}; use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE}; use serde::Serialize; -use std::{collections::HashMap, time::Duration}; +use std::collections::HashMap; use tracing::{debug, error, info, instrument, trace}; const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client"; @@ -23,37 +23,31 @@ pub struct Ids { pub endpoint_id: String, } -pub async fn collect_metrics( - metric_collection_endpoint: &reqwest::Url, - metric_collection_interval: Duration, - hostname: String, -) -> anyhow::Result<()> { +pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> { + info!("metrics collector config: {config:?}"); scopeguard::defer! { - info!("collect_metrics has shut down"); + info!("metrics collector has shut down"); } - let mut ticker = tokio::time::interval(metric_collection_interval); - - info!( - "starting collect_metrics. metric_collection_endpoint: {}", - metric_collection_endpoint - ); - let http_client = http::new_client(); let mut cached_metrics: HashMap)> = HashMap::new(); + let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned(); + let mut ticker = tokio::time::interval(config.interval); loop { - tokio::select! { - _ = ticker.tick() => { + ticker.tick().await; - match collect_metrics_iteration(&http_client, &mut cached_metrics, metric_collection_endpoint, hostname.clone()).await - { - Err(e) => { - error!("Failed to send consumption metrics: {} ", e); - }, - Ok(_) => { trace!("collect_metrics_iteration completed successfully") }, - } - } + let res = collect_metrics_iteration( + &http_client, + &mut cached_metrics, + &config.endpoint, + &hostname, + ) + .await; + + match res { + Err(e) => error!("failed to send consumption metrics: {e} "), + Ok(_) => trace!("periodic metrics collection completed successfully"), } } } @@ -102,7 +96,7 @@ async fn collect_metrics_iteration( client: &http::ClientWithMiddleware, cached_metrics: &mut HashMap)>, metric_collection_endpoint: &reqwest::Url, - hostname: String, + hostname: &str, ) -> anyhow::Result<()> { info!( "starting collect_metrics_iteration. metric_collection_endpoint: {}", @@ -133,7 +127,7 @@ async fn collect_metrics_iteration( stop_time: *curr_time, }, metric: PROXY_IO_BYTES_PER_CLIENT, - idempotency_key: idempotency_key(hostname.clone()), + idempotency_key: idempotency_key(hostname.to_owned()), value, extra: Ids { endpoint_id: curr_key.endpoint_id.clone(), diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ca5288fa0a..59d616eb6f 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2529,15 +2529,17 @@ class NeonProxy(PgProtocol): tb: Optional[TracebackType], ): if self._popen is not None: - # NOTE the process will die when we're done with tests anyway, because - # it's a child process. This is mostly to clean up in between different tests. - self._popen.kill() + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warn("failed to gracefully terminate proxy; killing") + self._popen.kill() @staticmethod async def activate_link_auth( local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True ): - pg_user = "proxy" if create_user: diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 68138b3df4..c0cf3c5611 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -45,7 +45,7 @@ scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["raw_value"] } socket2 = { version = "0.4", default-features = false, features = ["all"] } -tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "sync", "time"] } +tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] } tokio-util = { version = "0.7", features = ["codec", "io"] } tonic = { version = "0.8", features = ["tls-roots"] } tower = { version = "0.4", features = ["balance", "buffer", "limit", "retry", "timeout", "util"] }