diff --git a/libs/utils/src/sync/heavier_once_cell.rs b/libs/utils/src/sync/heavier_once_cell.rs index f733d107f1..81625b907e 100644 --- a/libs/utils/src/sync/heavier_once_cell.rs +++ b/libs/utils/src/sync/heavier_once_cell.rs @@ -69,37 +69,44 @@ impl OnceCell { F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - let sem = { + loop { + let sem = { + let guard = self.inner.write().await; + if guard.value.is_some() { + return Ok(GuardMut(guard)); + } + guard.init_semaphore.clone() + }; + + { + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire().await + }; + + let Ok(permit) = permit else { + let guard = self.inner.write().await; + if !Arc::ptr_eq(&sem, &guard.init_semaphore) { + // there was a take_and_deinit in between + continue; + } + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(GuardMut(guard)); + }; + + permit.forget(); + } + + let permit = InitPermit(sem); + let (value, _permit) = factory(permit).await?; + let guard = self.inner.write().await; - if guard.value.is_some() { - return Ok(GuardMut(guard)); - } - guard.init_semaphore.clone() - }; - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire_owned().await - }; - - match permit { - Ok(permit) => { - let permit = InitPermit(permit); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.write().await; - - Ok(Self::set0(value, guard)) - } - Err(_closed) => { - let guard = self.inner.write().await; - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardMut(guard)); - } + return Ok(Self::set0(value, guard)); } } @@ -112,37 +119,44 @@ impl OnceCell { F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - let sem = { - let guard = self.inner.read().await; - if guard.value.is_some() { - return Ok(GuardRef(guard)); - } - guard.init_semaphore.clone() - }; - - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire_owned().await - }; - - match permit { - Ok(permit) => { - let permit = InitPermit(permit); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.write().await; - - Ok(Self::set0(value, guard).downgrade()) - } - Err(_closed) => { + loop { + let sem = { let guard = self.inner.read().await; - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardRef(guard)); + if guard.value.is_some() { + return Ok(GuardRef(guard)); + } + guard.init_semaphore.clone() + }; + + { + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire().await + }; + + let Ok(permit) = permit else { + let guard = self.inner.read().await; + if !Arc::ptr_eq(&sem, &guard.init_semaphore) { + // there was a take_and_deinit in between + continue; + } + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(GuardRef(guard)); + }; + + permit.forget(); } + + let permit = InitPermit(sem); + let (value, _permit) = factory(permit).await?; + + let guard = self.inner.write().await; + + return Ok(Self::set0(value, guard).downgrade()); } } @@ -250,15 +264,12 @@ impl<'a, T> GuardMut<'a, T> { /// [`OnceCell::get_or_init`] will wait on it to complete. pub fn take_and_deinit(&mut self) -> (T, InitPermit) { let mut swapped = Inner::default(); - let permit = swapped - .init_semaphore - .clone() - .try_acquire_owned() - .expect("we just created this"); + let sem = swapped.init_semaphore.clone(); + sem.try_acquire().expect("we just created this").forget(); std::mem::swap(&mut *self.0, &mut swapped); swapped .value - .map(|v| (v, InitPermit(permit))) + .map(|v| (v, InitPermit(sem))) .expect("guard is not created unless value has been initialized") } @@ -282,13 +293,23 @@ impl std::ops::Deref for GuardRef<'_, T> { } /// Type held by OnceCell (de)initializing task. -pub struct InitPermit(tokio::sync::OwnedSemaphorePermit); +pub struct InitPermit(Arc); + +impl Drop for InitPermit { + fn drop(&mut self) { + debug_assert_eq!(self.0.available_permits(), 0); + self.0.add_permits(1); + } +} #[cfg(test)] mod tests { + use futures::Future; + use super::*; use std::{ convert::Infallible, + pin::{pin, Pin}, sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; @@ -455,4 +476,94 @@ mod tests { .unwrap(); assert_eq!(*g, "now initialized"); } + + #[tokio::test(start_paused = true)] + async fn reproduce_init_take_deinit_race() { + init_take_deinit_scenario(|cell, factory| { + Box::pin(async { + cell.get_or_init(factory).await.unwrap(); + }) + }) + .await; + } + + #[tokio::test(start_paused = true)] + async fn reproduce_init_take_deinit_race_mut() { + init_take_deinit_scenario(|cell, factory| { + Box::pin(async { + cell.get_mut_or_init(factory).await.unwrap(); + }) + }) + .await; + } + + type BoxedInitFuture = Pin>>>; + type BoxedInitFunction = Box BoxedInitFuture>; + + /// Reproduce an assertion failure with both initialization methods. + /// + /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`. + /// Alternative would be a macro_rules! but that is the last resort. + async fn init_take_deinit_scenario(init_way: F) + where + F: for<'a> Fn( + &'a OnceCell<&'static str>, + BoxedInitFunction<&'static str, Infallible>, + ) -> Pin + 'a>>, + { + let cell = OnceCell::default(); + + // acquire the init_semaphore only permit to drive initializing tasks in order to waiting + // on the same semaphore. + let permit = cell + .inner + .read() + .await + .init_semaphore + .clone() + .try_acquire_owned() + .unwrap(); + + let mut t1 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })), + )); + + let mut t2 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })), + )); + + // drive t2 first to the init_semaphore + tokio::select! { + _ = &mut t2 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // followed by t1 in the init_semaphore + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // now let t2 proceed and initialize + drop(permit); + t2.await; + + let (s, permit) = { cell.get_mut().await.unwrap().take_and_deinit() }; + assert_eq!("t2", s); + + // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from + // the new one. + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // only now we get to initialize it + drop(permit); + t1.await; + + assert_eq!("t1", *cell.get().await.unwrap()); + } }