mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-01 17:50:38 +00:00
Compare commits
6 Commits
refactor-c
...
heavier_on
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
767c8bb95f | ||
|
|
f5b0e723cb | ||
|
|
9d11fcca02 | ||
|
|
013496c42b | ||
|
|
0d417d7c0f | ||
|
|
fb5420a959 |
@@ -1,5 +1,5 @@
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
@@ -14,6 +14,18 @@ use tokio::sync::Semaphore;
|
||||
pub struct OnceCell<T> {
|
||||
inner: tokio::sync::RwLock<Inner<T>>,
|
||||
initializers: AtomicUsize,
|
||||
/// Do we have one permit or `u32::MAX` permits?
|
||||
///
|
||||
/// Having one permit means the cell is not initialized, and one winning future could
|
||||
/// initialize it. The act of initializing the cell adds `u32::MAX` permits and set this to
|
||||
/// `false`.
|
||||
///
|
||||
/// Deinitializing an initialized cell will first take `u32::MAX` permits handing one of them
|
||||
/// out, then set this back to `true`.
|
||||
///
|
||||
/// Because we need to see all changes to this variable, always use Acquire to read, AcqRel to
|
||||
/// compare_exchange.
|
||||
has_one_permit: AtomicBool,
|
||||
}
|
||||
|
||||
impl<T> Default for OnceCell<T> {
|
||||
@@ -22,6 +34,7 @@ impl<T> Default for OnceCell<T> {
|
||||
Self {
|
||||
inner: Default::default(),
|
||||
initializers: AtomicUsize::new(0),
|
||||
has_one_permit: AtomicBool::new(true),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -47,14 +60,14 @@ impl<T> Default for Inner<T> {
|
||||
impl<T> OnceCell<T> {
|
||||
/// Creates an already initialized `OnceCell` with the given value.
|
||||
pub fn new(value: T) -> Self {
|
||||
let sem = Semaphore::new(1);
|
||||
sem.close();
|
||||
let sem = Semaphore::new(u32::MAX as usize);
|
||||
Self {
|
||||
inner: tokio::sync::RwLock::new(Inner {
|
||||
init_semaphore: Arc::new(sem),
|
||||
value: Some(value),
|
||||
}),
|
||||
initializers: AtomicUsize::new(0),
|
||||
has_one_permit: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,42 +77,47 @@ impl<T> OnceCell<T> {
|
||||
/// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization.
|
||||
///
|
||||
/// Initialization is panic-safe and cancellation-safe.
|
||||
#[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
|
||||
pub async fn get_mut_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardMut<'_, T>, E>
|
||||
where
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
loop {
|
||||
let sem = {
|
||||
let guard = self.inner.write().await;
|
||||
if guard.value.is_some() {
|
||||
tracing::debug!("returning GuardMut over existing value");
|
||||
return Ok(GuardMut(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
{
|
||||
let permit = {
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire().await
|
||||
};
|
||||
|
||||
let permit = permit.expect("semaphore is never closed");
|
||||
|
||||
if !self.has_one_permit.load(Ordering::Acquire) {
|
||||
// it is important that the permit is dropped here otherwise there would be a
|
||||
// deadlock with `take_and_deinit` happening at the same time.
|
||||
tracing::trace!("seems initialization happened already, trying again");
|
||||
continue;
|
||||
}
|
||||
|
||||
permit.forget();
|
||||
}
|
||||
|
||||
tracing::trace!("calling factory");
|
||||
let permit = InitPermit::from(sem);
|
||||
let (value, permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.write().await;
|
||||
if guard.value.is_some() {
|
||||
return Ok(GuardMut(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire_owned().await
|
||||
};
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
Ok(Self::set0(value, guard))
|
||||
}
|
||||
Err(_closed) => {
|
||||
let guard = self.inner.write().await;
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(GuardMut(guard));
|
||||
}
|
||||
return Ok(self.set0(value, guard, permit));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,42 +125,48 @@ impl<T> OnceCell<T> {
|
||||
/// returning the guard.
|
||||
///
|
||||
/// Initialization is panic-safe and cancellation-safe.
|
||||
#[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
|
||||
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardRef<'_, T>, E>
|
||||
where
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
let guard = self.inner.read().await;
|
||||
if guard.value.is_some() {
|
||||
return Ok(GuardRef(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire_owned().await
|
||||
};
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
Ok(Self::set0(value, guard).downgrade())
|
||||
}
|
||||
Err(_closed) => {
|
||||
loop {
|
||||
let sem = {
|
||||
let guard = self.inner.read().await;
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(GuardRef(guard));
|
||||
if guard.value.is_some() {
|
||||
tracing::debug!("returning GuardRef over existing value");
|
||||
return Ok(GuardRef(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
{
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire().await
|
||||
};
|
||||
|
||||
let permit = permit.expect("semaphore is never closed");
|
||||
|
||||
if !self.has_one_permit.load(Ordering::Acquire) {
|
||||
tracing::trace!("seems initialization happened already, trying again");
|
||||
continue;
|
||||
} else {
|
||||
// it is our turn to initialize for sure
|
||||
}
|
||||
|
||||
permit.forget();
|
||||
}
|
||||
|
||||
tracing::trace!("calling factory");
|
||||
let permit = InitPermit::from(sem);
|
||||
let (value, permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
return Ok(self.set0(value, guard, permit).downgrade());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,26 +176,47 @@ impl<T> OnceCell<T> {
|
||||
/// # Panics
|
||||
///
|
||||
/// If the inner has already been initialized.
|
||||
pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> {
|
||||
#[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
|
||||
pub async fn set(&self, value: T, permit: InitPermit) -> GuardMut<'_, T> {
|
||||
let guard = self.inner.write().await;
|
||||
|
||||
assert!(
|
||||
self.has_one_permit.load(Ordering::Acquire),
|
||||
"cannot set when there are multiple permits"
|
||||
);
|
||||
|
||||
// cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
|
||||
// give more permits right now.
|
||||
if guard.init_semaphore.try_acquire().is_ok() {
|
||||
let available = guard.init_semaphore.available_permits();
|
||||
drop(guard);
|
||||
panic!("permit is of wrong origin");
|
||||
panic!("permit is of wrong origin: {available}");
|
||||
}
|
||||
|
||||
Self::set0(value, guard)
|
||||
self.set0(value, guard, permit)
|
||||
}
|
||||
|
||||
fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner<T>>) -> GuardMut<'_, T> {
|
||||
fn set0<'a>(
|
||||
&'a self,
|
||||
value: T,
|
||||
mut guard: tokio::sync::RwLockWriteGuard<'a, Inner<T>>,
|
||||
permit: InitPermit,
|
||||
) -> GuardMut<'a, T> {
|
||||
if guard.value.is_some() {
|
||||
drop(guard);
|
||||
unreachable!("we won permit, must not be initialized");
|
||||
}
|
||||
guard.value = Some(value);
|
||||
guard.init_semaphore.close();
|
||||
assert!(
|
||||
self.has_one_permit
|
||||
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
|
||||
.is_ok(),
|
||||
"should had only had one permit"
|
||||
);
|
||||
permit.forget();
|
||||
guard.init_semaphore.add_permits(u32::MAX as usize);
|
||||
|
||||
tracing::debug!("value initialized");
|
||||
GuardMut(guard)
|
||||
}
|
||||
|
||||
@@ -199,6 +244,48 @@ impl<T> OnceCell<T> {
|
||||
pub fn initializer_count(&self) -> usize {
|
||||
self.initializers.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Take the current value, and a new permit for it's deinitialization.
|
||||
///
|
||||
/// The permit will be on a semaphore part of the new internal value, and any following
|
||||
/// [`OnceCell::get_or_init`] will wait on it to complete.
|
||||
#[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
|
||||
pub async fn take_and_deinit(&self, mut guard: GuardMut<'_, T>) -> (T, InitPermit) {
|
||||
// guard exists => we have been initialized
|
||||
assert!(
|
||||
!self.has_one_permit.load(Ordering::Acquire),
|
||||
"has to have all permits after initializing"
|
||||
);
|
||||
assert!(guard.0.value.is_some(), "guard exists => initialized");
|
||||
|
||||
// we must first drain out all "waiting to initialize" stragglers
|
||||
tracing::trace!("draining other initializers");
|
||||
let all_permits = guard
|
||||
.0
|
||||
.init_semaphore
|
||||
.acquire_many(u32::MAX)
|
||||
.await
|
||||
.expect("never closed");
|
||||
all_permits.forget();
|
||||
tracing::debug!("other initializers drained");
|
||||
|
||||
assert_eq!(guard.0.init_semaphore.available_permits(), 0);
|
||||
|
||||
// now that the permits have been drained, switch the state
|
||||
assert!(
|
||||
self.has_one_permit
|
||||
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
|
||||
.is_ok(),
|
||||
"there should be only one GuardMut attempting take_and_deinit"
|
||||
);
|
||||
|
||||
let value = guard.0.value.take().unwrap();
|
||||
|
||||
// act of creating an init_permit is the same as "adding back one when this is dropped"
|
||||
let init_permit = InitPermit::from(guard.0.init_semaphore.clone());
|
||||
|
||||
(value, init_permit)
|
||||
}
|
||||
}
|
||||
|
||||
/// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
|
||||
@@ -244,24 +331,6 @@ impl<T> std::ops::DerefMut for GuardMut<'_, T> {
|
||||
}
|
||||
|
||||
impl<'a, T> GuardMut<'a, T> {
|
||||
/// Take the current value, and a new permit for it's deinitialization.
|
||||
///
|
||||
/// The permit will be on a semaphore part of the new internal value, and any following
|
||||
/// [`OnceCell::get_or_init`] will wait on it to complete.
|
||||
pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
|
||||
let mut swapped = Inner::default();
|
||||
let permit = swapped
|
||||
.init_semaphore
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.expect("we just created this");
|
||||
std::mem::swap(&mut *self.0, &mut swapped);
|
||||
swapped
|
||||
.value
|
||||
.map(|v| (v, InitPermit(permit)))
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
|
||||
pub fn downgrade(self) -> GuardRef<'a, T> {
|
||||
GuardRef(self.0.downgrade())
|
||||
}
|
||||
@@ -282,13 +351,39 @@ impl<T> std::ops::Deref for GuardRef<'_, T> {
|
||||
}
|
||||
|
||||
/// Type held by OnceCell (de)initializing task.
|
||||
pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
|
||||
pub struct InitPermit(Option<Arc<tokio::sync::Semaphore>>);
|
||||
|
||||
impl From<Arc<tokio::sync::Semaphore>> for InitPermit {
|
||||
fn from(value: Arc<tokio::sync::Semaphore>) -> Self {
|
||||
InitPermit(Some(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl InitPermit {
|
||||
fn forget(mut self) {
|
||||
self.0
|
||||
.take()
|
||||
.expect("unable to forget twice, created with None?");
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for InitPermit {
|
||||
fn drop(&mut self) {
|
||||
if let Some(sem) = self.0.take() {
|
||||
debug_assert_eq!(sem.available_permits(), 0);
|
||||
sem.add_permits(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::Future;
|
||||
|
||||
use super::*;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
pin::{pin, Pin},
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
@@ -366,11 +461,8 @@ mod tests {
|
||||
let cell = cell.clone();
|
||||
let deinitialization_started = deinitialization_started.clone();
|
||||
async move {
|
||||
let (answer, _permit) = cell
|
||||
.get_mut()
|
||||
.await
|
||||
.expect("initialized to value")
|
||||
.take_and_deinit();
|
||||
let guard = cell.get_mut().await.unwrap();
|
||||
let (answer, _permit) = cell.take_and_deinit(guard).await;
|
||||
assert_eq!(answer, initial);
|
||||
|
||||
deinitialization_started.wait().await;
|
||||
@@ -399,12 +491,31 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn reinit_with_deinit_permit() {
|
||||
let cell = Arc::new(OnceCell::new(42));
|
||||
assert!(!cell.has_one_permit.load(Ordering::Acquire));
|
||||
assert_eq!(
|
||||
cell.inner.read().await.init_semaphore.available_permits(),
|
||||
u32::MAX as usize
|
||||
);
|
||||
|
||||
let guard = cell.get_mut().await.unwrap();
|
||||
assert!(!cell.has_one_permit.load(Ordering::Acquire));
|
||||
assert_eq!(
|
||||
guard.0.init_semaphore.available_permits(),
|
||||
u32::MAX as usize
|
||||
);
|
||||
|
||||
let (mol, permit) = cell.take_and_deinit(guard).await;
|
||||
assert!(cell.has_one_permit.load(Ordering::Acquire));
|
||||
assert_eq!(
|
||||
cell.inner.read().await.init_semaphore.available_permits(),
|
||||
0
|
||||
);
|
||||
|
||||
let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit();
|
||||
cell.set(5, permit).await;
|
||||
assert_eq!(*cell.get_mut().await.unwrap(), 5);
|
||||
|
||||
let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit();
|
||||
let guard = cell.get_mut().await.unwrap();
|
||||
let (five, permit) = cell.take_and_deinit(guard).await;
|
||||
assert_eq!(5, five);
|
||||
cell.set(mol, permit).await;
|
||||
assert_eq!(*cell.get_mut().await.unwrap(), 42);
|
||||
@@ -455,4 +566,110 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(*g, "now initialized");
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn reproduce_init_take_deinit_race_ref() {
|
||||
init_take_deinit_scenario(|cell, factory| {
|
||||
Box::pin(async {
|
||||
cell.get_or_init(factory).await.unwrap();
|
||||
})
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn reproduce_init_take_deinit_race_mut() {
|
||||
init_take_deinit_scenario(|cell, factory| {
|
||||
Box::pin(async {
|
||||
cell.get_mut_or_init(factory).await.unwrap();
|
||||
})
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
|
||||
type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
|
||||
|
||||
/// Reproduce an assertion failure with both initialization methods.
|
||||
///
|
||||
/// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
|
||||
/// Alternative would be a macro_rules! but that is the last resort.
|
||||
async fn init_take_deinit_scenario<F>(init_way: F)
|
||||
where
|
||||
F: for<'a> Fn(
|
||||
&'a OnceCell<&'static str>,
|
||||
BoxedInitFunction<&'static str, Infallible>,
|
||||
) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
|
||||
{
|
||||
use tracing::Instrument;
|
||||
|
||||
let cell = OnceCell::default();
|
||||
|
||||
// acquire the init_semaphore only permit to drive initializing tasks in order to waiting
|
||||
// on the same semaphore.
|
||||
let permit = cell
|
||||
.inner
|
||||
.read()
|
||||
.await
|
||||
.init_semaphore
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.unwrap();
|
||||
|
||||
let mut t1 = pin!(init_way(
|
||||
&cell,
|
||||
Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
|
||||
)
|
||||
.instrument(tracing::info_span!("t1")));
|
||||
|
||||
let mut t2 = pin!(init_way(
|
||||
&cell,
|
||||
Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
|
||||
)
|
||||
.instrument(tracing::info_span!("t2")));
|
||||
|
||||
// drive t2 first to the init_semaphore
|
||||
tokio::select! {
|
||||
_ = &mut t2 => unreachable!("it cannot get permit"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
// followed by t1 in the init_semaphore
|
||||
tokio::select! {
|
||||
_ = &mut t1 => unreachable!("it cannot get permit"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
// now let t2 proceed and initialize
|
||||
drop(permit);
|
||||
t2.await;
|
||||
|
||||
// in original implementation which did closing and re-creation of the semaphore, t1 was
|
||||
// still stuck on the first semaphore, but now that t1 and deinit are using the same
|
||||
// semaphore, deinit will have to wait for t1.
|
||||
let mut deinit = pin!(async {
|
||||
let guard = cell.get_mut().await.unwrap();
|
||||
cell.take_and_deinit(guard).await
|
||||
}
|
||||
.instrument(tracing::info_span!("deinit")));
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut deinit => unreachable!("deinit must not make progress before t1 is complete"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
// now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
|
||||
// the new one.
|
||||
tokio::select! {
|
||||
_ = &mut t1 => unreachable!("it cannot get permit"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
let (s, _) = deinit.await;
|
||||
assert_eq!("t2", s);
|
||||
|
||||
t1.await;
|
||||
|
||||
assert_eq!("t1", *cell.get().await.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user