diff --git a/libs/utils/src/shared_retryable.rs b/libs/utils/src/shared_retryable.rs index 79f81b17db..06430d83a9 100644 --- a/libs/utils/src/shared_retryable.rs +++ b/libs/utils/src/shared_retryable.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::future::Future; /// Container using which many request handlers can come together and join a single task to /// completion instead of racing each other and their own cancellation. @@ -62,6 +63,27 @@ pub trait Retryable { } } +pub trait MakeFuture { + type Future: Future + Send + 'static; + type Output: Send + 'static; + + fn make_future(self) -> Self::Future; +} + +impl MakeFuture for Fun +where + Fun: FnOnce() -> Fut, + Fut: Future + Send + 'static, + R: Send + 'static, +{ + type Future = Fut; + type Output = R; + + fn make_future(self) -> Self::Future { + self() + } +} + /// Retried task panicked, was cancelled, or never spawned (see [`SharedRetryable::try_restart`]). #[derive(Debug, PartialEq, Eq)] pub struct RetriedTaskPanicked; @@ -69,7 +91,7 @@ pub struct RetriedTaskPanicked; impl SharedRetryable> where T: Clone + std::fmt::Debug + Send + 'static, - E1: Clone + Retryable + std::fmt::Debug + Send + 'static, + E1: Retryable + Clone + std::fmt::Debug + Send + 'static, { /// Restart a previously failed operation unless it already completed with a terminal result. /// @@ -79,10 +101,11 @@ where /// Compared to `Self::try_restart`, this method also spawns the future to run, which would /// otherwise have to be done manually. #[cfg(test)] - pub async fn try_restart_spawn(&self, retry_with: F) -> Result + pub async fn try_restart_spawn( + &self, + retry_with: impl MakeFuture>, + ) -> Result where - F: FnOnce() -> Fut, - Fut: std::future::Future> + Send + 'static, E2: From + From + Send + 'static, { let (recv, maybe_fut) = self.try_restart(retry_with).await; @@ -109,18 +132,15 @@ where /// /// This complication exists because on pageserver we cannot use `tokio::spawn` directly /// at this time. - pub async fn try_restart( + pub async fn try_restart( &self, - retry_with: F, + retry_with: impl MakeFuture>, ) -> ( - impl std::future::Future> + Send + 'static, - Option + Send + 'static>, + impl Future> + Send + 'static, + Option + Send + 'static>, ) where - F: FnOnce() -> Fut, - Fut: std::future::Future> + Send + 'static, - E2: From + Send + 'static, - E2: From, + E2: From + From + Send + 'static, { use futures::future::Either; @@ -136,21 +156,18 @@ where /// Returns a Ok if the previous attempt had resulted in a terminal result. Err is returned /// when an attempt can be joined and possibly needs to be spawned. - async fn decide_to_retry_or_join( + async fn decide_to_retry_or_join( &self, - retry_with: F, + retry_with: impl MakeFuture>, ) -> Result< Result, ( tokio::sync::broadcast::Receiver>, - Option + Send + 'static>, + Option + Send + 'static>, ), > where - F: FnOnce() -> Fut, - Fut: std::future::Future> + Send + 'static, - E2: From, - E2: From, + E2: From + From, { let mut g = self.inner.lock().await; @@ -173,7 +190,7 @@ where None => { // new attempt // panic safety: invoke the factory before configuring the pending value - let fut = retry_with(); + let fut = retry_with.make_future(); let (strong, fut) = self.make_run_and_complete(fut, &mut g); (strong, Some(fut)) @@ -192,17 +209,14 @@ where /// /// Returns an `Arc>` which is valid until the attempt completes, and the future /// which will need to run to completion outside the lifecycle of the caller. - fn make_run_and_complete( + fn make_run_and_complete( &self, - fut: Fut, + fut: impl Future> + Send + 'static, g: &mut tokio::sync::MutexGuard<'_, Option>>>, ) -> ( Arc>>, - impl std::future::Future + Send + 'static, - ) - where - Fut: std::future::Future> + Send + 'static, - { + impl Future + Send + 'static, + ) { #[cfg(debug_assertions)] match &**g { Some(MaybeDone::Pending(weak)) => { @@ -232,10 +246,9 @@ where /// times. fn make_oneshot_alike_receiver( mut rx: tokio::sync::broadcast::Receiver>, - ) -> impl std::future::Future> + Send + 'static + ) -> impl Future> + Send + 'static where - E2: From, - E2: From, + E2: From + From, { use tokio::sync::broadcast::error::RecvError; @@ -262,11 +275,10 @@ where /// Any previous attempt which panicked will be retried, but the `RetriedTaskPanicked` will be /// returned when the most recent attempt panicked. #[cfg(test)] - pub async fn attempt_spawn(&self, attempt_with: F) -> Result - where - F: FnOnce() -> Fut, - Fut: std::future::Future + Send + 'static, - { + pub async fn attempt_spawn( + &self, + attempt_with: impl MakeFuture, + ) -> Result { let (rx, maybe_fut) = { let mut g = self.inner.lock().await; @@ -282,7 +294,7 @@ where let (strong, maybe_fut) = match maybe_rx { Some(strong) => (strong, None), None => { - let fut = attempt_with(); + let fut = attempt_with.make_future(); let (strong, fut) = self.make_run_and_complete_any(fut, &mut g); (strong, Some(fut)) @@ -313,17 +325,14 @@ where /// and not running the future will then require a new attempt. /// /// Also returns an `Arc>` which is valid until the attempt completes. - fn make_run_and_complete_any( + fn make_run_and_complete_any( &self, - fut: Fut, + fut: impl Future + Send + 'static, g: &mut tokio::sync::MutexGuard<'_, Option>>, ) -> ( Arc>, - impl std::future::Future + Send + 'static, - ) - where - Fut: std::future::Future + Send + 'static, - { + impl Future + Send + 'static, + ) { let (tx, rx) = tokio::sync::broadcast::channel(1); let strong = Arc::new(rx); @@ -331,8 +340,7 @@ where let retry = { let strong = strong.clone(); - let this = self.clone(); - async move { this.run_and_complete(fut, tx, strong).await } + self.clone().run_and_complete(fut, tx, strong) }; #[cfg(debug_assertions)] @@ -350,14 +358,12 @@ where /// Run the actual attempt, and communicate the response via both: /// - setting the `MaybeDone::Done` /// - the broadcast channel - async fn run_and_complete( - &self, - fut: Fut, + async fn run_and_complete( + self, + fut: impl Future, tx: tokio::sync::broadcast::Sender, strong: Arc>, - ) where - Fut: std::future::Future, - { + ) { let res = fut.await; { @@ -380,7 +386,7 @@ where #[cfg(test)] fn make_oneshot_alike_receiver_any( mut rx: tokio::sync::broadcast::Receiver, - ) -> impl std::future::Future> + Send + 'static { + ) -> impl Future> + Send + 'static { use tokio::sync::broadcast::error::RecvError; async move { @@ -555,8 +561,8 @@ mod tests { // but we can still reach a terminal state if the api is not misused or the // should_be_spawned winner is not cancelled - let recv1 = shr.try_restart_spawn::<_, _, OuterError>(|| async move { Ok(42) }); - let recv2 = shr.try_restart_spawn::<_, _, OuterError>(|| async move { Ok(43) }); + let recv1 = shr.try_restart_spawn::(|| async move { Ok(42) }); + let recv2 = shr.try_restart_spawn::(|| async move { Ok(43) }); assert_eq!(recv1.await.unwrap(), 42); assert_eq!(recv2.await.unwrap(), 42, "43 should never be returned");