diff --git a/src/collector/mod.rs b/src/collector/mod.rs index f32bd73a2..06ada1d23 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -109,6 +109,8 @@ pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker}; mod facet_collector; pub use self::facet_collector::FacetCollector; +use crate::fastfield::DeleteBitSet; +use crate::query::Scorer; /// `Fruit` is the type for the result of our collection. /// e.g. `usize` for the `Count` collector. @@ -161,7 +163,7 @@ pub trait Collector: Sync { /// /// `.collect(doc, score)` will be called for every documents /// matching the query. -pub trait SegmentCollector: 'static { +pub trait SegmentCollector: 'static + Sized { /// `Fruit` is the type for the result of our collection. /// e.g. `usize` for the `Count` collector. type Fruit: Fruit; @@ -171,6 +173,19 @@ pub trait SegmentCollector: 'static { /// Extract the fruit of the collection from the `SegmentCollector`. fn harvest(self) -> Self::Fruit; + + fn collect_scorer(mut self, scorer: &mut dyn Scorer, delete_bitset: Option<&DeleteBitSet>) -> Self::Fruit { + if let Some(delete_bitset) = delete_bitset { + scorer.for_each(&mut |doc, score| { + if delete_bitset.is_alive(doc) { + self.collect(doc, score); + } + }); + } else { + scorer.for_each(&mut |doc, score| self.collect(doc, score)); + } + self.harvest() + } } // ----------------------------------------------- diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index f3b5742e3..69fe9091e 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -69,9 +69,7 @@ where /// # Panics /// The method panics if limit is 0 pub fn with_limit(limit: usize) -> TopCollector { - if limit < 1 { - panic!("Limit must be strictly greater than 0."); - } + assert!(limit > 0, "Limit must be strictly greater than 0."); TopCollector { limit, _marker: PhantomData, @@ -124,7 +122,7 @@ where /// The theorical complexity for collecting the top `K` out of `n` documents /// is `O(n log K)`. pub(crate) struct TopSegmentCollector { - limit: usize, + pub limit: usize, heap: BinaryHeap>, segment_id: u32, } @@ -161,6 +159,10 @@ impl TopSegmentCollector { self.heap.len() >= self.limit } + pub fn pruning_score(&self) -> Option { + self.heap.peek().map(|head| head.feature.clone()) + } + /// Collects a document scored by the given feature /// /// It collects documents until it has reached the max capacity. Once it reaches capacity, it diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 4a8cb48e0..d7eff9fc3 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -6,7 +6,7 @@ use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; use crate::collector::{ CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector, }; -use crate::fastfield::FastFieldReader; +use crate::fastfield::{FastFieldReader, DeleteBitSet}; use crate::schema::Field; use crate::DocAddress; use crate::DocId; @@ -14,6 +14,7 @@ use crate::Score; use crate::SegmentLocalId; use crate::SegmentReader; use std::fmt; +use crate::query::Scorer; /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. @@ -438,6 +439,34 @@ impl SegmentCollector for TopScoreSegmentCollector { fn harvest(self) -> Vec<(Score, DocAddress)> { self.0.harvest() } + + fn collect_scorer(mut self, scorer: &mut dyn Scorer, delete_bitset: Option<&DeleteBitSet>) -> Self::Fruit { + if let Some(delete_bitset) = delete_bitset { + scorer.for_each(&mut |doc, score| { + if delete_bitset.is_alive(doc) { + self.collect(doc, score); + } + }); + return self.harvest(); + // TODO(implement the optimisation for deletes) + } + if let Some(pruning_scorer) = scorer.get_pruning_scorer() { + let limit = self.0.limit; + for _ in 0..limit { + if !pruning_scorer.advance() { + return self.harvest(); + } + self.collect(pruning_scorer.doc(), pruning_scorer.score()); + } + let mut pruning_score = self.0.pruning_score().unwrap_or(0.0f32); + while pruning_scorer.advance_with_pruning(pruning_score) { + self.collect(pruning_scorer.doc(), pruning_scorer.score()); + pruning_score = self.0.pruning_score().unwrap_or(0.0f32); + } + } + scorer.for_each(&mut |doc, score| self.collect(doc, score)); + self.harvest() + } } #[cfg(test)] diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 8aa808a8e..83054ba58 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -4,7 +4,6 @@ use crate::core::Executor; use crate::core::InvertedIndexReader; use crate::core::SegmentReader; use crate::query::Query; -use crate::query::Scorer; use crate::query::Weight; use crate::schema::Document; use crate::schema::Schema; @@ -24,17 +23,9 @@ fn collect_segment( segment_reader: &SegmentReader, ) -> crate::Result { let mut scorer = weight.scorer(segment_reader, 1.0f32)?; - let mut segment_collector = collector.for_segment(segment_ord as u32, segment_reader)?; - if let Some(delete_bitset) = segment_reader.delete_bitset() { - scorer.for_each(&mut |doc, score| { - if delete_bitset.is_alive(doc) { - segment_collector.collect(doc, score); - } - }); - } else { - scorer.for_each(&mut |doc, score| segment_collector.collect(doc, score)); - } - Ok(segment_collector.harvest()) + let segment_collector = + collector.for_segment(segment_ord as u32, segment_reader)?; + Ok(segment_collector.collect_scorer(&mut scorer, segment_reader.delete_bitset())) } /// Holds a list of `SegmentReader`s ready for search. diff --git a/src/query/block_max_scorer.rs b/src/query/block_max_scorer.rs index 3ece3adbf..8f15f224d 100644 --- a/src/query/block_max_scorer.rs +++ b/src/query/block_max_scorer.rs @@ -35,10 +35,10 @@ impl BlockMaxScorer for Box { fn block_max_score(&mut self) -> Score { self.deref_mut().block_max_score() } - fn max_score(&self) -> Score { - self.deref().max_score() - } fn block_max_doc(&mut self) -> DocId { self.deref_mut().block_max_doc() } + fn max_score(&self) -> Score { + self.deref().max_score() + } } diff --git a/src/query/block_max_wand.rs b/src/query/block_max_wand.rs index fb4f5f424..401d3a078 100644 --- a/src/query/block_max_wand.rs +++ b/src/query/block_max_wand.rs @@ -13,6 +13,10 @@ struct Pivot { /// Find the position in the sorted list of posting lists of the **pivot**. +/// +/// docsets need to be advanced, and are required to be sorted by the doc they point to. +/// +/// The pivot is then defined as the lowest DocId that has a chance of matching our condition. fn find_pivot_position<'a, TScorer, F>( mut docsets: impl Iterator, condition: &F, @@ -129,10 +133,10 @@ impl BlockMaxWand::doc); BlockMaxWand { docsets: non_empty_docsets, - doc: 0, - score: 0f32, combiner, threshold_fn, + doc: 0u32, + score: 0f32 } } @@ -151,6 +155,11 @@ impl BlockMaxWand BlockMaxWand::doc); SkipResult::Reached } else { - // The subraction is correct because otherwise we would go to the other branch. + // The substraction is correct because otherwise we would go to the other branch. let advanced_idx = pivot.first_occurrence - 1; if !self.docsets[advanced_idx].advance() { self.docsets.swap_remove(advanced_idx); diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 02a4fb021..82de40603 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -21,8 +21,21 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static { callback(self.doc(), self.score()); } } + + /// Returns `Some(&mut self)` if pruning is supported by the current scorer. + /// None, if pruning is score is not supported. + fn get_pruning_scorer(&mut self) -> Option<&mut dyn ScorerWithPruning> { + None + } } +pub trait ScorerWithPruning: Scorer { + /// Advance to the next document that has a score strictly greater than + /// `lower_bound_score`. + fn advance_with_pruning(&mut self, lower_bound_score: Score) -> bool; +} + + impl_downcast!(Scorer); impl Scorer for Box {