Make proxy shutdown when all connections are closed (#3764)

## Describe your changes
Makes Proxy start draining connections on SIGTERM.
## Issue ticket number and link
#3333
This commit is contained in:
Sasha Krassovsky
2023-04-13 09:31:30 -07:00
committed by GitHub
parent db8dd6f380
commit fd31fafeee
7 changed files with 96 additions and 33 deletions

1
Cargo.lock generated
View File

@@ -2965,6 +2965,7 @@ dependencies = [
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tokio-util",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",

View File

@@ -64,6 +64,7 @@ webpki-roots.workspace = true
x509-parser.workspace = true
workspace_hack.workspace = true
tokio-util.workspace = true
[dev-dependencies]
rcgen.workspace = true

View File

@@ -22,6 +22,7 @@ use tokio::{
io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
net::TcpListener,
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
@@ -188,6 +189,7 @@ async fn ws_handler(
pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -231,6 +233,7 @@ pub async fn task_main(
hyper::Server::builder(accept::from_stream(tls_listener))
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
Ok(())

View File

@@ -28,6 +28,7 @@ use config::ProxyConfig;
use futures::FutureExt;
use std::{borrow::Cow, future::Future, net::SocketAddr};
use tokio::{net::TcpListener, task::JoinError};
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use utils::{project_git_version, sentry_init::init_sentry};
@@ -66,39 +67,48 @@ async fn main() -> anyhow::Result<()> {
let proxy_address: SocketAddr = args.get_one::<String>("proxy").unwrap().parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
let mut tasks = vec![
tokio::spawn(handle_signals()),
tokio::spawn(http::server::task_main(http_listener)),
tokio::spawn(proxy::task_main(config, proxy_listener)),
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
];
let mut client_tasks = vec![tokio::spawn(proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
))];
if let Some(wss_address) = args.get_one::<String>("wss") {
let wss_address: SocketAddr = wss_address.parse()?;
info!("Starting wss on {wss_address}");
let wss_listener = TcpListener::bind(wss_address).await?;
tasks.push(tokio::spawn(http::websocket::task_main(
client_tasks.push(tokio::spawn(http::websocket::task_main(
config,
wss_listener,
cancellation_token.clone(),
)));
}
let mut tasks = vec![
tokio::spawn(handle_signals(cancellation_token)),
tokio::spawn(http::server::task_main(http_listener)),
tokio::spawn(console::mgmt::task_main(mgmt_listener)),
];
if let Some(metrics_config) = &config.metric_collection {
tasks.push(tokio::spawn(metrics::task_main(metrics_config)));
}
// This combinator will block until either all tasks complete or
// one of them finishes with an error (others will be cancelled).
let tasks = tasks.into_iter().map(flatten_err);
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
let tasks = futures::future::try_join_all(tasks.into_iter().map(flatten_err));
let client_tasks = futures::future::try_join_all(client_tasks.into_iter().map(flatten_err));
tokio::select! {
// We are only expecting an error from these forever tasks
res = tasks => { res?; },
res = client_tasks => { res?; },
}
Ok(())
}
/// Handle unix signals appropriately.
async fn handle_signals() -> anyhow::Result<()> {
async fn handle_signals(token: CancellationToken) -> anyhow::Result<()> {
use tokio::signal::unix::{signal, SignalKind};
let mut hangup = signal(SignalKind::hangup())?;
@@ -116,11 +126,9 @@ async fn handle_signals() -> anyhow::Result<()> {
warn!("received SIGINT, exiting immediately");
bail!("interrupted");
}
// TODO: Don't accept new proxy connections.
// TODO: Shut down once all exisiting connections have been closed.
_ = terminate.recv() => {
warn!("received SIGTERM, exiting immediately");
bail!("terminated");
warn!("received SIGTERM, shutting down once all existing connections have closed");
token.cancel();
}
}
}

View File

@@ -17,6 +17,7 @@ use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::measured_stream::MeasuredStream;
@@ -63,6 +64,7 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
pub async fn task_main(
config: &'static ProxyConfig,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -72,29 +74,48 @@ pub async fn task_main(
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let mut connections = tokio::task::JoinSet::new();
let cancel_map = Arc::new(CancelMap::default());
loop {
let (socket, peer_addr) = listener.accept().await?;
info!("accepted postgres client connection from {peer_addr}");
tokio::select! {
accept_result = listener.accept() => {
let (socket, peer_addr) = accept_result?;
info!("accepted postgres client connection from {peer_addr}");
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
tokio::spawn(
async move {
info!("spawned a task for {peer_addr}");
let session_id = uuid::Uuid::new_v4();
let cancel_map = Arc::clone(&cancel_map);
connections.spawn(
async move {
info!("spawned a task for {peer_addr}");
socket
.set_nodelay(true)
.context("failed to set socket option")?;
socket
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, session_id, socket).await
handle_client(config, &cancel_map, session_id, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
}),
);
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
}),
);
_ = cancellation_token.cancelled() => {
drop(listener);
break;
}
}
}
// Drain connections
while let Some(res) = connections.join_next().await {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
}
Ok(())
}
// TODO(tech debt): unite this with its twin below.

View File

@@ -2041,6 +2041,17 @@ class NeonProxy(PgProtocol):
self._wait_until_ready()
return self
# Sends SIGTERM to the proxy if it has been started
def terminate(self):
if self._popen:
self._popen.terminate()
# Waits for proxy to exit if it has been opened with a default timeout of
# two seconds. Raises subprocess.TimeoutExpired if the proxy does not exit in time.
def wait_for_exit(self, timeout=2):
if self._popen:
self._popen.wait(timeout=2)
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
def _wait_until_ready(self):
requests.get(f"http://{self.host}:{self.http_port}/v1/status")

View File

@@ -1,3 +1,5 @@
import subprocess
import psycopg2
import pytest
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
@@ -134,3 +136,19 @@ def test_forward_params_to_client(static_proxy: NeonProxy):
for name, value in cur.fetchall():
# Check that proxy has forwarded this parameter.
assert conn.get_parameter_status(name) == value
@pytest.mark.timeout(5)
def test_close_on_connections_exit(static_proxy: NeonProxy):
# Open two connections, send SIGTERM, then ensure that proxy doesn't exit
# until after connections close.
with static_proxy.connect(options="project=irrelevant"), static_proxy.connect(
options="project=irrelevant"
):
static_proxy.terminate()
with pytest.raises(subprocess.TimeoutExpired):
static_proxy.wait_for_exit(timeout=2)
# Ensure we don't accept any more connections
with pytest.raises(psycopg2.OperationalError):
static_proxy.connect(options="project=irrelevant")
static_proxy.wait_for_exit()