diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 06ada1d23..79a39baf3 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -100,6 +100,9 @@ mod top_collector; mod top_score_collector; pub use self::top_score_collector::TopDocs; +#[cfg(test)] +pub(crate) use self::top_score_collector::TopScoreSegmentCollector; + mod custom_score_top_collector; pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer}; diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 69fe9091e..9af6f366b 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -128,7 +128,7 @@ pub(crate) struct TopSegmentCollector { } impl TopSegmentCollector { - fn new(segment_id: SegmentLocalId, limit: usize) -> TopSegmentCollector { + pub fn new(segment_id: SegmentLocalId, limit: usize) -> TopSegmentCollector { TopSegmentCollector { limit, heap: BinaryHeap::with_capacity(limit), diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index d7eff9fc3..69de8c8f8 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -429,6 +429,12 @@ impl Collector for TopDocs { /// Segment Collector associated to `TopDocs`. pub struct TopScoreSegmentCollector(TopSegmentCollector); +impl TopScoreSegmentCollector { + pub fn new(segment_id: SegmentLocalId, limit: usize) -> Self { + TopScoreSegmentCollector(TopSegmentCollector::new(segment_id, limit)) + } +} + impl SegmentCollector for TopScoreSegmentCollector { type Fruit = Vec<(Score, DocAddress)>; diff --git a/src/query/block_max_wand.rs b/src/query/block_max_wand.rs index 401d3a078..f714c4b24 100644 --- a/src/query/block_max_wand.rs +++ b/src/query/block_max_wand.rs @@ -3,6 +3,7 @@ use crate::query::score_combiner::ScoreCombiner; use crate::query::{BlockMaxScorer, Scorer}; use crate::DocId; use crate::Score; +use crate::query::scorer::ScorerWithPruning; #[derive(Debug, Copy, Clone, PartialEq, Eq)] struct Pivot { @@ -17,19 +18,18 @@ struct Pivot { /// docsets need to be advanced, and are required to be sorted by the doc they point to. /// /// The pivot is then defined as the lowest DocId that has a chance of matching our condition. -fn find_pivot_position<'a, TScorer, F>( +fn find_pivot_position<'a, TScorer>( mut docsets: impl Iterator, - condition: &F, + lower_bound_score: Score, ) -> Option where - F: Fn(&Score) -> bool, TScorer: BlockMaxScorer + Scorer, { let mut position = 0; let mut upper_bound = Score::default(); while let Some(docset) = docsets.next() { upper_bound += docset.max_score(); - if condition(&upper_bound) { + if lower_bound_score < upper_bound { let pivot_doc = docset.doc(); let first_occurrence = position; while let Some(docset) = docsets.next() { @@ -101,25 +101,22 @@ fn sift_down(docsets: &mut [T]) /// applying [BlockMaxWand] dynamic pruning. /// /// [BlockMaxWand]: https://dl.acm.org/doi/10.1145/2009916.2010048 -pub struct BlockMaxWand { +pub struct BlockMaxWand { docsets: Vec>, doc: DocId, score: Score, combiner: TScoreCombiner, - threshold_fn: ThresholdFn, } -impl BlockMaxWand +impl BlockMaxWand where TScoreCombiner: ScoreCombiner, TScorer: BlockMaxScorer + Scorer, - ThresholdFn: Fn(&Score) -> bool + 'static, { fn new( docsets: Vec, combiner: TScoreCombiner, - threshold_fn: ThresholdFn, - ) -> BlockMaxWand { + ) -> BlockMaxWand { let mut non_empty_docsets: Vec<_> = docsets .into_iter() .flat_map(|mut docset| { @@ -134,26 +131,24 @@ impl BlockMaxWand Option { + fn find_pivot_position(&self, lower_bound_score: Score) -> Option { find_pivot_position( self.docsets.iter().map(|docset| docset.as_ref()), - &self.threshold_fn, - ) + lower_bound_score) } - fn advance_with_pivot(&mut self, pivot: Pivot) -> SkipResult { + fn advance_with_pivot(&mut self, pivot: Pivot, lower_bound_score: Score) -> SkipResult { let block_upper_bound: Score = self.docsets[..=pivot.position] .iter_mut() .map(|docset| docset.block_max_score()) .sum(); - if (self.threshold_fn)(&block_upper_bound) { + if block_upper_bound > lower_bound_score { if pivot.doc == self.docsets[0].doc() { // Since self.docsets is sorted by their current doc, in this branch, all // docsets in [0..=pivot] are positioned on pivot.doc. @@ -206,22 +201,14 @@ impl BlockMaxWand DocSet -for BlockMaxWand +impl DocSet +for BlockMaxWand where TScorer: BlockMaxScorer + Scorer, TScoreCombiner: ScoreCombiner, - ThresholdFn: Fn(&Score) -> bool + 'static, { fn advance(&mut self) -> bool { - while let Some(pivot) = self.find_pivot_position() { - match self.advance_with_pivot(pivot) { - SkipResult::End => { return false }, - SkipResult::Reached=> { return true; } - SkipResult::OverStep => {} - } - } - false + unimplemented!(); } fn skip_next(&mut self, target: DocId) -> SkipResult { @@ -252,16 +239,38 @@ for BlockMaxWand } } -impl Scorer -for BlockMaxWand +impl Scorer +for BlockMaxWand where TScoreCombiner: ScoreCombiner, TScorer: Scorer + BlockMaxScorer, - ThresholdFn: Fn(&Score) -> bool + 'static, { fn score(&mut self) -> Score { self.score } + + /// Returns `Some(&mut self)` if pruning is supported by the current scorer. + /// None, if pruning is score is not supported. + fn get_pruning_scorer(&mut self) -> Option<&mut dyn ScorerWithPruning> { + Some(self) + } +} + +impl ScorerWithPruning +for BlockMaxWand + where + TScoreCombiner: ScoreCombiner, + TScorer: Scorer + BlockMaxScorer { + fn advance_with_pruning(&mut self, lower_bound_score: f32) -> bool { + while let Some(pivot) = self.find_pivot_position(lower_bound_score) { + match self.advance_with_pivot(pivot, lower_bound_score) { + SkipResult::End => { return false }, + SkipResult::Reached=> { return true; } + SkipResult::OverStep => {} + } + } + false + } } #[cfg(test)] @@ -277,8 +286,8 @@ mod tests { use float_cmp::approx_eq; use proptest::strategy::Strategy; use std::cmp::Ordering; - use std::collections::BinaryHeap; use std::num::Wrapping; + use crate::collector::{SegmentCollector, TopScoreSegmentCollector}; #[derive(Debug, Clone)] pub struct VecDocSet { @@ -403,94 +412,17 @@ mod tests { impl Eq for ComparableDoc {} - #[derive(Debug)] - struct TopSegmentCollector { - limit: usize, - heap: BinaryHeap>, - } - - impl TopSegmentCollector { - fn new(limit: usize) -> TopSegmentCollector { - TopSegmentCollector { - limit, - heap: BinaryHeap::with_capacity(limit), - } - } - } - - impl TopSegmentCollector { - pub fn harvest(self) -> Vec<(T, DocId)> { - self.heap - .into_sorted_vec() - .into_iter() - .map(|comparable_doc| (comparable_doc.feature, comparable_doc.doc)) - .collect() - } - - /// Return true iff at least K documents have gone through - /// the collector. - #[inline(always)] - pub(crate) fn at_capacity(&self) -> bool { - self.heap.len() >= self.limit - } - - #[inline(always)] - pub(crate) fn above_threshold(&self, elem: &T) -> bool { - if self.at_capacity() { - elem > &self.heap.peek().unwrap().feature - } else { - true - } - } - - /// Collects a document scored by the given feature - /// - /// It collects documents until it has reached the max capacity. Once it reaches capacity, it - /// will compare the lowest scoring item with the given one and keep whichever is greater. - #[inline(always)] - pub fn collect(&mut self, doc: DocId, feature: T) { - if self.at_capacity() { - // It's ok to unwrap as long as a limit of 0 is forbidden. - if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) { - if limit_feature < feature { - if let Some(mut head) = self.heap.peek_mut() { - head.feature = feature; - head.doc = doc; - } - } - } - } else { - // we have not reached capacity yet, so we can just push the - // element. - self.heap.push(ComparableDoc { feature, doc }); - } - } - } - fn union_vs_bmw(posting_lists: Vec) { let mut union = Union::::from(posting_lists.clone()); - let mut top_union = TopSegmentCollector::::new(10); + let mut top_union = TopScoreSegmentCollector::new(0, 10); while union.advance() { top_union.collect(union.doc(), union.score()); } - let top_bmw = std::rc::Rc::new(std::cell::RefCell::new(TopSegmentCollector::::new( - 10, - ))); - let inner = std::rc::Rc::clone(&top_bmw); - let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default(), move |score| { - inner.borrow().above_threshold(score) - }); - while bmw.advance() { - top_bmw.borrow_mut().collect(bmw.doc(), bmw.score()); - } - drop(bmw); + let top_bmw = TopScoreSegmentCollector::new(0, 10 ); + let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default()); + let top_docs_bnw = top_bmw.collect_scorer(&mut bmw, None); for ((expected_score, expected_doc), (actual_score, actual_doc)) in - top_union.harvest().into_iter().zip( - std::rc::Rc::try_unwrap(top_bmw) - .unwrap() - .into_inner() - .harvest(), - ) + top_union.harvest().into_iter().zip( top_docs_bnw ) { assert!(approx_eq!( f32, @@ -595,7 +527,7 @@ mod tests { VecDocSet::started(vec![(3, 6.0)], 1), ]; assert_eq!( - find_pivot_position(postings.iter(), &|&score| score > 2.0), + find_pivot_position(postings.iter(), 2.0f32), Some(Pivot { position: 1, doc: 1, @@ -603,7 +535,7 @@ mod tests { }) ); assert_eq!( - find_pivot_position(postings.iter(), &|&score| score > 5.0), + find_pivot_position(postings.iter(), 5.0f32), Some(Pivot { position: 2, doc: 2, @@ -611,7 +543,7 @@ mod tests { }) ); assert_eq!( - find_pivot_position(postings.iter(), &|&score| score > 9.0), + find_pivot_position(postings.iter(), 9.0f32), Some(Pivot { position: 4, doc: 3, @@ -619,7 +551,7 @@ mod tests { }) ); assert_eq!( - find_pivot_position(postings.iter(), &|&score| score > 20.0), + find_pivot_position(postings.iter(), 20.0f32), None ); }