diff --git a/libs/utils/src/shared_retryable.rs b/libs/utils/src/shared_retryable.rs index 1d2897860f..52bae1fe3b 100644 --- a/libs/utils/src/shared_retryable.rs +++ b/libs/utils/src/shared_retryable.rs @@ -41,96 +41,7 @@ use std::sync::Arc; /// A shared service value completes the infallible work once, even if called concurrently by /// multiple cancellable tasks. /// -/// ``` -/// use utils::shared_retryable::{SharedRetryable, Retryable, RetriedTaskPanicked}; -/// use std::sync::Arc; -/// -/// #[derive(Debug, Clone, Copy)] -/// enum OneLevelError { -/// TaskPanicked -/// } -/// -/// impl Retryable for OneLevelError { -/// fn is_permanent(&self) -> bool { -/// // for a single level errors, this wording is weird -/// !matches!(self, OneLevelError::TaskPanicked) -/// } -/// } -/// -/// impl From for OneLevelError { -/// fn from(_: RetriedTaskPanicked) -> Self { -/// OneLevelError::TaskPanicked -/// } -/// } -/// -/// #[derive(Clone, Default)] -/// struct Service(SharedRetryable>); -/// -/// impl Service { -/// async fn work(&self, completions: Arc) -> Result { -/// self.0.try_restart_spawn( -/// || async move { -/// // give time to cancel some of the tasks -/// tokio::time::sleep(std::time::Duration::from_secs(1)).await; -/// completions.fetch_add(1, std::sync::atomic::Ordering::Relaxed); -/// Self::work_once().await -/// } -/// ) -/// .await -/// } -/// -/// async fn work_once() -> Result { -/// Ok(42) -/// } -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let svc = Service::default(); -/// -/// let mut js = tokio::task::JoinSet::new(); -/// -/// let barrier = Arc::new(tokio::sync::Barrier::new(10 + 1)); -/// let completions = Arc::new(std::sync::atomic::AtomicUsize::new(0)); -/// -/// let handles = (0..10).map(|_| js.spawn({ -/// let svc = svc.clone(); -/// let barrier = barrier.clone(); -/// let completions = completions.clone(); -/// async move { -/// // make sure all tasks are ready to start at the same time -/// barrier.wait().await; -/// // after successfully starting the work, any of the futures could get cancelled -/// svc.work(completions).await -/// } -/// })).collect::>(); -/// -/// barrier.wait().await; -/// -/// tokio::time::sleep(std::time::Duration::from_millis(100)).await; -/// -/// handles[5].abort(); -/// -/// let mut cancellations = 0; -/// -/// while let Some(res) = js.join_next().await { -/// // all complete with the same result -/// match res { -/// Ok(res) => assert_eq!(res.unwrap(), 42), -/// Err(je) => { -/// // except for the one task we cancelled; it's cancelling -/// // does not interfere with the result -/// assert!(je.is_cancelled()); -/// cancellations += 1; -/// assert_eq!(cancellations, 1, "only 6th task was aborted"); -/// } -/// } -/// } -/// -/// // there will be at most one terminal completion -/// assert_eq!(completions.load(std::sync::atomic::Ordering::Relaxed), 1); -/// } -/// ``` +/// Example moved as a test `service_example`. #[derive(Clone)] pub struct SharedRetryable { inner: Arc>>>, @@ -167,6 +78,7 @@ where /// /// Compared to `Self::try_restart`, this method also spawns the future to run, which would /// otherwise have to be done manually. + #[cfg(test)] pub async fn try_restart_spawn(&self, retry_with: F) -> Result where F: FnOnce() -> Fut, @@ -349,6 +261,7 @@ where /// /// Any previous attempt which panicked will be retried, but the `RetriedTaskPanicked` will be /// returned when the most recent attempt panicked. + #[cfg(test)] pub async fn attempt_spawn(&self, attempt_with: F) -> Result where F: FnOnce() -> Fut, @@ -464,6 +377,7 @@ where drop(tx.send(res)); } + #[cfg(test)] fn make_oneshot_alike_receiver_any( mut rx: tokio::sync::broadcast::Receiver, ) -> impl std::future::Future> + Send + 'static { @@ -521,8 +435,8 @@ impl MaybeDone { #[cfg(test)] mod tests { - use super::{RetriedTaskPanicked, Retryable, SharedRetryable}; + use std::sync::Arc; #[derive(Debug)] enum OuterError { @@ -647,4 +561,97 @@ mod tests { assert_eq!(recv1.await.unwrap(), 42); assert_eq!(recv2.await.unwrap(), 42, "43 should never be returned"); } + + #[tokio::test] + async fn service_example() { + #[derive(Debug, Clone, Copy)] + enum OneLevelError { + TaskPanicked, + } + + impl Retryable for OneLevelError { + fn is_permanent(&self) -> bool { + // for a single level errors, this wording is weird + !matches!(self, OneLevelError::TaskPanicked) + } + } + + impl From for OneLevelError { + fn from(_: RetriedTaskPanicked) -> Self { + OneLevelError::TaskPanicked + } + } + + #[derive(Clone, Default)] + struct Service(SharedRetryable>); + + impl Service { + async fn work( + &self, + completions: Arc, + ) -> Result { + self.0 + .try_restart_spawn(|| async move { + // give time to cancel some of the tasks + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + completions.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Self::work_once().await + }) + .await + } + + async fn work_once() -> Result { + Ok(42) + } + } + + let svc = Service::default(); + + let mut js = tokio::task::JoinSet::new(); + + let barrier = Arc::new(tokio::sync::Barrier::new(10 + 1)); + let completions = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + + let handles = (0..10) + .map(|_| { + js.spawn({ + let svc = svc.clone(); + let barrier = barrier.clone(); + let completions = completions.clone(); + async move { + // make sure all tasks are ready to start at the same time + barrier.wait().await; + // after successfully starting the work, any of the futures could get cancelled + svc.work(completions).await + } + }) + }) + .collect::>(); + + barrier.wait().await; + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + handles[5].abort(); + + let mut cancellations = 0; + + while let Some(res) = js.join_next().await { + // all complete with the same result + match res { + Ok(res) => assert_eq!(res.unwrap(), 42), + Err(je) => { + // except for the one task we cancelled; it's cancelling + // does not interfere with the result + assert!(je.is_cancelled()); + cancellations += 1; + assert_eq!(cancellations, 1, "only 6th task was aborted"); + // however we cannot assert that everytime we get to cancel the 6th task + } + } + } + + // there will be at most one terminal completion + assert_eq!(completions.load(std::sync::atomic::Ordering::Relaxed), 1); + } }