use rayon::{ThreadPool, ThreadPoolBuilder}; use crate::TantivyError; /// Search executor whether search request are single thread or multithread. /// /// We don't expose Rayon thread pool directly here for several reasons. /// /// First dependency hell. It is not a good idea to expose the /// API of a dependency, knowing it might conflict with a different version /// used by the client. Second, we may stop using rayon in the future. pub enum Executor { /// Single thread variant of an Executor SingleThread, /// Thread pool variant of an Executor ThreadPool(ThreadPool), } impl Executor { /// Creates an Executor that performs all task in the caller thread. pub fn single_thread() -> Executor { Executor::SingleThread } /// Creates an Executor that dispatches the tasks in a thread pool. pub fn multi_thread(num_threads: usize, prefix: &'static str) -> crate::Result { let pool = ThreadPoolBuilder::new() .num_threads(num_threads) .thread_name(move |num| format!("{prefix}{num}")) .build()?; Ok(Executor::ThreadPool(pool)) } /// Perform a map in the thread pool. /// /// Regardless of the executor (`SingleThread` or `ThreadPool`), panics in the task /// will propagate to the caller. pub fn map< A: Send, R: Send, AIterator: Iterator, F: Sized + Sync + Fn(A) -> crate::Result, >( &self, f: F, args: AIterator, ) -> crate::Result> { match self { Executor::SingleThread => args.map(f).collect::>(), Executor::ThreadPool(pool) => { let args: Vec = args.collect(); let num_fruits = args.len(); let fruit_receiver = { let (fruit_sender, fruit_receiver) = crossbeam_channel::unbounded(); pool.scope(|scope| { for (idx, arg) in args.into_iter().enumerate() { // We name references for f and fruit_sender_ref because we do not // want these two to be moved into the closure. let f_ref = &f; let fruit_sender_ref = &fruit_sender; scope.spawn(move |_| { let fruit = f_ref(arg); if let Err(err) = fruit_sender_ref.send((idx, fruit)) { error!( "Failed to send search task. It probably means all search \ threads have panicked. {:?}", err ); } }); } }); fruit_receiver // This ends the scope of fruit_sender. // This is important as it makes it possible for the fruit_receiver iteration to // terminate. }; let mut result_placeholders: Vec> = std::iter::repeat_with(|| None).take(num_fruits).collect(); for (pos, fruit_res) in fruit_receiver { let fruit = fruit_res?; result_placeholders[pos] = Some(fruit); } let results: Vec = result_placeholders.into_iter().flatten().collect(); if results.len() != num_fruits { return Err(TantivyError::InternalError( "One of the mapped execution failed.".to_string(), )); } Ok(results) } } } } #[cfg(test)] mod tests { use super::Executor; #[test] #[should_panic(expected = "panic should propagate")] fn test_panic_propagates_single_thread() { let _result: Vec = Executor::single_thread() .map( |_| { panic!("panic should propagate"); }, vec![0].into_iter(), ) .unwrap(); } #[test] #[should_panic] //< unfortunately the panic message is not propagated fn test_panic_propagates_multi_thread() { let _result: Vec = Executor::multi_thread(1, "search-test") .unwrap() .map( |_| { panic!("panic should propagate"); }, vec![0].into_iter(), ) .unwrap(); } #[test] fn test_map_singlethread() { let result: Vec = Executor::single_thread() .map(|i| Ok(i * 2), 0..1_000) .unwrap(); assert_eq!(result.len(), 1_000); for i in 0..1_000 { assert_eq!(result[i], i * 2); } } #[test] fn test_map_multithread() { let result: Vec = Executor::multi_thread(3, "search-test") .unwrap() .map(|i| Ok(i * 2), 0..10) .unwrap(); assert_eq!(result.len(), 10); for i in 0..10 { assert_eq!(result[i], i * 2); } } }