avoid Arc<Gate> by having clonable GateGuard

This commit is contained in:
Christian Schwarz
2024-12-21 20:13:38 +01:00
parent dc58846f0c
commit d776ee66d7
2 changed files with 63 additions and 19 deletions

View File

@@ -64,6 +64,12 @@ pub struct GateGuard {
gate: Arc<GateInner>,
}
impl GateGuard {
pub fn try_clone(&self) -> Result<Self, GateError> {
Gate::enter_impl(self.gate.clone())
}
}
impl Drop for GateGuard {
fn drop(&mut self) {
if self.gate.closing.load(Ordering::Relaxed) {
@@ -107,11 +113,11 @@ impl Gate {
/// to avoid blocking close() indefinitely: typically types that contain a Gate will
/// also contain a CancellationToken.
pub fn enter(&self) -> Result<GateGuard, GateError> {
let permit = self
.inner
.sem
.try_acquire()
.map_err(|_| GateError::GateClosed)?;
Self::enter_impl(self.inner.clone())
}
fn enter_impl(gate: Arc<GateInner>) -> Result<GateGuard, GateError> {
let permit = gate.sem.try_acquire().map_err(|_| GateError::GateClosed)?;
// we now have the permit, let's disable the normal raii functionality and leave
// "returning" the permit to our GateGuard::drop.
@@ -122,7 +128,7 @@ impl Gate {
Ok(GateGuard {
span_at_enter: tracing::Span::current(),
gate: self.inner.clone(),
gate,
})
}
@@ -252,4 +258,39 @@ mod tests {
// Attempting to enter() is still forbidden
gate.enter().expect_err("enter should fail finishing close");
}
#[tokio::test(start_paused = true)]
async fn clone_gate_guard() {
let gate = Gate::default();
let forever = Duration::from_secs(24 * 7 * 365);
let guard1 = gate.enter().expect("gate isn't closed");
let guard2 = guard1.try_clone().expect("gate isn't clsoed");
let mut close_fut = std::pin::pin!(gate.close());
tokio::time::timeout(forever, &mut close_fut)
.await
.unwrap_err();
// we polled close_fut once, that should prevent all later enters and clones
gate.enter().unwrap_err();
guard1.try_clone().unwrap_err();
guard2.try_clone().unwrap_err();
// guard2 keeps gate open even if guard1 is closed
drop(guard1);
tokio::time::timeout(forever, &mut close_fut)
.await
.unwrap_err();
drop(guard2);
// now that the last guard is dropped, closing should complete
close_fut.await;
// entering is still forbidden
gate.enter().expect_err("enter should stilll fail");
}
}

View File

@@ -39,7 +39,7 @@ use tokio::io::{AsyncWriteExt, BufWriter};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::sync::gate::Gate;
use utils::sync::gate::{Gate, GateGuard};
use utils::sync::spsc_fold;
use utils::{
auth::{Claims, Scope, SwappableJwtAuth},
@@ -91,7 +91,7 @@ pub struct Listener {
pub struct Connections {
cancel: CancellationToken,
tasks: tokio::task::JoinSet<ConnectionHandlerResult>,
gate: Arc<Gate>,
gate: Gate,
}
pub fn spawn(
@@ -101,7 +101,6 @@ pub fn spawn(
tcp_listener: tokio::net::TcpListener,
) -> Listener {
let cancel = CancellationToken::new();
let gate = Arc::new(Gate::default());
let libpq_ctx = RequestContext::todo_child(
TaskKind::LibpqEndpointListener,
// listener task shouldn't need to download anything. (We will
@@ -120,7 +119,6 @@ pub fn spawn(
conf.page_service_pipelining.clone(),
libpq_ctx,
cancel.clone(),
gate,
)
.map(anyhow::Ok),
));
@@ -175,12 +173,17 @@ pub async fn libpq_listener_main(
pipelining_config: PageServicePipeliningConfig,
listener_ctx: RequestContext,
listener_cancel: CancellationToken,
gate: Arc<Gate>,
) -> Connections {
let connections_cancel = CancellationToken::new();
let connections_gate = Gate::default();
let mut connection_handler_tasks = tokio::task::JoinSet::default();
loop {
let gate_guard = match connections_gate.enter() {
Ok(guard) => guard,
Err(_) => break,
};
let accepted = tokio::select! {
biased;
_ = listener_cancel.cancelled() => break,
@@ -207,7 +210,7 @@ pub async fn libpq_listener_main(
pipelining_config.clone(),
connection_ctx,
connections_cancel.child_token(),
Arc::clone(&gate),
gate_guard,
));
}
Err(err) => {
@@ -222,7 +225,7 @@ pub async fn libpq_listener_main(
Connections {
cancel: connections_cancel,
tasks: connection_handler_tasks,
gate,
gate: connections_gate,
}
}
@@ -237,7 +240,7 @@ async fn page_service_conn_main(
pipelining_config: PageServicePipeliningConfig,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate: Arc<Gate>,
gate_guard: GateGuard,
) -> ConnectionHandlerResult {
let _guard = LIVE_CONNECTIONS
.with_label_values(&["page_service"])
@@ -292,7 +295,7 @@ async fn page_service_conn_main(
pipelining_config,
connection_ctx,
cancel.clone(),
gate,
gate_guard,
);
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;
@@ -340,7 +343,7 @@ struct PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
gate: Arc<Gate>,
gate_guard: GateGuard,
}
struct TimelineHandles {
@@ -633,7 +636,7 @@ impl PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate: Arc<Gate>,
gate_guard: GateGuard,
) -> Self {
PageServerHandler {
auth,
@@ -642,7 +645,7 @@ impl PageServerHandler {
timeline_handles: Some(TimelineHandles::new(tenant_manager)),
cancel,
pipelining_config,
gate,
gate_guard,
}
}
@@ -1161,7 +1164,7 @@ impl PageServerHandler {
}
}
let io_concurrency = IoConcurrency::spawn_from_env(match self.gate.enter() {
let io_concurrency = IoConcurrency::spawn_from_env(match self.gate_guard.try_clone() {
Ok(guard) => guard,
Err(_) => {
info!("shutdown request received in page handler");