diff --git a/libs/utils/src/sync/gate.rs b/libs/utils/src/sync/gate.rs index 16ec563fa7..0a1ed81621 100644 --- a/libs/utils/src/sync/gate.rs +++ b/libs/utils/src/sync/gate.rs @@ -64,6 +64,12 @@ pub struct GateGuard { gate: Arc, } +impl GateGuard { + pub fn try_clone(&self) -> Result { + Gate::enter_impl(self.gate.clone()) + } +} + impl Drop for GateGuard { fn drop(&mut self) { if self.gate.closing.load(Ordering::Relaxed) { @@ -107,11 +113,11 @@ impl Gate { /// to avoid blocking close() indefinitely: typically types that contain a Gate will /// also contain a CancellationToken. pub fn enter(&self) -> Result { - let permit = self - .inner - .sem - .try_acquire() - .map_err(|_| GateError::GateClosed)?; + Self::enter_impl(self.inner.clone()) + } + + fn enter_impl(gate: Arc) -> Result { + let permit = gate.sem.try_acquire().map_err(|_| GateError::GateClosed)?; // we now have the permit, let's disable the normal raii functionality and leave // "returning" the permit to our GateGuard::drop. @@ -122,7 +128,7 @@ impl Gate { Ok(GateGuard { span_at_enter: tracing::Span::current(), - gate: self.inner.clone(), + gate, }) } @@ -252,4 +258,39 @@ mod tests { // Attempting to enter() is still forbidden gate.enter().expect_err("enter should fail finishing close"); } + + #[tokio::test(start_paused = true)] + async fn clone_gate_guard() { + let gate = Gate::default(); + let forever = Duration::from_secs(24 * 7 * 365); + + let guard1 = gate.enter().expect("gate isn't closed"); + + let guard2 = guard1.try_clone().expect("gate isn't clsoed"); + + let mut close_fut = std::pin::pin!(gate.close()); + + tokio::time::timeout(forever, &mut close_fut) + .await + .unwrap_err(); + + // we polled close_fut once, that should prevent all later enters and clones + gate.enter().unwrap_err(); + guard1.try_clone().unwrap_err(); + guard2.try_clone().unwrap_err(); + + // guard2 keeps gate open even if guard1 is closed + drop(guard1); + tokio::time::timeout(forever, &mut close_fut) + .await + .unwrap_err(); + + drop(guard2); + + // now that the last guard is dropped, closing should complete + close_fut.await; + + // entering is still forbidden + gate.enter().expect_err("enter should stilll fail"); + } } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 03f2beac8c..0bf26dfca1 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -39,7 +39,7 @@ use tokio::io::{AsyncWriteExt, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::*; -use utils::sync::gate::Gate; +use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; use utils::{ auth::{Claims, Scope, SwappableJwtAuth}, @@ -91,7 +91,7 @@ pub struct Listener { pub struct Connections { cancel: CancellationToken, tasks: tokio::task::JoinSet, - gate: Arc, + gate: Gate, } pub fn spawn( @@ -101,7 +101,6 @@ pub fn spawn( tcp_listener: tokio::net::TcpListener, ) -> Listener { let cancel = CancellationToken::new(); - let gate = Arc::new(Gate::default()); let libpq_ctx = RequestContext::todo_child( TaskKind::LibpqEndpointListener, // listener task shouldn't need to download anything. (We will @@ -120,7 +119,6 @@ pub fn spawn( conf.page_service_pipelining.clone(), libpq_ctx, cancel.clone(), - gate, ) .map(anyhow::Ok), )); @@ -175,12 +173,17 @@ pub async fn libpq_listener_main( pipelining_config: PageServicePipeliningConfig, listener_ctx: RequestContext, listener_cancel: CancellationToken, - gate: Arc, ) -> Connections { let connections_cancel = CancellationToken::new(); + let connections_gate = Gate::default(); let mut connection_handler_tasks = tokio::task::JoinSet::default(); loop { + let gate_guard = match connections_gate.enter() { + Ok(guard) => guard, + Err(_) => break, + }; + let accepted = tokio::select! { biased; _ = listener_cancel.cancelled() => break, @@ -207,7 +210,7 @@ pub async fn libpq_listener_main( pipelining_config.clone(), connection_ctx, connections_cancel.child_token(), - Arc::clone(&gate), + gate_guard, )); } Err(err) => { @@ -222,7 +225,7 @@ pub async fn libpq_listener_main( Connections { cancel: connections_cancel, tasks: connection_handler_tasks, - gate, + gate: connections_gate, } } @@ -237,7 +240,7 @@ async fn page_service_conn_main( pipelining_config: PageServicePipeliningConfig, connection_ctx: RequestContext, cancel: CancellationToken, - gate: Arc, + gate_guard: GateGuard, ) -> ConnectionHandlerResult { let _guard = LIVE_CONNECTIONS .with_label_values(&["page_service"]) @@ -292,7 +295,7 @@ async fn page_service_conn_main( pipelining_config, connection_ctx, cancel.clone(), - gate, + gate_guard, ); let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?; @@ -340,7 +343,7 @@ struct PageServerHandler { pipelining_config: PageServicePipeliningConfig, - gate: Arc, + gate_guard: GateGuard, } struct TimelineHandles { @@ -633,7 +636,7 @@ impl PageServerHandler { pipelining_config: PageServicePipeliningConfig, connection_ctx: RequestContext, cancel: CancellationToken, - gate: Arc, + gate_guard: GateGuard, ) -> Self { PageServerHandler { auth, @@ -642,7 +645,7 @@ impl PageServerHandler { timeline_handles: Some(TimelineHandles::new(tenant_manager)), cancel, pipelining_config, - gate, + gate_guard, } } @@ -1161,7 +1164,7 @@ impl PageServerHandler { } } - let io_concurrency = IoConcurrency::spawn_from_env(match self.gate.enter() { + let io_concurrency = IoConcurrency::spawn_from_env(match self.gate_guard.try_clone() { Ok(guard) => guard, Err(_) => { info!("shutdown request received in page handler");