diff --git a/src/storage/src/scheduler.rs b/src/storage/src/scheduler.rs index d9bb71a7cf..0fdbce6aa0 100644 --- a/src/storage/src/scheduler.rs +++ b/src/storage/src/scheduler.rs @@ -315,6 +315,7 @@ mod tests { use std::sync::atomic::{AtomicBool, AtomicI32}; use std::time::Duration; + use futures_util::future::BoxFuture; use store_api::storage::RegionId; use super::*; @@ -537,10 +538,49 @@ mod tests { .unwrap(); } + struct MockAsyncHandler { + cb: F, + } + + #[async_trait::async_trait] + impl Handler for MockAsyncHandler + where + F: Fn() -> BoxFuture<'static, ()> + Send + Sync, + { + type Request = MockRequest; + + async fn handle_request( + &self, + _req: Self::Request, + token: BoxedRateLimitToken, + finish_notifier: Arc, + ) -> Result<()> { + let fut = (self.cb)(); + fut.await; + token.try_release(); + finish_notifier.notify_one(); + Ok(()) + } + } + #[tokio::test] async fn test_schedule_duplicate_tasks() { common_telemetry::init_default_ut_logging(); - let handler = MockHandler { cb: || {} }; + let (tx, rx) = tokio::sync::watch::channel(false); + let handler = MockAsyncHandler { + cb: move || { + let mut rx = rx.clone(); + Box::pin(async move { + // Block the handler so it can't handle more requests. + loop { + rx.changed().await.unwrap(); + if *rx.borrow() { + break; + } + } + }) as _ // Casts the Pin> to Pin> + }, + }; let config = SchedulerConfig { max_inflight_tasks: 30, }; @@ -557,6 +597,7 @@ mod tests { scheduled_task += 1; } } + tx.send(true).unwrap(); scheduler.stop(true).await.unwrap(); debug!("Schedule tasks: {}", scheduled_task); assert!(scheduled_task < 10);