From 8cd7ddc535df2efe465e543ad4a55e22338d236f Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Wed, 8 May 2024 12:22:44 +0200 Subject: [PATCH] run block decompression from executor (#2386) * run block decompression from executor * add a wrapper with is_closed to oneshot channel * add cancelation test to Executor::spawn_blocking --- src/core/executor.rs | 107 ++++++++++++++++++++++++++++++++++++++++++- src/core/searcher.rs | 3 +- src/store/reader.rs | 32 ++++++++++--- 3 files changed, 133 insertions(+), 9 deletions(-) diff --git a/src/core/executor.rs b/src/core/executor.rs index f4d7d2a13..915534009 100644 --- a/src/core/executor.rs +++ b/src/core/executor.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "quickwit")] +use futures_util::{future::Either, FutureExt}; use rayon::{ThreadPool, ThreadPoolBuilder}; use crate::TantivyError; @@ -91,11 +93,84 @@ impl Executor { } } } + + /// Spawn a task on the pool, returning a future completing on task success. + /// + /// If the task panic, returns `Err(())`. + #[cfg(feature = "quickwit")] + pub fn spawn_blocking( + &self, + cpu_intensive_task: impl FnOnce() -> T + Send + 'static, + ) -> impl std::future::Future> { + match self { + Executor::SingleThread => Either::Left(std::future::ready(Ok(cpu_intensive_task()))), + Executor::ThreadPool(pool) => { + let (sender, receiver) = oneshot_with_sentinel::channel(); + pool.spawn(|| { + if sender.is_closed() { + return; + } + let task_result = cpu_intensive_task(); + let _ = sender.send(task_result); + }); + + let res = receiver.map(|res| res.map_err(|_| ())); + Either::Right(res) + } + } + } +} + +#[cfg(feature = "quickwit")] +mod oneshot_with_sentinel { + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + // TODO get ride of this if oneshot ever gains a is_closed() + + pub struct SenderWithSentinel { + tx: oneshot::Sender, + guard: Arc<()>, + } + + pub struct ReceiverWithSentinel { + rx: oneshot::Receiver, + _guard: Arc<()>, + } + + pub fn channel() -> (SenderWithSentinel, ReceiverWithSentinel) { + let (tx, rx) = oneshot::channel(); + let guard = Arc::new(()); + ( + SenderWithSentinel { + tx, + guard: guard.clone(), + }, + ReceiverWithSentinel { rx, _guard: guard }, + ) + } + + impl SenderWithSentinel { + pub fn send(self, message: T) -> Result<(), oneshot::SendError> { + self.tx.send(message) + } + + pub fn is_closed(&self) -> bool { + Arc::strong_count(&self.guard) == 1 + } + } + + impl std::future::Future for ReceiverWithSentinel { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.rx).poll(ctx) + } + } } #[cfg(test)] mod tests { - use super::Executor; #[test] @@ -147,4 +222,34 @@ mod tests { assert_eq!(result[i], i * 2); } } + + #[cfg(feature = "quickwit")] + #[test] + fn test_cancel_cpu_intensive_tasks() { + use std::sync::atomic::{AtomicU64, Ordering}; + use std::sync::Arc; + use std::time::Duration; + + let counter: Arc = Default::default(); + let mut futures = Vec::new(); + let executor = Executor::multi_thread(3, "search-test").unwrap(); + for _ in 0..1_000 { + let counter_clone = counter.clone(); + let fut = executor.spawn_blocking(move || { + std::thread::sleep(Duration::from_millis(4)); + counter_clone.fetch_add(1, Ordering::SeqCst) + }); + futures.push(fut); + } + std::thread::sleep(Duration::from_millis(5)); + // The first few num_cores tasks should run, but the other should get cancelled. + drop(futures); + while Arc::strong_count(&counter) > 1 { + std::thread::sleep(Duration::from_millis(10)); + } + // with ideal timing, we expect the result to always be 6, but as long as we run some, and + // cancelled most, the test is a success + assert!(counter.load(Ordering::SeqCst) > 0); + assert!(counter.load(Ordering::SeqCst) < 50); + } } diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 56816145e..f74c837c4 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -109,8 +109,9 @@ impl Searcher { &self, doc_address: DocAddress, ) -> crate::Result { + let executor = self.inner.index.search_executor(); let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize]; - store_reader.get_async(doc_address.doc_id).await + store_reader.get_async(doc_address.doc_id, executor).await } /// Access the schema associated with the index of this searcher. diff --git a/src/store/reader.rs b/src/store/reader.rs index b7f243003..44f0df993 100644 --- a/src/store/reader.rs +++ b/src/store/reader.rs @@ -18,6 +18,8 @@ use crate::schema::document::{BinaryDocumentDeserializer, DocumentDeserialize}; use crate::space_usage::StoreSpaceUsage; use crate::store::index::Checkpoint; use crate::DocId; +#[cfg(feature = "quickwit")] +use crate::Executor; pub(crate) const DOCSTORE_CACHE_CAPACITY: usize = 100; @@ -341,7 +343,11 @@ impl StoreReader { /// In most cases use [`get_async`](Self::get_async) /// /// Loads and decompresses a block asynchronously. - async fn read_block_async(&self, checkpoint: &Checkpoint) -> io::Result { + async fn read_block_async( + &self, + checkpoint: &Checkpoint, + executor: &Executor, + ) -> io::Result { let cache_key = checkpoint.byte_range.start; if let Some(block) = self.cache.get_from_cache(checkpoint.byte_range.start) { return Ok(block); @@ -353,8 +359,12 @@ impl StoreReader { .read_bytes_async() .await?; - let decompressed_block = - OwnedBytes::new(self.decompressor.decompress(compressed_block.as_ref())?); + let decompressor = self.decompressor; + let maybe_decompressed_block = executor + .spawn_blocking(move || decompressor.decompress(compressed_block.as_ref())) + .await + .expect("decompression panicked"); + let decompressed_block = OwnedBytes::new(maybe_decompressed_block?); self.cache .put_into_cache(cache_key, decompressed_block.clone()); @@ -363,15 +373,23 @@ impl StoreReader { } /// Reads raw bytes of a given document asynchronously. - pub async fn get_document_bytes_async(&self, doc_id: DocId) -> crate::Result { + pub async fn get_document_bytes_async( + &self, + doc_id: DocId, + executor: &Executor, + ) -> crate::Result { let checkpoint = self.block_checkpoint(doc_id)?; - let block = self.read_block_async(&checkpoint).await?; + let block = self.read_block_async(&checkpoint, executor).await?; Self::get_document_bytes_from_block(block, doc_id, &checkpoint) } /// Fetches a document asynchronously. Async version of [`get`](Self::get). - pub async fn get_async(&self, doc_id: DocId) -> crate::Result { - let mut doc_bytes = self.get_document_bytes_async(doc_id).await?; + pub async fn get_async( + &self, + doc_id: DocId, + executor: &Executor, + ) -> crate::Result { + let mut doc_bytes = self.get_document_bytes_async(doc_id, executor).await?; let deserializer = BinaryDocumentDeserializer::from_reader(&mut doc_bytes) .map_err(crate::TantivyError::from)?;