make seqwait generic

SeqWait can use any type that is Ord + Debug + Copy. Debug is not
strictly necessary, but allows us to keep the panic message if a caller
wants the sequence number to go backwards.
This commit is contained in:
Eric Seppanen
2021-04-23 13:55:42 -07:00
parent 3d3eb0ed16
commit f62ce4bcf7

View File

@@ -1,6 +1,7 @@
#![warn(missing_docs)]
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::mem;
use std::sync::Mutex;
use std::time::Duration;
@@ -18,9 +19,12 @@ pub enum SeqWaitError {
}
/// Internal components of a `SeqWait`
struct SeqWaitInt {
waiters: BTreeMap<u64, (Sender<()>, Receiver<()>)>,
current: u64,
struct SeqWaitInt<T>
where
T: Ord,
{
waiters: BTreeMap<T, (Sender<()>, Receiver<()>)>,
current: T,
shutdown: bool,
}
@@ -38,13 +42,19 @@ struct SeqWaitInt {
/// [`wait_for`]: SeqWait::wait_for
/// [`advance`]: SeqWait::advance
///
pub struct SeqWait {
internal: Mutex<SeqWaitInt>,
pub struct SeqWait<T>
where
T: Ord,
{
internal: Mutex<SeqWaitInt<T>>,
}
impl SeqWait {
impl<T> SeqWait<T>
where
T: Ord + Debug + Copy,
{
/// Create a new `SeqWait`, initialized to a particular number
pub fn new(starting_num: u64) -> Self {
pub fn new(starting_num: T) -> Self {
let internal = SeqWaitInt {
waiters: BTreeMap::new(),
current: starting_num,
@@ -82,7 +92,7 @@ impl SeqWait {
///
/// This call won't complete until someone has called `advance`
/// with a number greater than or equal to the one we're waiting for.
pub async fn wait_for(&self, num: u64) -> Result<(), SeqWaitError> {
pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> {
let mut rx = {
let mut internal = self.internal.lock().unwrap();
if internal.current >= num {
@@ -116,7 +126,7 @@ impl SeqWait {
/// [`SeqWaitError::Timeout`] will be returned.
pub async fn wait_for_timeout(
&self,
num: u64,
num: T,
timeout_duration: Duration,
) -> Result<(), SeqWaitError> {
timeout(timeout_duration, self.wait_for(num))
@@ -130,23 +140,30 @@ impl SeqWait {
///
/// `advance` will panic if you send it a lower number than
/// a previous call.
pub fn advance(&self, num: u64) {
pub fn advance(&self, num: T) {
let wake_these = {
let mut internal = self.internal.lock().unwrap();
if internal.current > num {
panic!(
"tried to advance backwards, from {} to {}",
"tried to advance backwards, from {:?} to {:?}",
internal.current, num
);
}
internal.current = num;
// split_off will give me all the high-numbered waiters,
// so split and then swap. Everything at or above (num + 1)
// gets to stay.
let mut split = internal.waiters.split_off(&(num + 1));
// so split and then swap. Everything at or above `num`
// stays.
let mut split = internal.waiters.split_off(&num);
std::mem::swap(&mut split, &mut internal.waiters);
// `split_at` didn't get the value at `num`; if it's
// there take that too.
if let Some(sleeper) = internal.waiters.remove(&num) {
split.insert(num, sleeper);
}
split
};