diff --git a/src/collector/mod.rs b/src/collector/mod.rs index cd56cb4e5..603aa4bbc 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -109,7 +109,7 @@ pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker}; mod facet_collector; pub use self::facet_collector::FacetCollector; -use crate::query::Scorer; +use crate::query::Weight; /// `Fruit` is the type for the result of our collection. /// e.g. `usize` for the `Count` collector. @@ -159,21 +159,22 @@ pub trait Collector: Sync { /// Created a segment collector and fn collect_segment( &self, - scorer: &mut dyn Scorer, + weight: &dyn Weight, segment_ord: u32, - segment_reader: &SegmentReader, + reader: &SegmentReader, ) -> crate::Result<::Fruit> { - let mut segment_collector = self.for_segment(segment_ord as u32, segment_reader)?; - if let Some(delete_bitset) = segment_reader.delete_bitset() { - scorer.for_each(&mut |doc, score| { + let mut segment_collector = self.for_segment(segment_ord as u32, reader)?; + + if let Some(delete_bitset) = reader.delete_bitset() { + weight.for_each(reader, &mut |doc, score| { if delete_bitset.is_alive(doc) { segment_collector.collect(doc, score); } - }); + })?; } else { - scorer.for_each(&mut |doc, score| { + weight.for_each(reader, &mut |doc, score| { segment_collector.collect(doc, score); - }) + })?; } Ok(segment_collector.harvest()) } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 1ed2c8c0c..6eeb287ac 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -6,9 +6,8 @@ use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; use crate::collector::{ CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector, }; -use crate::docset::TERMINATED; use crate::fastfield::FastFieldReader; -use crate::query::Scorer; +use crate::query::Weight; use crate::schema::Field; use crate::DocAddress; use crate::DocId; @@ -472,45 +471,52 @@ impl Collector for TopDocs { fn collect_segment( &self, - scorer: &mut dyn Scorer, + weight: &dyn Weight, segment_ord: u32, - segment_reader: &SegmentReader, + reader: &SegmentReader, ) -> crate::Result<::Fruit> { - let mut heap: BinaryHeap> = - BinaryHeap::with_capacity(self.0.limit + self.0.offset); - // first we fill the heap with the first `limit` elements. - let mut doc = scorer.doc(); - while doc != TERMINATED && heap.len() < (self.0.limit + self.0.offset) { - if !segment_reader.is_deleted(doc) { - let score = scorer.score(); - heap.push(ComparableDoc { - feature: score, - doc, - }); - } - doc = scorer.advance(); - } + let heap_len = self.0.limit + self.0.offset; + let mut heap: BinaryHeap> = BinaryHeap::with_capacity(heap_len); - let threshold = heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN); - - if let Some(delete_bitset) = segment_reader.delete_bitset() { - scorer.for_each_pruning(threshold, &mut |doc, score| { - if delete_bitset.is_alive(doc) { - *heap.peek_mut().unwrap() = ComparableDoc { - feature: score, - doc, - }; + if let Some(delete_bitset) = reader.delete_bitset() { + let mut threshold = f32::MIN; + weight.for_each_pruning(threshold, reader, &mut |doc, score| { + if delete_bitset.is_deleted(doc) { + return threshold; } - heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN) - }); - } else { - scorer.for_each_pruning(threshold, &mut |doc, score| { - *heap.peek_mut().unwrap() = ComparableDoc { + let heap_item = ComparableDoc { feature: score, doc, }; + if heap.len() < heap_len { + heap.push(heap_item); + if heap.len() == heap_len { + threshold = heap.peek().map(|el| el.feature).unwrap_or(f32::MIN); + } + return threshold; + } + *heap.peek_mut().unwrap() = heap_item; + threshold = heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN); + threshold + })?; + } else { + weight.for_each_pruning(f32::MIN, reader, &mut |doc, score| { + let heap_item = ComparableDoc { + feature: score, + doc, + }; + if heap.len() < heap_len { + heap.push(heap_item); + // TODO the threshold is suboptimal for heap.len == heap_len + if heap.len() == heap_len { + return heap.peek().map(|el| el.feature).unwrap_or(f32::MIN); + } else { + return f32::MIN; + } + } + *heap.peek_mut().unwrap() = heap_item; heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN) - }); + })?; } let fruit = heap @@ -518,7 +524,6 @@ impl Collector for TopDocs { .into_iter() .map(|cid| (cid.feature, DocAddress(segment_ord, cid.doc))) .collect(); - Ok(fruit) } } diff --git a/src/core/searcher.rs b/src/core/searcher.rs index f9fb9da4d..8e8775efd 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -140,8 +140,7 @@ impl Searcher { let segment_readers = self.segment_readers(); let fruits = executor.map( |(segment_ord, segment_reader)| { - let mut scorer = weight.scorer(segment_reader, 1.0f32)?; - collector.collect_segment(scorer.as_mut(), segment_ord as u32, segment_reader) + collector.collect_segment(weight.as_ref(), segment_ord as u32, segment_reader) }, segment_readers.iter().enumerate(), )?; diff --git a/src/postings/skip.rs b/src/postings/skip.rs index d5a67de70..6a6d66dc4 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -102,6 +102,7 @@ impl SkipReader { self.remaining_docs = doc_freq; } + #[cfg(test)] #[inline(always)] pub(crate) fn last_doc_in_block(&self) -> DocId { self.last_doc_in_block diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 8b6a6c881..95f6091cc 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -2,6 +2,7 @@ use crate::core::SegmentReader; use crate::query::explanation::does_not_match; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner}; use crate::query::term_query::TermScorer; +use crate::query::weight::{for_each_pruning_scorer, for_each_scorer}; use crate::query::EmptyScorer; use crate::query::Exclude; use crate::query::Occur; @@ -10,16 +11,21 @@ use crate::query::Scorer; use crate::query::Union; use crate::query::Weight; use crate::query::{intersect_scorers, Explanation}; -use crate::DocId; +use crate::{DocId, Score}; use std::collections::HashMap; -fn scorer_union(scorers: Vec>) -> Box +enum SpecializedScorer { + TermUnion(Union), + Other(Box), +} + +fn scorer_union(scorers: Vec>) -> SpecializedScorer where TScoreCombiner: ScoreCombiner, { assert!(!scorers.is_empty()); if scorers.len() == 1 { - return scorers.into_iter().next().unwrap(); //< we checked the size beforehands + return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehands } { @@ -29,14 +35,21 @@ where .into_iter() .map(|scorer| *(scorer.downcast::().map_err(|_| ()).unwrap())) .collect(); - let scorer: Box = - Box::new(Union::::from(scorers)); - return scorer; + return SpecializedScorer::TermUnion(Union::::from( + scorers, + )); } } + SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers))) +} - let scorer: Box = Box::new(Union::<_, TScoreCombiner>::from(scorers)); - scorer +impl Into> for SpecializedScorer { + fn into(self) -> Box { + match self { + Self::TermUnion(union) => Box::new(union), + Self::Other(scorer) => scorer, + } + } } pub struct BooleanWeight { @@ -72,41 +85,50 @@ impl BooleanWeight { &self, reader: &SegmentReader, boost: f32, - ) -> crate::Result> { + ) -> crate::Result> { let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; - let should_scorer_opt: Option> = per_occur_scorers + let should_scorer_opt: Option> = per_occur_scorers .remove(&Occur::Should) .map(scorer_union::); let exclude_scorer_opt: Option> = per_occur_scorers .remove(&Occur::MustNot) - .map(scorer_union::); + .map(scorer_union::) + .map(Into::into); let must_scorer_opt: Option> = per_occur_scorers .remove(&Occur::Must) .map(intersect_scorers); - let positive_scorer: Box = match (should_scorer_opt, must_scorer_opt) { - (Some(should_scorer), Some(must_scorer)) => { - if self.scoring_enabled { - Box::new(RequiredOptionalScorer::<_, _, TScoreCombiner>::new( - must_scorer, - should_scorer, - )) - } else { - must_scorer + let positive_scorer: SpecializedScorer = + match (should_scorer_opt, must_scorer_opt) { + (Some(should_scorer), Some(must_scorer)) => { + if self.scoring_enabled { + SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< + Box, + Box, + TScoreCombiner, + >::new( + must_scorer, should_scorer.into() + ))) + } else { + SpecializedScorer::Other(must_scorer) + } } - } - (None, Some(must_scorer)) => must_scorer, - (Some(should_scorer), None) => should_scorer, - (None, None) => { - return Ok(Box::new(EmptyScorer)); - } - }; + (None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer), + (Some(should_scorer), None) => should_scorer, + (None, None) => { + return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); + } + }; if let Some(exclude_scorer) = exclude_scorer_opt { - Ok(Box::new(Exclude::new(positive_scorer, exclude_scorer))) + let positive_scorer_boxed: Box = positive_scorer.into(); + Ok(SpecializedScorer::Other(Box::new(Exclude::new( + positive_scorer_boxed, + exclude_scorer, + )))) } else { Ok(positive_scorer) } @@ -126,8 +148,10 @@ impl Weight for BooleanWeight { } } else if self.scoring_enabled { self.complex_scorer::(reader, boost) + .map(Into::into) } else { self.complex_scorer::(reader, boost) + .map(Into::into) } } @@ -150,6 +174,51 @@ impl Weight for BooleanWeight { } Ok(explanation) } + + fn for_each( + &self, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score), + ) -> crate::Result<()> { + let scorer = self.complex_scorer::(reader, 1.0f32)?; + match scorer { + SpecializedScorer::TermUnion(mut union_scorer) => { + for_each_scorer(&mut union_scorer, callback); + } + SpecializedScorer::Other(mut scorer) => { + for_each_scorer(scorer.as_mut(), callback); + } + } + Ok(()) + } + + /// Calls `callback` with all of the `(doc, score)` for which score + /// is exceeding a given threshold. + /// + /// This method is useful for the TopDocs collector. + /// For all docsets, the blanket implementation has the benefit + /// of prefiltering (doc, score) pairs, avoiding the + /// virtual dispatch cost. + /// + /// More importantly, it makes it possible for scorers to implement + /// important optimization (e.g. BlockWAND for union). + fn for_each_pruning( + &self, + threshold: f32, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score) -> Score, + ) -> crate::Result<()> { + let scorer = self.complex_scorer::(reader, 1.0f32)?; + match scorer { + SpecializedScorer::TermUnion(mut union_scorer) => { + for_each_pruning_scorer(&mut union_scorer, threshold, callback); + } + SpecializedScorer::Other(mut scorer) => { + for_each_pruning_scorer(scorer.as_mut(), threshold, callback); + } + } + Ok(()) + } } fn is_positive_occur(occur: Occur) -> bool { diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index 2a02dc790..5eddb5f4d 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -101,7 +101,10 @@ mod tests { ConstScorer::from(VecDocSet::from(vec![])), ); let mut docs = vec![]; - reqoptscorer.for_each(&mut |doc, _| docs.push(doc)); + while reqoptscorer.doc() != TERMINATED { + docs.push(reqoptscorer.doc()); + reqoptscorer.advance(); + } assert_eq!(docs, req); } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 923d5ff7a..d8a14951c 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,4 +1,4 @@ -use crate::docset::{DocSet, TERMINATED}; +use crate::docset::DocSet; use crate::DocId; use crate::Score; use downcast_rs::impl_downcast; @@ -12,41 +12,6 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static { /// /// This method will perform a bit of computation and is not cached. fn score(&mut self) -> Score; - - /// Iterates through all of the document matched by the DocSet - /// `DocSet` and push the scored documents to the collector. - fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) { - let mut doc = self.doc(); - while doc != TERMINATED { - callback(doc, self.score()); - doc = self.advance(); - } - } - - /// Calls `callback` with all of the `(doc, score)` for which score - /// is exceeding a given threshold. - /// - /// This method is useful for the TopDocs collector. - /// For all docsets, the blanket implementation has the benefit - /// of prefiltering (doc, score) pairs, avoiding the - /// virtual dispatch cost. - /// - /// More importantly, it makes it possible for scorers to implement - /// important optimization (e.g. BlockWAND for union). - fn for_each_pruning( - &mut self, - mut threshold: f32, - callback: &mut dyn FnMut(DocId, Score) -> Score, - ) { - let mut doc = self.doc(); - while doc != TERMINATED { - let score = self.score(); - if score > threshold { - threshold = callback(doc, score); - } - doc = self.advance(); - } - } } impl_downcast!(Scorer); @@ -55,11 +20,6 @@ impl Scorer for Box { fn score(&mut self) -> Score { self.deref_mut().score() } - - fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) { - let scorer = self.deref_mut(); - scorer.for_each(callback); - } } /// Wraps a `DocSet` and simply returns a constant `Scorer`. diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 3d53827ee..57ca3f87e 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -4,12 +4,13 @@ use crate::docset::DocSet; use crate::postings::SegmentPostings; use crate::query::bm25::BM25Weight; use crate::query::explanation::does_not_match; +use crate::query::weight::{for_each_pruning_scorer, for_each_scorer}; use crate::query::Weight; use crate::query::{Explanation, Scorer}; use crate::schema::IndexRecordOption; -use crate::DocId; use crate::Result; use crate::Term; +use crate::{DocId, Score}; pub struct TermWeight { term: Term, @@ -43,6 +44,39 @@ impl Weight for TermWeight { .unwrap_or(0)) } } + + /// Iterates through all of the document matched by the DocSet + /// `DocSet` and push the scored documents to the collector. + fn for_each( + &self, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score), + ) -> crate::Result<()> { + let mut scorer = self.scorer_specialized(reader, 1.0f32)?; + for_each_scorer(&mut scorer, callback); + Ok(()) + } + + /// Calls `callback` with all of the `(doc, score)` for which score + /// is exceeding a given threshold. + /// + /// This method is useful for the TopDocs collector. + /// For all docsets, the blanket implementation has the benefit + /// of prefiltering (doc, score) pairs, avoiding the + /// virtual dispatch cost. + /// + /// More importantly, it makes it possible for scorers to implement + /// important optimization (e.g. BlockWAND for union). + fn for_each_pruning( + &self, + threshold: f32, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score) -> Score, + ) -> crate::Result<()> { + let mut scorer = self.scorer(reader, 1.0f32)?; + for_each_pruning_scorer(&mut scorer, threshold, callback); + Ok(()) + } } impl TermWeight { diff --git a/src/query/weight.rs b/src/query/weight.rs index 821cebfd7..232f1824a 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -1,7 +1,45 @@ use super::Scorer; use crate::core::SegmentReader; use crate::query::Explanation; -use crate::DocId; +use crate::{DocId, Score, TERMINATED}; + +/// Iterates through all of the document matched by the DocSet +/// `DocSet` and push the scored documents to the collector. +pub(crate) fn for_each_scorer( + scorer: &mut TScorer, + callback: &mut dyn FnMut(DocId, Score), +) { + let mut doc = scorer.doc(); + while doc != TERMINATED { + callback(doc, scorer.score()); + doc = scorer.advance(); + } +} + +/// Calls `callback` with all of the `(doc, score)` for which score +/// is exceeding a given threshold. +/// +/// This method is useful for the TopDocs collector. +/// For all docsets, the blanket implementation has the benefit +/// of prefiltering (doc, score) pairs, avoiding the +/// virtual dispatch cost. +/// +/// More importantly, it makes it possible for scorers to implement +/// important optimization (e.g. BlockWAND for union). +pub(crate) fn for_each_pruning_scorer( + scorer: &mut TScorer, + mut threshold: f32, + callback: &mut dyn FnMut(DocId, Score) -> Score, +) { + let mut doc = scorer.doc(); + while doc != TERMINATED { + let score = scorer.score(); + if score > threshold { + threshold = callback(doc, score); + } + doc = scorer.advance(); + } +} /// A Weight is the specialization of a Query /// for a given set of segments. @@ -27,4 +65,37 @@ pub trait Weight: Send + Sync + 'static { Ok(scorer.count_including_deleted()) } } + + /// Iterates through all of the document matched by the DocSet + /// `DocSet` and push the scored documents to the collector. + fn for_each( + &self, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score), + ) -> crate::Result<()> { + let mut scorer = self.scorer(reader, 1.0f32)?; + for_each_scorer(scorer.as_mut(), callback); + Ok(()) + } + + /// Calls `callback` with all of the `(doc, score)` for which score + /// is exceeding a given threshold. + /// + /// This method is useful for the TopDocs collector. + /// For all docsets, the blanket implementation has the benefit + /// of prefiltering (doc, score) pairs, avoiding the + /// virtual dispatch cost. + /// + /// More importantly, it makes it possible for scorers to implement + /// important optimization (e.g. BlockWAND for union). + fn for_each_pruning( + &self, + threshold: f32, + reader: &SegmentReader, + callback: &mut dyn FnMut(DocId, Score) -> Score, + ) -> crate::Result<()> { + let mut scorer = self.scorer(reader, 1.0f32)?; + for_each_pruning_scorer(scorer.as_mut(), threshold, callback); + Ok(()) + } }