Files
neon/zenith_utils/src/seqwait.rs
Heikki Linnakangas cff671c1bd Remove duplicated LSN fields from the page cache.
Having multiple copies of the same values is a source of confusion.
Commit da9bf5dc63 fixed one race condition caused by that, for example.
See also discussion at
https://github.com/zenithdb/zenith/issues/57#issuecomment-824393470

This changes SeqWait.advance() to return the old number, and not panic if
you try to move the value backwards. The caller should check for that and
act accordingly.
2021-04-27 10:32:39 +03:00

265 lines
7.8 KiB
Rust

#![warn(missing_docs)]
use std::cmp::{Eq, Ordering, PartialOrd};
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::mem;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::Mutex;
use std::time::Duration;
/// An error happened while waiting for a number
#[derive(Debug, PartialEq, thiserror::Error)]
#[error("SeqWaitError")]
pub enum SeqWaitError {
/// The wait timeout was reached
Timeout,
/// [`SeqWait::shutdown`] was called
Shutdown,
}
/// Internal components of a `SeqWait`
struct SeqWaitInt<T>
where
T: Ord,
{
waiters: BinaryHeap<Waiter<T>>,
current: T,
shutdown: bool,
}
struct Waiter<T>
where
T: Ord,
{
wake_num: T, // wake me when this number arrives ...
wake_channel: Sender<()>, // ... by sending a message to this channel
}
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
// to get that.
impl<T: Ord> PartialOrd for Waiter<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.wake_num.partial_cmp(&self.wake_num)
}
}
impl<T: Ord> Ord for Waiter<T> {
fn cmp(&self, other: &Self) -> Ordering {
other.wake_num.cmp(&self.wake_num)
}
}
impl<T: Ord> PartialEq for Waiter<T> {
fn eq(&self, other: &Self) -> bool {
other.wake_num == self.wake_num
}
}
impl<T: Ord> Eq for Waiter<T> {}
/// A tool for waiting on a sequence number
///
/// This provides a way to wait the arrival of a number.
/// As soon as the number arrives by another caller calling
/// [`advance`], then the waiter will be woken up.
///
/// This implementation takes a blocking Mutex on both [`wait_for`]
/// and [`advance`], meaning there may be unexpected executor blocking
/// due to thread scheduling unfairness. There are probably better
/// implementations, but we can probably live with this for now.
///
/// [`wait_for`]: SeqWait::wait_for
/// [`advance`]: SeqWait::advance
///
pub struct SeqWait<T>
where
T: Ord,
{
internal: Mutex<SeqWaitInt<T>>,
}
impl<T> SeqWait<T>
where
T: Ord + Debug + Copy,
{
/// Create a new `SeqWait`, initialized to a particular number
pub fn new(starting_num: T) -> Self {
let internal = SeqWaitInt {
waiters: BinaryHeap::new(),
current: starting_num,
shutdown: false,
};
SeqWait {
internal: Mutex::new(internal),
}
}
/// Shut down a `SeqWait`, causing all waiters (present and
/// future) to return an error.
pub fn shutdown(&self) {
let waiters = {
// Prevent new waiters; wake all those that exist.
// Wake everyone with an error.
let mut internal = self.internal.lock().unwrap();
// This will steal the entire waiters map.
// When we drop it all waiters will be woken.
mem::take(&mut internal.waiters)
// Drop the lock as we exit this scope.
};
// When we drop the waiters list, each Receiver will
// be woken with an error.
// This drop doesn't need to be explicit; it's done
// here to make it easier to read the code and understand
// the order of events.
drop(waiters);
}
/// Wait for a number to arrive
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
pub fn wait_for(&self, num: T) -> Result<(), SeqWaitError> {
match self.queue_for_wait(num) {
Ok(None) => Ok(()),
Ok(Some(rx)) => rx.recv().map_err(|_| SeqWaitError::Shutdown),
Err(e) => Err(e),
}
}
/// Wait for a number to arrive
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
///
/// If that hasn't happened after the specified timeout duration,
/// [`SeqWaitError::Timeout`] will be returned.
pub fn wait_for_timeout(&self, num: T, timeout_duration: Duration) -> Result<(), SeqWaitError> {
match self.queue_for_wait(num) {
Ok(None) => Ok(()),
Ok(Some(rx)) => rx.recv_timeout(timeout_duration).map_err(|e| match e {
std::sync::mpsc::RecvTimeoutError::Timeout => SeqWaitError::Timeout,
std::sync::mpsc::RecvTimeoutError::Disconnected => SeqWaitError::Shutdown,
}),
Err(e) => Err(e),
}
}
/// Register and return a channel that will be notified when a number arrives,
/// or None, if it has already arrived.
fn queue_for_wait(&self, num: T) -> Result<Option<Receiver<()>>, SeqWaitError> {
let mut internal = self.internal.lock().unwrap();
if internal.current >= num {
return Ok(None);
}
if internal.shutdown {
return Err(SeqWaitError::Shutdown);
}
// Create a new channel.
let (tx, rx) = channel();
internal.waiters.push(Waiter {
wake_num: num,
wake_channel: tx,
});
// Drop the lock as we exit this scope.
Ok(Some(rx))
}
/// Announce a new number has arrived
///
/// All waiters at this value or below will be woken.
///
/// Returns the old number.
pub fn advance(&self, num: T) -> T {
let old_value;
let wake_these = {
let mut internal = self.internal.lock().unwrap();
old_value = internal.current;
if old_value >= num {
return old_value;
}
internal.current = num;
// Pop all waiters <= num from the heap. Collect them in a vector, and
// wake them up after releasing the lock.
let mut wake_these = Vec::new();
while let Some(n) = internal.waiters.peek() {
if n.wake_num > num {
break;
}
wake_these.push(internal.waiters.pop().unwrap().wake_channel);
}
wake_these
};
for tx in wake_these {
// This can fail if there are no receivers.
// We don't care; discard the error.
let _ = tx.send(());
}
old_value
}
/// Read the current value, without waiting.
pub fn load(&self) -> T {
self.internal.lock().unwrap().current
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread::sleep;
use std::thread::spawn;
use std::time::Duration;
#[test]
fn seqwait() {
let seq = Arc::new(SeqWait::new(0));
let seq2 = Arc::clone(&seq);
let seq3 = Arc::clone(&seq);
spawn(move || {
seq2.wait_for(42).expect("wait_for 42");
let old = seq2.advance(100);
assert_eq!(old, 99);
seq2.wait_for(999).expect_err("no 999");
});
spawn(move || {
seq3.wait_for(42).expect("wait_for 42");
seq3.wait_for(0).expect("wait_for 0");
});
sleep(Duration::from_secs(1));
let old = seq.advance(99);
assert_eq!(old, 0);
seq.wait_for(100).expect("wait_for 100");
// Calling advance with a smaller value is a no-op
assert_eq!(seq.advance(98), 100);
assert_eq!(seq.load(), 100);
seq.shutdown();
}
#[test]
fn seqwait_timeout() {
let seq = Arc::new(SeqWait::new(0));
let seq2 = Arc::clone(&seq);
spawn(move || {
let timeout = Duration::from_millis(1);
let res = seq2.wait_for_timeout(42, timeout);
assert_eq!(res, Err(SeqWaitError::Timeout));
});
sleep(Duration::from_secs(1));
// This will attempt to wake, but nothing will happen
// because the waiter already dropped its Receiver.
let old = seq.advance(99);
assert_eq!(old, 0)
}
}