diff --git a/Cargo.lock b/Cargo.lock index fc587c57bf..f67311cf09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2965,6 +2965,7 @@ dependencies = [ "tokio-postgres", "tokio-postgres-rustls", "tokio-rustls", + "tokio-util", "tracing", "tracing-opentelemetry", "tracing-subscriber", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index add8b14c95..9d702b29c3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -64,6 +64,7 @@ webpki-roots.workspace = true x509-parser.workspace = true workspace_hack.workspace = true +tokio-util.workspace = true [dev-dependencies] rcgen.workspace = true diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 1757652a90..c7676e8e14 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -22,6 +22,7 @@ use tokio::{ io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}, net::TcpListener, }; +use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; use utils::http::{error::ApiError, json::json_response}; @@ -188,6 +189,7 @@ async fn ws_handler( pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, + cancellation_token: CancellationToken, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -231,6 +233,7 @@ pub async fn task_main( hyper::Server::builder(accept::from_stream(tls_listener)) .serve(make_svc) + .with_graceful_shutdown(cancellation_token.cancelled()) .await?; Ok(()) diff --git a/proxy/src/main.rs b/proxy/src/main.rs index c6526e9aff..1fd13c9f68 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -28,6 +28,7 @@ use config::ProxyConfig; use futures::FutureExt; use std::{borrow::Cow, future::Future, net::SocketAddr}; use tokio::{net::TcpListener, task::JoinError}; +use tokio_util::sync::CancellationToken; use tracing::{info, warn}; use utils::{project_git_version, sentry_init::init_sentry}; @@ -66,39 +67,48 @@ async fn main() -> anyhow::Result<()> { let proxy_address: SocketAddr = args.get_one::("proxy").unwrap().parse()?; info!("Starting proxy on {proxy_address}"); let proxy_listener = TcpListener::bind(proxy_address).await?; + let cancellation_token = CancellationToken::new(); - let mut tasks = vec![ - tokio::spawn(handle_signals()), - tokio::spawn(http::server::task_main(http_listener)), - tokio::spawn(proxy::task_main(config, proxy_listener)), - tokio::spawn(console::mgmt::task_main(mgmt_listener)), - ]; + let mut client_tasks = vec![tokio::spawn(proxy::task_main( + config, + proxy_listener, + cancellation_token.clone(), + ))]; if let Some(wss_address) = args.get_one::("wss") { let wss_address: SocketAddr = wss_address.parse()?; info!("Starting wss on {wss_address}"); let wss_listener = TcpListener::bind(wss_address).await?; - tasks.push(tokio::spawn(http::websocket::task_main( + client_tasks.push(tokio::spawn(http::websocket::task_main( config, wss_listener, + cancellation_token.clone(), ))); } + let mut tasks = vec![ + tokio::spawn(handle_signals(cancellation_token)), + tokio::spawn(http::server::task_main(http_listener)), + tokio::spawn(console::mgmt::task_main(mgmt_listener)), + ]; + if let Some(metrics_config) = &config.metric_collection { tasks.push(tokio::spawn(metrics::task_main(metrics_config))); } - // This combinator will block until either all tasks complete or - // one of them finishes with an error (others will be cancelled). - let tasks = tasks.into_iter().map(flatten_err); - let _: Vec<()> = futures::future::try_join_all(tasks).await?; - + let tasks = futures::future::try_join_all(tasks.into_iter().map(flatten_err)); + let client_tasks = futures::future::try_join_all(client_tasks.into_iter().map(flatten_err)); + tokio::select! { + // We are only expecting an error from these forever tasks + res = tasks => { res?; }, + res = client_tasks => { res?; }, + } Ok(()) } /// Handle unix signals appropriately. -async fn handle_signals() -> anyhow::Result<()> { +async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> { use tokio::signal::unix::{signal, SignalKind}; let mut hangup = signal(SignalKind::hangup())?; @@ -116,11 +126,9 @@ async fn handle_signals() -> anyhow::Result<()> { warn!("received SIGINT, exiting immediately"); bail!("interrupted"); } - // TODO: Don't accept new proxy connections. - // TODO: Shut down once all exisiting connections have been closed. _ = terminate.recv() => { - warn!("received SIGTERM, exiting immediately"); - bail!("terminated"); + warn!("received SIGTERM, shutting down once all existing connections have closed"); + token.cancel(); } } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 70fb25474e..9945e3697f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -17,6 +17,7 @@ use once_cell::sync::Lazy; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; use utils::measured_stream::MeasuredStream; @@ -63,6 +64,7 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, + cancellation_token: CancellationToken, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -72,29 +74,48 @@ pub async fn task_main( // will be inherited by all accepted client sockets. socket2::SockRef::from(&listener).set_keepalive(true)?; + let mut connections = tokio::task::JoinSet::new(); let cancel_map = Arc::new(CancelMap::default()); + loop { - let (socket, peer_addr) = listener.accept().await?; - info!("accepted postgres client connection from {peer_addr}"); + tokio::select! { + accept_result = listener.accept() => { + let (socket, peer_addr) = accept_result?; + info!("accepted postgres client connection from {peer_addr}"); - let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); - tokio::spawn( - async move { - info!("spawned a task for {peer_addr}"); + let session_id = uuid::Uuid::new_v4(); + let cancel_map = Arc::clone(&cancel_map); + connections.spawn( + async move { + info!("spawned a task for {peer_addr}"); - socket - .set_nodelay(true) - .context("failed to set socket option")?; + socket + .set_nodelay(true) + .context("failed to set socket option")?; - handle_client(config, &cancel_map, session_id, socket).await + handle_client(config, &cancel_map, session_id, socket).await + } + .unwrap_or_else(|e| { + // Acknowledge that the task has finished with an error. + error!("per-client task finished with an error: {e:#}"); + }), + ); } - .unwrap_or_else(|e| { - // Acknowledge that the task has finished with an error. - error!("per-client task finished with an error: {e:#}"); - }), - ); + _ = cancellation_token.cancelled() => { + drop(listener); + break; + } + } } + // Drain connections + while let Some(res) = connections.join_next().await { + if let Err(e) = res { + if !e.is_panic() && !e.is_cancelled() { + warn!("unexpected error from joined connection task: {e:?}"); + } + } + } + Ok(()) } // TODO(tech debt): unite this with its twin below. diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index e9f0363843..fb12752d3c 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2041,6 +2041,17 @@ class NeonProxy(PgProtocol): self._wait_until_ready() return self + # Sends SIGTERM to the proxy if it has been started + def terminate(self): + if self._popen: + self._popen.terminate() + + # Waits for proxy to exit if it has been opened with a default timeout of + # two seconds. Raises subprocess.TimeoutExpired if the proxy does not exit in time. + def wait_for_exit(self, timeout=2): + if self._popen: + self._popen.wait(timeout=2) + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) def _wait_until_ready(self): requests.get(f"http://{self.host}:{self.http_port}/v1/status") diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 51fabdd2a1..ee6349436b 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -1,3 +1,5 @@ +import subprocess + import psycopg2 import pytest from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres @@ -134,3 +136,19 @@ def test_forward_params_to_client(static_proxy: NeonProxy): for name, value in cur.fetchall(): # Check that proxy has forwarded this parameter. assert conn.get_parameter_status(name) == value + + +@pytest.mark.timeout(5) +def test_close_on_connections_exit(static_proxy: NeonProxy): + # Open two connections, send SIGTERM, then ensure that proxy doesn't exit + # until after connections close. + with static_proxy.connect(options="project=irrelevant"), static_proxy.connect( + options="project=irrelevant" + ): + static_proxy.terminate() + with pytest.raises(subprocess.TimeoutExpired): + static_proxy.wait_for_exit(timeout=2) + # Ensure we don't accept any more connections + with pytest.raises(psycopg2.OperationalError): + static_proxy.connect(options="project=irrelevant") + static_proxy.wait_for_exit()