diff --git a/libs/utils/src/sync/spsc_watch.rs b/libs/utils/src/sync/spsc_watch.rs index d29a820645..6666aa3eaf 100644 --- a/libs/utils/src/sync/spsc_watch.rs +++ b/libs/utils/src/sync/spsc_watch.rs @@ -1,11 +1,12 @@ -//! Like [`tokio::sync::watch`] but the Sender gets an error when the Receiver is gone. -//! -//! TODO: an actually more efficient implementation +//! watch is probably not the right word, because we do take out use tokio_util::sync::CancellationToken; -pub fn channel(init: T) -> (Sender, Receiver) { - let (tx, rx) = tokio::sync::watch::channel(init); +use crate::sync::spsc_fold; + +pub fn channel(init: T) -> (Sender, Receiver) { + let (mut tx, rx) = spsc_fold::channel(); + poll_ready(tx.send(init, |_, _| unreachable!("init"))); let cancel = CancellationToken::new(); ( Sender { @@ -17,45 +18,39 @@ pub fn channel(init: T) -> (Sender, Receiver) { } pub struct Sender { - tx: tokio::sync::watch::Sender, + tx: spsc_fold::Sender, cancel: tokio_util::sync::DropGuard, } pub struct Receiver { - rx: tokio::sync::watch::Receiver, + rx: spsc_fold::Receiver, cancel: CancellationToken, } -pub enum RecvError { - SenderGone, -} -pub enum SendError { - ReceiverGone, -} - -impl Sender { - pub fn send_replace(&mut self, value: T) -> Result<(), (T, SendError)> { - if self.tx.receiver_count() == 0 { - // we don't provide outside access to tx, so, we know the only - // rx that is ever going to exist is gone - return Err((value, SendError::ReceiverGone)); - } - self.tx.send_replace(value); - Ok(()) +impl Sender { + pub fn send_replace(&mut self, value: T) -> Result<(), spsc_fold::SendError> { + poll_ready(self.tx.send(value, |old, new| { + *old = new; + Ok(()) + })) } } impl Receiver { - pub async fn changed(&mut self) -> Result<(), RecvError> { - match self.rx.changed().await { - Ok(()) => Ok(self.rx.borrow()), - Err(e) => Err(RecvError::SenderGone), - } + pub async fn recv(&mut self) -> Result<(), spsc_fold::RecvError> { + todo!() } - pub fn borrow(&self) -> impl Deref { - self.rx.borrow() - } - pub async fn closed(&mut self) { + pub async fn cancelled(&mut self) { self.cancel.cancelled().await } } + +fn poll_ready, O>(f: F) -> O { + futures::executor::block_on(async move { + let f = std::pin::pin!(f); + match futures::poll!(f) { + std::task::Poll::Ready(r) => r, + std::task::Poll::Pending => unreachable!("expecting future to always return Ready"), + } + }) +} diff --git a/safekeeper/src/wal_advertiser.rs b/safekeeper/src/wal_advertiser.rs index 1ca9133f6a..16ae869b9f 100644 --- a/safekeeper/src/wal_advertiser.rs +++ b/safekeeper/src/wal_advertiser.rs @@ -19,7 +19,7 @@ pub async fn task_main( for (node_id, advs) in advertisements { loop { let tx = senders.entry(node_id).or_insert_with(|| { - let (tx, rx) = utils::sync::spsc_fold::channel(); + let (tx, rx) = spsc_watch::channel(advs); tokio::spawn( NodeTask { ps_id: node_id, @@ -40,55 +40,37 @@ pub async fn task_main( struct PageserverTask { ps_id: NodeId, - advs: spsc_fold::Receiver>, + advs: spsc_watch::Receiver>, } impl PageserverTask { /// Cancellation: happens through last PageserverHandle being dropped. async fn run(mut self) { + let mut current; loop { - let Ok(advs) = self.advs.recv().await else { - return; - }; - tokio::select! { - _ = self.advs.cancelled() => { - return; + let res = self.run0(advs).await; + match res { + Ok(()) => {} + Err(err) => { + error!(?err, "error sending advertisements"); + // TODO: proper backoff? + tokio::time::sleep(Duration::from_secs(5)).await; } - res = self.run0() => { - if let Err(err) = res { - error!(?err, "failure sending advertisements, restarting after back-off"); - // TODO: backoff? + cancellation sensitivity - tokio::time::sleep(Duration::from_secs(10)).await; - } - continue; - } - }; + } } } - async fn run0(&mut self) -> anyhow::Result<()> { + async fn run0(&mut self, advs: HashMap) -> anyhow::Result<()> { use storage_broker::wal_advertisement as proto; use storage_broker::wal_advertisement::pageserver_client::PageserverClient; let stream = async_stream::stream! { loop { - while self.pending_advertisements.is_empty() { - tokio::select! { - _ = self.advs.cancelled() => { - return; - } - _ = self.notify_pending_advertisements.notified() => {} - } - let mut state = self.state.lock().unwrap(); - std::mem::swap( - &mut state.pending_advertisements, - &mut self.pending_advertisements, - ); - } - for (tenant_timeline_id, commit_lsn) in self.pending_advertisements.drain() { + for (tenant_timeline_id, commit_lsn) in advs { yield proto::CommitLsnAdvertisement {tenant_timeline_id: Some(tenant_timeline_id), commit_lsn: Some(commit_lsn) }; } - } }; - let client: PageserverClient<_> = PageserverClient::connect(todo!()) - .await - .context("connect")?; + }}; + let client: PageserverClient<_> = + PageserverClient::connect(todo!("how do we learn pageserver hostnames?")) + .await + .context("connect")?; let publish_stream = client .publish_commit_lsn_advertisements(stream) .await