#![warn(missing_docs)] use std::cmp::{Eq, Ordering, PartialOrd}; use std::collections::BinaryHeap; use std::fmt::Debug; use std::mem; use std::sync::Mutex; use std::time::Duration; use tokio::sync::watch::{channel, Receiver, Sender}; use tokio::time::timeout; /// An error happened while waiting for a number #[derive(Debug, PartialEq, Eq, thiserror::Error)] #[error("SeqWaitError")] pub enum SeqWaitError { /// The wait timeout was reached Timeout, /// [`SeqWait::shutdown`] was called Shutdown, } /// Monotonically increasing value /// /// It is handy to store some other fields under the same mutex in SeqWait /// (e.g. store prev_record_lsn). So we allow SeqWait to be parametrized with /// any type that can expose counter. is the type of exposed counter. pub trait MonotonicCounter { /// Bump counter value and check that it goes forward /// N.B.: new_val is an actual new value, not a difference. fn cnt_advance(&mut self, new_val: V); /// Get counter value fn cnt_value(&self) -> V; } /// Internal components of a `SeqWait` struct SeqWaitInt where S: MonotonicCounter, V: Ord, { waiters: BinaryHeap>, current: S, shutdown: bool, } struct Waiter where T: Ord, { wake_num: T, // wake me when this number arrives ... wake_channel: Sender<()>, // ... by sending a message to this channel } // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here // to get that. impl PartialOrd for Waiter { fn partial_cmp(&self, other: &Self) -> Option { other.wake_num.partial_cmp(&self.wake_num) } } impl Ord for Waiter { fn cmp(&self, other: &Self) -> Ordering { other.wake_num.cmp(&self.wake_num) } } impl PartialEq for Waiter { fn eq(&self, other: &Self) -> bool { other.wake_num == self.wake_num } } impl Eq for Waiter {} /// A tool for waiting on a sequence number /// /// This provides a way to wait the arrival of a number. /// As soon as the number arrives by another caller calling /// [`advance`], then the waiter will be woken up. /// /// This implementation takes a blocking Mutex on both [`wait_for`] /// and [`advance`], meaning there may be unexpected executor blocking /// due to thread scheduling unfairness. There are probably better /// implementations, but we can probably live with this for now. /// /// [`wait_for`]: SeqWait::wait_for /// [`advance`]: SeqWait::advance /// /// means Storage, is type of counter that this storage exposes. /// pub struct SeqWait where S: MonotonicCounter, V: Ord, { internal: Mutex>, } impl SeqWait where S: MonotonicCounter + Copy, V: Ord + Copy, { /// Create a new `SeqWait`, initialized to a particular number pub fn new(starting_num: S) -> Self { let internal = SeqWaitInt { waiters: BinaryHeap::new(), current: starting_num, shutdown: false, }; SeqWait { internal: Mutex::new(internal), } } /// Shut down a `SeqWait`, causing all waiters (present and /// future) to return an error. pub fn shutdown(&self) { let waiters = { // Prevent new waiters; wake all those that exist. // Wake everyone with an error. let mut internal = self.internal.lock().unwrap(); // This will steal the entire waiters map. // When we drop it all waiters will be woken. mem::take(&mut internal.waiters) // Drop the lock as we exit this scope. }; // When we drop the waiters list, each Receiver will // be woken with an error. // This drop doesn't need to be explicit; it's done // here to make it easier to read the code and understand // the order of events. drop(waiters); } /// Wait for a number to arrive /// /// This call won't complete until someone has called `advance` /// with a number greater than or equal to the one we're waiting for. pub async fn wait_for(&self, num: V) -> Result<(), SeqWaitError> { match self.queue_for_wait(num) { Ok(None) => Ok(()), Ok(Some(mut rx)) => rx.changed().await.map_err(|_| SeqWaitError::Shutdown), Err(e) => Err(e), } } /// Wait for a number to arrive /// /// This call won't complete until someone has called `advance` /// with a number greater than or equal to the one we're waiting for. /// /// If that hasn't happened after the specified timeout duration, /// [`SeqWaitError::Timeout`] will be returned. pub async fn wait_for_timeout( &self, num: V, timeout_duration: Duration, ) -> Result<(), SeqWaitError> { match self.queue_for_wait(num) { Ok(None) => Ok(()), Ok(Some(mut rx)) => match timeout(timeout_duration, rx.changed()).await { Ok(Ok(())) => Ok(()), Ok(Err(_)) => Err(SeqWaitError::Shutdown), Err(_) => Err(SeqWaitError::Timeout), }, Err(e) => Err(e), } } /// Register and return a channel that will be notified when a number arrives, /// or None, if it has already arrived. fn queue_for_wait(&self, num: V) -> Result>, SeqWaitError> { let mut internal = self.internal.lock().unwrap(); if internal.current.cnt_value() >= num { return Ok(None); } if internal.shutdown { return Err(SeqWaitError::Shutdown); } // Create a new channel. let (tx, rx) = channel(()); internal.waiters.push(Waiter { wake_num: num, wake_channel: tx, }); // Drop the lock as we exit this scope. Ok(Some(rx)) } /// Announce a new number has arrived /// /// All waiters at this value or below will be woken. /// /// Returns the old number. pub fn advance(&self, num: V) -> V { let old_value; let wake_these = { let mut internal = self.internal.lock().unwrap(); old_value = internal.current.cnt_value(); if old_value >= num { return old_value; } internal.current.cnt_advance(num); // Pop all waiters <= num from the heap. Collect them in a vector, and // wake them up after releasing the lock. let mut wake_these = Vec::new(); while let Some(n) = internal.waiters.peek() { if n.wake_num > num { break; } wake_these.push(internal.waiters.pop().unwrap().wake_channel); } wake_these }; for tx in wake_these { // This can fail if there are no receivers. // We don't care; discard the error. let _ = tx.send(()); } old_value } /// Read the current value, without waiting. pub fn load(&self) -> S { self.internal.lock().unwrap().current } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use std::time::Duration; impl MonotonicCounter for i32 { fn cnt_advance(&mut self, val: i32) { assert!(*self <= val); *self = val; } fn cnt_value(&self) -> i32 { *self } } #[tokio::test] async fn seqwait() { let seq = Arc::new(SeqWait::new(0)); let seq2 = Arc::clone(&seq); let seq3 = Arc::clone(&seq); let jh1 = tokio::task::spawn(async move { seq2.wait_for(42).await.expect("wait_for 42"); let old = seq2.advance(100); assert_eq!(old, 99); seq2.wait_for_timeout(999, Duration::from_millis(100)) .await .expect_err("no 999"); }); let jh2 = tokio::task::spawn(async move { seq3.wait_for(42).await.expect("wait_for 42"); seq3.wait_for(0).await.expect("wait_for 0"); }); tokio::time::sleep(Duration::from_millis(200)).await; let old = seq.advance(99); assert_eq!(old, 0); seq.wait_for(100).await.expect("wait_for 100"); // Calling advance with a smaller value is a no-op assert_eq!(seq.advance(98), 100); assert_eq!(seq.load(), 100); jh1.await.unwrap(); jh2.await.unwrap(); seq.shutdown(); } #[tokio::test] async fn seqwait_timeout() { let seq = Arc::new(SeqWait::new(0)); let seq2 = Arc::clone(&seq); let jh = tokio::task::spawn(async move { let timeout = Duration::from_millis(1); let res = seq2.wait_for_timeout(42, timeout).await; assert_eq!(res, Err(SeqWaitError::Timeout)); }); tokio::time::sleep(Duration::from_millis(200)).await; // This will attempt to wake, but nothing will happen // because the waiter already dropped its Receiver. let old = seq.advance(99); assert_eq!(old, 0); jh.await.unwrap(); seq.shutdown(); } }