diff --git a/pageserver/src/walreceiver.rs b/pageserver/src/walreceiver.rs index aaf46579a7..30d19de405 100644 --- a/pageserver/src/walreceiver.rs +++ b/pageserver/src/walreceiver.rs @@ -31,6 +31,7 @@ use once_cell::sync::OnceCell; use std::future::Future; use storage_broker::BrokerClientChannel; use tokio::sync::watch; +use tokio_util::sync::CancellationToken; use tracing::*; pub use connection_manager::spawn_connection_manager_task; @@ -76,7 +77,7 @@ pub fn is_broker_client_initialized() -> bool { /// A handle of an asynchronous task. /// The task has a channel that it can use to communicate its lifecycle events in a certain form, see [`TaskEvent`] -/// and a cancellation channel that it can listen to for earlier interrupts. +/// and a cancellation token that it can listen to for earlier interrupts. /// /// Note that the communication happens via the `watch` channel, that does not accumulate the events, replacing the old one with the never one on submission. /// That may lead to certain events not being observed by the listener. @@ -84,7 +85,7 @@ pub fn is_broker_client_initialized() -> bool { pub struct TaskHandle { join_handle: Option>>, events_receiver: watch::Receiver>, - cancellation: watch::Sender<()>, + cancellation: CancellationToken, } pub enum TaskEvent { @@ -102,20 +103,19 @@ pub enum TaskStateUpdate { impl TaskHandle { /// Initializes the task, starting it immediately after the creation. pub fn spawn( - task: impl FnOnce(watch::Sender>, watch::Receiver<()>) -> Fut - + Send - + 'static, + task: impl FnOnce(watch::Sender>, CancellationToken) -> Fut + Send + 'static, ) -> Self where Fut: Future> + Send, E: Send + Sync + 'static, { - let (cancellation, cancellation_receiver) = watch::channel(()); + let cancellation = CancellationToken::new(); let (events_sender, events_receiver) = watch::channel(TaskStateUpdate::Started); + let cancellation_clone = cancellation.clone(); let join_handle = WALRECEIVER_RUNTIME.spawn(async move { events_sender.send(TaskStateUpdate::Started).ok(); - task(events_sender, cancellation_receiver).await + task(events_sender, cancellation_clone).await }); TaskHandle { @@ -157,7 +157,7 @@ impl TaskHandle { /// Aborts current task, waiting for it to finish. pub async fn shutdown(self) { if let Some(jh) = self.join_handle { - self.cancellation.send(()).ok(); + self.cancellation.cancel(); match jh.await { Ok(Ok(())) => debug!("Shutdown success"), Ok(Err(e)) => error!("Shutdown task error: {e:?}"), diff --git a/pageserver/src/walreceiver/walreceiver_connection.rs b/pageserver/src/walreceiver/walreceiver_connection.rs index aca5e8e019..1b9e4923fb 100644 --- a/pageserver/src/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/walreceiver/walreceiver_connection.rs @@ -19,6 +19,7 @@ use postgres_protocol::message::backend::ReplicationMessage; use postgres_types::PgLsn; use tokio::{pin, select, sync::watch, time}; use tokio_postgres::{replication::ReplicationStream, Client}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn}; use crate::{metrics::LIVE_CONNECTIONS_COUNT, walreceiver::TaskStateUpdate}; @@ -59,7 +60,7 @@ pub async fn handle_walreceiver_connection( timeline: Arc, wal_source_connconf: PgConnectionConfig, events_sender: watch::Sender>, - mut cancellation: watch::Receiver<()>, + cancellation: CancellationToken, connect_timeout: Duration, ) -> anyhow::Result<()> { // Connect to the database in replication mode. @@ -98,7 +99,7 @@ pub async fn handle_walreceiver_connection( // The connection object performs the actual communication with the database, // so spawn it off to run on its own. - let mut connection_cancellation = cancellation.clone(); + let connection_cancellation = cancellation.clone(); task_mgr::spawn( WALRECEIVER_RUNTIME.handle(), TaskKind::WalReceiverConnection, @@ -117,7 +118,7 @@ pub async fn handle_walreceiver_connection( } }, - _ = connection_cancellation.changed() => info!("Connection cancelled"), + _ = connection_cancellation.cancelled() => info!("Connection cancelled"), } Ok(()) }, @@ -183,7 +184,7 @@ pub async fn handle_walreceiver_connection( while let Some(replication_message) = { select! { - _ = cancellation.changed() => { + _ = cancellation.cancelled() => { info!("walreceiver interrupted"); None }