mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-09 14:32:57 +00:00
[proxy] Handle some unix signals.
On the surface, this doesn't add much, but there are some benefits: * We can do graceful shutdowns and thus record more code coverage data. * We now have a foundation for the more interesting behaviors, e.g. "stop accepting new connections after SIGTERM but keep serving the existing ones". * We give the otel machinery a chance to flush trace events before finally shutting down.
This commit is contained in:
@@ -51,7 +51,7 @@ thiserror.workspace = true
|
||||
tls-listener.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tracing-opentelemetry.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
|
||||
@@ -8,6 +8,7 @@ pub struct ProxyConfig {
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetricCollectionConfig {
|
||||
pub endpoint: reqwest::Url,
|
||||
pub interval: Duration,
|
||||
|
||||
@@ -5,10 +5,7 @@ use crate::{
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
thread,
|
||||
};
|
||||
use std::{net::TcpStream, thread};
|
||||
use tracing::{error, info, info_span};
|
||||
use utils::{
|
||||
postgres_backend::{self, AuthType, PostgresBackend},
|
||||
@@ -34,23 +31,24 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener thread.
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
}
|
||||
|
||||
listener
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set listener to blocking")?;
|
||||
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
info!("accepted connection from {peer_addr}");
|
||||
|
||||
let socket = socket.into_std()?;
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
socket
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
// TODO: replace with async tasks.
|
||||
thread::spawn(move || {
|
||||
|
||||
@@ -186,8 +186,8 @@ async fn ws_handler(
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
|
||||
@@ -41,6 +41,7 @@ impl Drop for LoggingGuard {
|
||||
fn drop(&mut self) {
|
||||
// Shutdown trace pipeline gracefully, so that it has a chance to send any
|
||||
// pending traces before we exit.
|
||||
tracing::info!("shutting down the tracing machinery");
|
||||
tracing_utils::shutdown_tracing();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::{borrow::Cow, future::Future, net::SocketAddr};
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
use tracing::{info, warn};
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
@@ -60,16 +60,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let mgmt_address: SocketAddr = args.get_one::<String>("mgmt").unwrap().parse()?;
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?;
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||
|
||||
let proxy_address: SocketAddr = args.get_one::<String>("proxy").unwrap().parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let mut tasks = vec![
|
||||
tokio::spawn(handle_signals()),
|
||||
tokio::spawn(http::server::task_main(http_listener)),
|
||||
tokio::spawn(proxy::task_main(config, proxy_listener)),
|
||||
tokio::task::spawn_blocking(move || console::mgmt::thread_main(mgmt_listener)),
|
||||
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
|
||||
];
|
||||
|
||||
if let Some(wss_address) = args.get_one::<String>("wss") {
|
||||
@@ -78,35 +79,52 @@ async fn main() -> anyhow::Result<()> {
|
||||
let wss_listener = TcpListener::bind(wss_address).await?;
|
||||
|
||||
tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
wss_listener,
|
||||
config,
|
||||
wss_listener,
|
||||
)));
|
||||
}
|
||||
|
||||
// TODO: refactor.
|
||||
if let Some(metric_collection) = &config.metric_collection {
|
||||
let hostname = hostname::get()?
|
||||
.into_string()
|
||||
.map_err(|e| anyhow::anyhow!("failed to get hostname {e:?}"))?;
|
||||
|
||||
tasks.push(tokio::spawn(
|
||||
metrics::collect_metrics(
|
||||
&metric_collection.endpoint,
|
||||
metric_collection.interval,
|
||||
hostname,
|
||||
)
|
||||
.instrument(info_span!("collect_metrics")),
|
||||
));
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
|
||||
}
|
||||
|
||||
// This will block until all tasks have completed.
|
||||
// Furthermore, the first one to fail will cancel the rest.
|
||||
// This combinator will block until either all tasks complete or
|
||||
// one of them finishes with an error (others will be cancelled).
|
||||
let tasks = tasks.into_iter().map(flatten_err);
|
||||
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
async fn handle_signals() -> anyhow::Result<()> {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
let mut interrupt = signal(SignalKind::interrupt())?;
|
||||
let mut terminate = signal(SignalKind::terminate())?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Hangup is commonly used for config reload.
|
||||
_ = hangup.recv() => {
|
||||
warn!("received SIGHUP; config reload is not supported");
|
||||
}
|
||||
// Shut down the whole application.
|
||||
_ = interrupt.recv() => {
|
||||
warn!("received SIGINT, exiting immediately");
|
||||
bail!("interrupted");
|
||||
}
|
||||
// TODO: Don't accept new proxy connections.
|
||||
// TODO: Shut down once all exisiting connections have been closed.
|
||||
_ = terminate.recv() => {
|
||||
warn!("received SIGTERM, exiting immediately");
|
||||
bail!("terminated");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let tls_config = match (
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//! Periodically collect proxy consumption metrics
|
||||
//! and push them to a HTTP endpoint.
|
||||
use crate::http;
|
||||
use crate::{config::MetricCollectionConfig, http};
|
||||
use chrono::{DateTime, Utc};
|
||||
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, error, info, instrument, trace};
|
||||
|
||||
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
|
||||
@@ -23,37 +23,31 @@ pub struct Ids {
|
||||
pub endpoint_id: String,
|
||||
}
|
||||
|
||||
pub async fn collect_metrics(
|
||||
metric_collection_endpoint: &reqwest::Url,
|
||||
metric_collection_interval: Duration,
|
||||
hostname: String,
|
||||
) -> anyhow::Result<()> {
|
||||
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> {
|
||||
info!("metrics collector config: {config:?}");
|
||||
scopeguard::defer! {
|
||||
info!("collect_metrics has shut down");
|
||||
info!("metrics collector has shut down");
|
||||
}
|
||||
|
||||
let mut ticker = tokio::time::interval(metric_collection_interval);
|
||||
|
||||
info!(
|
||||
"starting collect_metrics. metric_collection_endpoint: {}",
|
||||
metric_collection_endpoint
|
||||
);
|
||||
|
||||
let http_client = http::new_client();
|
||||
let mut cached_metrics: HashMap<Ids, (u64, DateTime<Utc>)> = HashMap::new();
|
||||
let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned();
|
||||
|
||||
let mut ticker = tokio::time::interval(config.interval);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {
|
||||
ticker.tick().await;
|
||||
|
||||
match collect_metrics_iteration(&http_client, &mut cached_metrics, metric_collection_endpoint, hostname.clone()).await
|
||||
{
|
||||
Err(e) => {
|
||||
error!("Failed to send consumption metrics: {} ", e);
|
||||
},
|
||||
Ok(_) => { trace!("collect_metrics_iteration completed successfully") },
|
||||
}
|
||||
}
|
||||
let res = collect_metrics_iteration(
|
||||
&http_client,
|
||||
&mut cached_metrics,
|
||||
&config.endpoint,
|
||||
&hostname,
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => error!("failed to send consumption metrics: {e} "),
|
||||
Ok(_) => trace!("periodic metrics collection completed successfully"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,7 +96,7 @@ async fn collect_metrics_iteration(
|
||||
client: &http::ClientWithMiddleware,
|
||||
cached_metrics: &mut HashMap<Ids, (u64, DateTime<Utc>)>,
|
||||
metric_collection_endpoint: &reqwest::Url,
|
||||
hostname: String,
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"starting collect_metrics_iteration. metric_collection_endpoint: {}",
|
||||
@@ -133,7 +127,7 @@ async fn collect_metrics_iteration(
|
||||
stop_time: *curr_time,
|
||||
},
|
||||
metric: PROXY_IO_BYTES_PER_CLIENT,
|
||||
idempotency_key: idempotency_key(hostname.clone()),
|
||||
idempotency_key: idempotency_key(hostname.to_owned()),
|
||||
value,
|
||||
extra: Ids {
|
||||
endpoint_id: curr_key.endpoint_id.clone(),
|
||||
|
||||
@@ -2529,15 +2529,17 @@ class NeonProxy(PgProtocol):
|
||||
tb: Optional[TracebackType],
|
||||
):
|
||||
if self._popen is not None:
|
||||
# NOTE the process will die when we're done with tests anyway, because
|
||||
# it's a child process. This is mostly to clean up in between different tests.
|
||||
self._popen.kill()
|
||||
self._popen.terminate()
|
||||
try:
|
||||
self._popen.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
log.warn("failed to gracefully terminate proxy; killing")
|
||||
self._popen.kill()
|
||||
|
||||
@staticmethod
|
||||
async def activate_link_auth(
|
||||
local_vanilla_pg, proxy_with_metric_collector, psql_session_id, create_user=True
|
||||
):
|
||||
|
||||
pg_user = "proxy"
|
||||
|
||||
if create_user:
|
||||
|
||||
@@ -45,7 +45,7 @@ scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
socket2 = { version = "0.4", default-features = false, features = ["all"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "sync", "time"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io"] }
|
||||
tonic = { version = "0.8", features = ["tls-roots"] }
|
||||
tower = { version = "0.4", features = ["balance", "buffer", "limit", "retry", "timeout", "util"] }
|
||||
|
||||
Reference in New Issue
Block a user