diff --git a/libs/utils/src/sync/heavier_once_cell.rs b/libs/utils/src/sync/heavier_once_cell.rs index 81625b907e..e81f768533 100644 --- a/libs/utils/src/sync/heavier_once_cell.rs +++ b/libs/utils/src/sync/heavier_once_cell.rs @@ -1,5 +1,5 @@ use std::sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }; use tokio::sync::Semaphore; @@ -14,6 +14,18 @@ use tokio::sync::Semaphore; pub struct OnceCell { inner: tokio::sync::RwLock>, initializers: AtomicUsize, + /// Do we have one permit or `u32::MAX` permits? + /// + /// Having one permit means the cell is not initialized, and one winning future could + /// initialize it. The act of initializing the cell adds `u32::MAX` permits and set this to + /// `false`. + /// + /// Deinitializing an initialized cell will first take `u32::MAX` permits handing one of them + /// out, then set this back to `true`. + /// + /// Because we need to see all changes to this variable, always use Acquire to read, AcqRel to + /// compare_exchange. + has_one_permit: AtomicBool, } impl Default for OnceCell { @@ -22,6 +34,7 @@ impl Default for OnceCell { Self { inner: Default::default(), initializers: AtomicUsize::new(0), + has_one_permit: AtomicBool::new(true), } } } @@ -47,14 +60,14 @@ impl Default for Inner { impl OnceCell { /// Creates an already initialized `OnceCell` with the given value. pub fn new(value: T) -> Self { - let sem = Semaphore::new(1); - sem.close(); + let sem = Semaphore::new(u32::MAX as usize); Self { inner: tokio::sync::RwLock::new(Inner { init_semaphore: Arc::new(sem), value: Some(value), }), initializers: AtomicUsize::new(0), + has_one_permit: AtomicBool::new(false), } } @@ -64,6 +77,7 @@ impl OnceCell { /// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization. /// /// Initialization is panic-safe and cancellation-safe. + #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)] pub async fn get_mut_or_init(&self, factory: F) -> Result, E> where F: FnOnce(InitPermit) -> Fut, @@ -73,6 +87,7 @@ impl OnceCell { let sem = { let guard = self.inner.write().await; if guard.value.is_some() { + tracing::debug!("returning GuardMut over existing value"); return Ok(GuardMut(guard)); } guard.init_semaphore.clone() @@ -80,33 +95,29 @@ impl OnceCell { { let permit = { - // increment the count for the duration of queued let _guard = CountWaitingInitializers::start(self); sem.acquire().await }; - let Ok(permit) = permit else { - let guard = self.inner.write().await; - if !Arc::ptr_eq(&sem, &guard.init_semaphore) { - // there was a take_and_deinit in between - continue; - } - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardMut(guard)); - }; + let permit = permit.expect("semaphore is never closed"); + + if !self.has_one_permit.load(Ordering::Acquire) { + // it is important that the permit is dropped here otherwise there would be a + // deadlock with `take_and_deinit` happening at the same time. + tracing::trace!("seems initialization happened already, trying again"); + continue; + } permit.forget(); } - let permit = InitPermit(sem); - let (value, _permit) = factory(permit).await?; + tracing::trace!("calling factory"); + let permit = InitPermit::from(sem); + let (value, permit) = factory(permit).await?; let guard = self.inner.write().await; - return Ok(Self::set0(value, guard)); + return Ok(self.set0(value, guard, permit)); } } @@ -114,6 +125,7 @@ impl OnceCell { /// returning the guard. /// /// Initialization is panic-safe and cancellation-safe. + #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)] pub async fn get_or_init(&self, factory: F) -> Result, E> where F: FnOnce(InitPermit) -> Fut, @@ -123,6 +135,7 @@ impl OnceCell { let sem = { let guard = self.inner.read().await; if guard.value.is_some() { + tracing::debug!("returning GuardRef over existing value"); return Ok(GuardRef(guard)); } guard.init_semaphore.clone() @@ -135,28 +148,25 @@ impl OnceCell { sem.acquire().await }; - let Ok(permit) = permit else { - let guard = self.inner.read().await; - if !Arc::ptr_eq(&sem, &guard.init_semaphore) { - // there was a take_and_deinit in between - continue; - } - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardRef(guard)); - }; + let permit = permit.expect("semaphore is never closed"); + + if !self.has_one_permit.load(Ordering::Acquire) { + tracing::trace!("seems initialization happened already, trying again"); + continue; + } else { + // it is our turn to initialize for sure + } permit.forget(); } - let permit = InitPermit(sem); - let (value, _permit) = factory(permit).await?; + tracing::trace!("calling factory"); + let permit = InitPermit::from(sem); + let (value, permit) = factory(permit).await?; let guard = self.inner.write().await; - return Ok(Self::set0(value, guard).downgrade()); + return Ok(self.set0(value, guard, permit).downgrade()); } } @@ -166,26 +176,47 @@ impl OnceCell { /// # Panics /// /// If the inner has already been initialized. - pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> { + #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)] + pub async fn set(&self, value: T, permit: InitPermit) -> GuardMut<'_, T> { let guard = self.inner.write().await; + assert!( + self.has_one_permit.load(Ordering::Acquire), + "cannot set when there are multiple permits" + ); + // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot // give more permits right now. if guard.init_semaphore.try_acquire().is_ok() { + let available = guard.init_semaphore.available_permits(); drop(guard); - panic!("permit is of wrong origin"); + panic!("permit is of wrong origin: {available}"); } - Self::set0(value, guard) + self.set0(value, guard, permit) } - fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner>) -> GuardMut<'_, T> { + fn set0<'a>( + &'a self, + value: T, + mut guard: tokio::sync::RwLockWriteGuard<'a, Inner>, + permit: InitPermit, + ) -> GuardMut<'a, T> { if guard.value.is_some() { drop(guard); unreachable!("we won permit, must not be initialized"); } guard.value = Some(value); - guard.init_semaphore.close(); + assert!( + self.has_one_permit + .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) + .is_ok(), + "should had only had one permit" + ); + permit.forget(); + guard.init_semaphore.add_permits(u32::MAX as usize); + + tracing::debug!("value initialized"); GuardMut(guard) } @@ -213,6 +244,48 @@ impl OnceCell { pub fn initializer_count(&self) -> usize { self.initializers.load(Ordering::Relaxed) } + + /// Take the current value, and a new permit for it's deinitialization. + /// + /// The permit will be on a semaphore part of the new internal value, and any following + /// [`OnceCell::get_or_init`] will wait on it to complete. + #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)] + pub async fn take_and_deinit(&self, mut guard: GuardMut<'_, T>) -> (T, InitPermit) { + // guard exists => we have been initialized + assert!( + !self.has_one_permit.load(Ordering::Acquire), + "has to have all permits after initializing" + ); + assert!(guard.0.value.is_some(), "guard exists => initialized"); + + // we must first drain out all "waiting to initialize" stragglers + tracing::trace!("draining other initializers"); + let all_permits = guard + .0 + .init_semaphore + .acquire_many(u32::MAX) + .await + .expect("never closed"); + all_permits.forget(); + tracing::debug!("other initializers drained"); + + assert_eq!(guard.0.init_semaphore.available_permits(), 0); + + // now that the permits have been drained, switch the state + assert!( + self.has_one_permit + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok(), + "there should be only one GuardMut attempting take_and_deinit" + ); + + let value = guard.0.value.take().unwrap(); + + // act of creating an init_permit is the same as "adding back one when this is dropped" + let init_permit = InitPermit::from(guard.0.init_semaphore.clone()); + + (value, init_permit) + } } /// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the @@ -258,21 +331,6 @@ impl std::ops::DerefMut for GuardMut<'_, T> { } impl<'a, T> GuardMut<'a, T> { - /// Take the current value, and a new permit for it's deinitialization. - /// - /// The permit will be on a semaphore part of the new internal value, and any following - /// [`OnceCell::get_or_init`] will wait on it to complete. - pub fn take_and_deinit(&mut self) -> (T, InitPermit) { - let mut swapped = Inner::default(); - let sem = swapped.init_semaphore.clone(); - sem.try_acquire().expect("we just created this").forget(); - std::mem::swap(&mut *self.0, &mut swapped); - swapped - .value - .map(|v| (v, InitPermit(sem))) - .expect("guard is not created unless value has been initialized") - } - pub fn downgrade(self) -> GuardRef<'a, T> { GuardRef(self.0.downgrade()) } @@ -293,12 +351,28 @@ impl std::ops::Deref for GuardRef<'_, T> { } /// Type held by OnceCell (de)initializing task. -pub struct InitPermit(Arc); +pub struct InitPermit(Option>); + +impl From> for InitPermit { + fn from(value: Arc) -> Self { + InitPermit(Some(value)) + } +} + +impl InitPermit { + fn forget(mut self) { + self.0 + .take() + .expect("unable to forget twice, created with None?"); + } +} impl Drop for InitPermit { fn drop(&mut self) { - debug_assert_eq!(self.0.available_permits(), 0); - self.0.add_permits(1); + if let Some(sem) = self.0.take() { + debug_assert_eq!(sem.available_permits(), 0); + sem.add_permits(1); + } } } @@ -387,11 +461,8 @@ mod tests { let cell = cell.clone(); let deinitialization_started = deinitialization_started.clone(); async move { - let (answer, _permit) = cell - .get_mut() - .await - .expect("initialized to value") - .take_and_deinit(); + let guard = cell.get_mut().await.unwrap(); + let (answer, _permit) = cell.take_and_deinit(guard).await; assert_eq!(answer, initial); deinitialization_started.wait().await; @@ -420,12 +491,31 @@ mod tests { #[tokio::test] async fn reinit_with_deinit_permit() { let cell = Arc::new(OnceCell::new(42)); + assert!(!cell.has_one_permit.load(Ordering::Acquire)); + assert_eq!( + cell.inner.read().await.init_semaphore.available_permits(), + u32::MAX as usize + ); + + let guard = cell.get_mut().await.unwrap(); + assert!(!cell.has_one_permit.load(Ordering::Acquire)); + assert_eq!( + guard.0.init_semaphore.available_permits(), + u32::MAX as usize + ); + + let (mol, permit) = cell.take_and_deinit(guard).await; + assert!(cell.has_one_permit.load(Ordering::Acquire)); + assert_eq!( + cell.inner.read().await.init_semaphore.available_permits(), + 0 + ); - let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit(); cell.set(5, permit).await; assert_eq!(*cell.get_mut().await.unwrap(), 5); - let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit(); + let guard = cell.get_mut().await.unwrap(); + let (five, permit) = cell.take_and_deinit(guard).await; assert_eq!(5, five); cell.set(mol, permit).await; assert_eq!(*cell.get_mut().await.unwrap(), 42); @@ -478,7 +568,7 @@ mod tests { } #[tokio::test(start_paused = true)] - async fn reproduce_init_take_deinit_race() { + async fn reproduce_init_take_deinit_race_ref() { init_take_deinit_scenario(|cell, factory| { Box::pin(async { cell.get_or_init(factory).await.unwrap(); @@ -511,6 +601,8 @@ mod tests { BoxedInitFunction<&'static str, Infallible>, ) -> Pin + 'a>>, { + use tracing::Instrument; + let cell = OnceCell::default(); // acquire the init_semaphore only permit to drive initializing tasks in order to waiting @@ -527,12 +619,14 @@ mod tests { let mut t1 = pin!(init_way( &cell, Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })), - )); + ) + .instrument(tracing::info_span!("t1"))); let mut t2 = pin!(init_way( &cell, Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })), - )); + ) + .instrument(tracing::info_span!("t2"))); // drive t2 first to the init_semaphore tokio::select! { @@ -550,8 +644,19 @@ mod tests { drop(permit); t2.await; - let (s, permit) = { cell.get_mut().await.unwrap().take_and_deinit() }; - assert_eq!("t2", s); + // in original implementation which did closing and re-creation of the semaphore, t1 was + // still stuck on the first semaphore, but now that t1 and deinit are using the same + // semaphore, deinit will have to wait for t1. + let mut deinit = pin!(async { + let guard = cell.get_mut().await.unwrap(); + cell.take_and_deinit(guard).await + } + .instrument(tracing::info_span!("deinit"))); + + tokio::select! { + _ = &mut deinit => unreachable!("deinit must not make progress before t1 is complete"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from // the new one. @@ -560,8 +665,9 @@ mod tests { _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} } - // only now we get to initialize it - drop(permit); + let (s, _) = deinit.await; + assert_eq!("t2", s); + t1.await; assert_eq!("t1", *cell.get().await.unwrap());