From f62ce4bcf7efeedc12e4361cceb37ca172ea9a0a Mon Sep 17 00:00:00 2001 From: Eric Seppanen Date: Fri, 23 Apr 2021 13:55:42 -0700 Subject: [PATCH] make seqwait generic SeqWait can use any type that is Ord + Debug + Copy. Debug is not strictly necessary, but allows us to keep the panic message if a caller wants the sequence number to go backwards. --- zenith_utils/src/seqwait.rs | 45 +++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/zenith_utils/src/seqwait.rs b/zenith_utils/src/seqwait.rs index c30304ab6a..b4f3cdd454 100644 --- a/zenith_utils/src/seqwait.rs +++ b/zenith_utils/src/seqwait.rs @@ -1,6 +1,7 @@ #![warn(missing_docs)] use std::collections::BTreeMap; +use std::fmt::Debug; use std::mem; use std::sync::Mutex; use std::time::Duration; @@ -18,9 +19,12 @@ pub enum SeqWaitError { } /// Internal components of a `SeqWait` -struct SeqWaitInt { - waiters: BTreeMap, Receiver<()>)>, - current: u64, +struct SeqWaitInt +where + T: Ord, +{ + waiters: BTreeMap, Receiver<()>)>, + current: T, shutdown: bool, } @@ -38,13 +42,19 @@ struct SeqWaitInt { /// [`wait_for`]: SeqWait::wait_for /// [`advance`]: SeqWait::advance /// -pub struct SeqWait { - internal: Mutex, +pub struct SeqWait +where + T: Ord, +{ + internal: Mutex>, } -impl SeqWait { +impl SeqWait +where + T: Ord + Debug + Copy, +{ /// Create a new `SeqWait`, initialized to a particular number - pub fn new(starting_num: u64) -> Self { + pub fn new(starting_num: T) -> Self { let internal = SeqWaitInt { waiters: BTreeMap::new(), current: starting_num, @@ -82,7 +92,7 @@ impl SeqWait { /// /// This call won't complete until someone has called `advance` /// with a number greater than or equal to the one we're waiting for. - pub async fn wait_for(&self, num: u64) -> Result<(), SeqWaitError> { + pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> { let mut rx = { let mut internal = self.internal.lock().unwrap(); if internal.current >= num { @@ -116,7 +126,7 @@ impl SeqWait { /// [`SeqWaitError::Timeout`] will be returned. pub async fn wait_for_timeout( &self, - num: u64, + num: T, timeout_duration: Duration, ) -> Result<(), SeqWaitError> { timeout(timeout_duration, self.wait_for(num)) @@ -130,23 +140,30 @@ impl SeqWait { /// /// `advance` will panic if you send it a lower number than /// a previous call. - pub fn advance(&self, num: u64) { + pub fn advance(&self, num: T) { let wake_these = { let mut internal = self.internal.lock().unwrap(); if internal.current > num { panic!( - "tried to advance backwards, from {} to {}", + "tried to advance backwards, from {:?} to {:?}", internal.current, num ); } internal.current = num; // split_off will give me all the high-numbered waiters, - // so split and then swap. Everything at or above (num + 1) - // gets to stay. - let mut split = internal.waiters.split_off(&(num + 1)); + // so split and then swap. Everything at or above `num` + // stays. + let mut split = internal.waiters.split_off(&num); std::mem::swap(&mut split, &mut internal.waiters); + + // `split_at` didn't get the value at `num`; if it's + // there take that too. + if let Some(sleeper) = internal.waiters.remove(&num) { + split.insert(num, sleeper); + } + split };