[proxy] Handle some unix signals.

On the surface, this doesn't add much, but there are some benefits:

* We can do graceful shutdowns and thus record more code coverage data.

* We now have a foundation for the more interesting behaviors, e.g. "stop
  accepting new connections after SIGTERM but keep serving the existing ones".

* We give the otel machinery a chance to flush trace events before
  finally shutting down.
This commit is contained in:
Dmitry Ivanov
2023-02-17 13:16:30 +03:00
parent 6f9af0aa8c
commit 956b6f17ca
9 changed files with 79 additions and 65 deletions

View File

@@ -51,7 +51,7 @@ thiserror.workspace = true
tls-listener.workspace = true
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tokio.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true

View File

@@ -8,6 +8,7 @@ pub struct ProxyConfig {
pub metric_collection: Option<MetricCollectionConfig>,
}
#[derive(Debug)]
pub struct MetricCollectionConfig {
pub endpoint: reqwest::Url,
pub interval: Duration,

View File

@@ -5,10 +5,7 @@ use crate::{
use anyhow::Context;
use once_cell::sync::Lazy;
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::{
net::{TcpListener, TcpStream},
thread,
};
use std::{net::TcpStream, thread};
use tracing::{error, info, info_span};
use utils::{
postgres_backend::{self, AuthType, PostgresBackend},
@@ -34,23 +31,24 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Console management API listener thread.
/// Console management API listener task.
/// It spawns console response handlers needed for the link auth.
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
info!("mgmt has shut down");
}
listener
.set_nonblocking(false)
.context("failed to set listener to blocking")?;
loop {
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
let (socket, peer_addr) = listener.accept().await?;
info!("accepted connection from {peer_addr}");
let socket = socket.into_std()?;
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
socket
.set_nonblocking(false)
.context("failed to set client socket option")?;
// TODO: replace with async tasks.
thread::spawn(move || {

View File

@@ -186,8 +186,8 @@ async fn ws_handler(
}
pub async fn task_main(
ws_listener: TcpListener,
config: &'static ProxyConfig,
ws_listener: TcpListener,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");

View File

@@ -41,6 +41,7 @@ impl Drop for LoggingGuard {
fn drop(&mut self) {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit.
tracing::info!("shutting down the tracing machinery");
tracing_utils::shutdown_tracing();
}
}

View File

@@ -28,7 +28,7 @@ use config::ProxyConfig;
use futures::FutureExt;
use std::{borrow::Cow, future::Future, net::SocketAddr};
use tokio::{net::TcpListener, task::JoinError};
use tracing::{info, info_span, Instrument};
use tracing::{info, warn};
use utils::{project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION);
@@ -60,16 +60,17 @@ async fn main() -> anyhow::Result<()> {
let mgmt_address: SocketAddr = args.get_one::<String>("mgmt").unwrap().parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?;
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_address: SocketAddr = args.get_one::<String>("proxy").unwrap().parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let mut tasks = vec![
tokio::spawn(handle_signals()),
tokio::spawn(http::server::task_main(http_listener)),
tokio::spawn(proxy::task_main(config, proxy_listener)),
tokio::task::spawn_blocking(move || console::mgmt::thread_main(mgmt_listener)),
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
];
if let Some(wss_address) = args.get_one::<String>("wss") {
@@ -78,35 +79,52 @@ async fn main() -> anyhow::Result<()> {
let wss_listener = TcpListener::bind(wss_address).await?;
tasks.push(tokio::spawn(http::websocket::task_main(
wss_listener,
config,
wss_listener,
)));
}
// TODO: refactor.
if let Some(metric_collection) = &config.metric_collection {
let hostname = hostname::get()?
.into_string()
.map_err(|e| anyhow::anyhow!("failed to get hostname {e:?}"))?;
tasks.push(tokio::spawn(
metrics::collect_metrics(
&metric_collection.endpoint,
metric_collection.interval,
hostname,
)
.instrument(info_span!("collect_metrics")),
));
if let Some(metrics_config) = &config.metric_collection {
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
}
// This will block until all tasks have completed.
// Furthermore, the first one to fail will cancel the rest.
// This combinator will block until either all tasks complete or
// one of them finishes with an error (others will be cancelled).
let tasks = tasks.into_iter().map(flatten_err);
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
Ok(())
}
/// Handle unix signals appropriately.
async fn handle_signals() -> anyhow::Result<()> {
use tokio::signal::unix::{signal, SignalKind};
let mut hangup = signal(SignalKind::hangup())?;
let mut interrupt = signal(SignalKind::interrupt())?;
let mut terminate = signal(SignalKind::terminate())?;
loop {
tokio::select! {
// Hangup is commonly used for config reload.
_ = hangup.recv() => {
warn!("received SIGHUP; config reload is not supported");
}
// Shut down the whole application.
_ = interrupt.recv() => {
warn!("received SIGINT, exiting immediately");
bail!("interrupted");
}
// TODO: Don't accept new proxy connections.
// TODO: Shut down once all exisiting connections have been closed.
_ = terminate.recv() => {
warn!("received SIGTERM, exiting immediately");
bail!("terminated");
}
}
}
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
let tls_config = match (

View File

@@ -1,10 +1,10 @@
//! Periodically collect proxy consumption metrics
//! and push them to a HTTP endpoint.
use crate::http;
use crate::{config::MetricCollectionConfig, http};
use chrono::{DateTime, Utc};
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
use serde::Serialize;
use std::{collections::HashMap, time::Duration};
use std::collections::HashMap;
use tracing::{debug, error, info, instrument, trace};
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
@@ -23,37 +23,31 @@ pub struct Ids {
pub endpoint_id: String,
}
pub async fn collect_metrics(
metric_collection_endpoint: &reqwest::Url,
metric_collection_interval: Duration,
hostname: String,
) -> anyhow::Result<()> {
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> {
info!("metrics collector config: {config:?}");
scopeguard::defer! {
info!("collect_metrics has shut down");
info!("metrics collector has shut down");
}
let mut ticker = tokio::time::interval(metric_collection_interval);
info!(
"starting collect_metrics. metric_collection_endpoint: {}",
metric_collection_endpoint
);
let http_client = http::new_client();
let mut cached_metrics: HashMap<Ids, (u64, DateTime<Utc>)> = HashMap::new();
let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned();
let mut ticker = tokio::time::interval(config.interval);
loop {
tokio::select! {
_ = ticker.tick() => {
ticker.tick().await;
match collect_metrics_iteration(&http_client, &mut cached_metrics, metric_collection_endpoint, hostname.clone()).await
{
Err(e) => {
error!("Failed to send consumption metrics: {} ", e);
},
Ok(_) => { trace!("collect_metrics_iteration completed successfully") },
}
}
let res = collect_metrics_iteration(
&http_client,
&mut cached_metrics,
&config.endpoint,
&hostname,
)
.await;
match res {
Err(e) => error!("failed to send consumption metrics: {e} "),
Ok(_) => trace!("periodic metrics collection completed successfully"),
}
}
}
@@ -102,7 +96,7 @@ async fn collect_metrics_iteration(
client: &http::ClientWithMiddleware,
cached_metrics: &mut HashMap<Ids, (u64, DateTime<Utc>)>,
metric_collection_endpoint: &reqwest::Url,
hostname: String,
hostname: &str,
) -> anyhow::Result<()> {
info!(
"starting collect_metrics_iteration. metric_collection_endpoint: {}",
@@ -133,7 +127,7 @@ async fn collect_metrics_iteration(
stop_time: *curr_time,
},
metric: PROXY_IO_BYTES_PER_CLIENT,
idempotency_key: idempotency_key(hostname.clone()),
idempotency_key: idempotency_key(hostname.to_owned()),
value,
extra: Ids {
endpoint_id: curr_key.endpoint_id.clone(),

View File

@@ -2529,15 +2529,17 @@ class NeonProxy(PgProtocol):
tb: Optional[TracebackType],
):
if self._popen is not None:
# NOTE the process will die when we're done with tests anyway, because
# it's a child process. This is mostly to clean up in between different tests.
self._popen.kill()
self._popen.terminate()
try:
self._popen.wait(timeout=5)
except subprocess.TimeoutExpired:
log.warn("failed to gracefully terminate proxy; killing")
self._popen.kill()
@staticmethod
async def activate_link_auth(
local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True
):
pg_user = "proxy"
if create_user:

View File

@@ -45,7 +45,7 @@ scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
socket2 = { version = "0.4", default-features = false, features = ["all"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "sync", "time"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] }
tokio-util = { version = "0.7", features = ["codec", "io"] }
tonic = { version = "0.8", features = ["tls-roots"] }
tower = { version = "0.4", features = ["balance", "buffer", "limit", "retry", "timeout", "util"] }