mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
make seqwait generic
SeqWait can use any type that is Ord + Debug + Copy. Debug is not strictly necessary, but allows us to keep the panic message if a caller wants the sequence number to go backwards.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#![warn(missing_docs)]
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::Debug;
|
||||
use std::mem;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
@@ -18,9 +19,12 @@ pub enum SeqWaitError {
|
||||
}
|
||||
|
||||
/// Internal components of a `SeqWait`
|
||||
struct SeqWaitInt {
|
||||
waiters: BTreeMap<u64, (Sender<()>, Receiver<()>)>,
|
||||
current: u64,
|
||||
struct SeqWaitInt<T>
|
||||
where
|
||||
T: Ord,
|
||||
{
|
||||
waiters: BTreeMap<T, (Sender<()>, Receiver<()>)>,
|
||||
current: T,
|
||||
shutdown: bool,
|
||||
}
|
||||
|
||||
@@ -38,13 +42,19 @@ struct SeqWaitInt {
|
||||
/// [`wait_for`]: SeqWait::wait_for
|
||||
/// [`advance`]: SeqWait::advance
|
||||
///
|
||||
pub struct SeqWait {
|
||||
internal: Mutex<SeqWaitInt>,
|
||||
pub struct SeqWait<T>
|
||||
where
|
||||
T: Ord,
|
||||
{
|
||||
internal: Mutex<SeqWaitInt<T>>,
|
||||
}
|
||||
|
||||
impl SeqWait {
|
||||
impl<T> SeqWait<T>
|
||||
where
|
||||
T: Ord + Debug + Copy,
|
||||
{
|
||||
/// Create a new `SeqWait`, initialized to a particular number
|
||||
pub fn new(starting_num: u64) -> Self {
|
||||
pub fn new(starting_num: T) -> Self {
|
||||
let internal = SeqWaitInt {
|
||||
waiters: BTreeMap::new(),
|
||||
current: starting_num,
|
||||
@@ -82,7 +92,7 @@ impl SeqWait {
|
||||
///
|
||||
/// This call won't complete until someone has called `advance`
|
||||
/// with a number greater than or equal to the one we're waiting for.
|
||||
pub async fn wait_for(&self, num: u64) -> Result<(), SeqWaitError> {
|
||||
pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> {
|
||||
let mut rx = {
|
||||
let mut internal = self.internal.lock().unwrap();
|
||||
if internal.current >= num {
|
||||
@@ -116,7 +126,7 @@ impl SeqWait {
|
||||
/// [`SeqWaitError::Timeout`] will be returned.
|
||||
pub async fn wait_for_timeout(
|
||||
&self,
|
||||
num: u64,
|
||||
num: T,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<(), SeqWaitError> {
|
||||
timeout(timeout_duration, self.wait_for(num))
|
||||
@@ -130,23 +140,30 @@ impl SeqWait {
|
||||
///
|
||||
/// `advance` will panic if you send it a lower number than
|
||||
/// a previous call.
|
||||
pub fn advance(&self, num: u64) {
|
||||
pub fn advance(&self, num: T) {
|
||||
let wake_these = {
|
||||
let mut internal = self.internal.lock().unwrap();
|
||||
|
||||
if internal.current > num {
|
||||
panic!(
|
||||
"tried to advance backwards, from {} to {}",
|
||||
"tried to advance backwards, from {:?} to {:?}",
|
||||
internal.current, num
|
||||
);
|
||||
}
|
||||
internal.current = num;
|
||||
|
||||
// split_off will give me all the high-numbered waiters,
|
||||
// so split and then swap. Everything at or above (num + 1)
|
||||
// gets to stay.
|
||||
let mut split = internal.waiters.split_off(&(num + 1));
|
||||
// so split and then swap. Everything at or above `num`
|
||||
// stays.
|
||||
let mut split = internal.waiters.split_off(&num);
|
||||
std::mem::swap(&mut split, &mut internal.waiters);
|
||||
|
||||
// `split_at` didn't get the value at `num`; if it's
|
||||
// there take that too.
|
||||
if let Some(sleeper) = internal.waiters.remove(&num) {
|
||||
split.insert(num, sleeper);
|
||||
}
|
||||
|
||||
split
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user