diff --git a/src/docset.rs b/src/docset.rs index 96ea30709..7de138da6 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -87,6 +87,17 @@ pub trait DocSet: Send { /// length of the docset. fn size_hint(&self) -> u32; + /// Returns a best-effort hint of the cost to consume the entire docset. + /// + /// Consuming means calling advance until [`TERMINATED`] is returned. + /// The cost should be relative to the cost of driving a Term query, + /// which would be the number of documents in the DocSet. + /// + /// By default this returns `size_hint()`. + fn cost(&self) -> u64 { + self.size_hint() as u64 + } + /// Returns the number documents matching. /// Calling this method consumes the `DocSet`. fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { @@ -134,6 +145,10 @@ impl DocSet for &mut dyn DocSet { (**self).size_hint() } + fn cost(&self) -> u64 { + (**self).cost() + } + fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { (**self).count(alive_bitset) } @@ -169,6 +184,11 @@ impl DocSet for Box { unboxed.size_hint() } + fn cost(&self) -> u64 { + let unboxed: &TDocSet = self.borrow(); + unboxed.cost() + } + fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.count(alive_bitset) diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 4dee06aad..efc0e069d 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -667,12 +667,15 @@ mod bench { .read_postings(&TERM_D, IndexRecordOption::Basic) .unwrap() .unwrap(); - let mut intersection = Intersection::new(vec![ - segment_postings_a, - segment_postings_b, - segment_postings_c, - segment_postings_d, - ]); + let mut intersection = Intersection::new( + vec![ + segment_postings_a, + segment_postings_b, + segment_postings_c, + segment_postings_d, + ], + reader.searcher().num_docs() as u32, + ); while intersection.advance() != TERMINATED {} }); } diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index c5f3ac36d..c6710b09c 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -367,10 +367,14 @@ mod tests { checkpoints } - fn compute_checkpoints_manual(term_scorers: Vec, n: usize) -> Vec<(DocId, Score)> { + fn compute_checkpoints_manual( + term_scorers: Vec, + n: usize, + max_doc: u32, + ) -> Vec<(DocId, Score)> { let mut heap: BinaryHeap = BinaryHeap::with_capacity(n); let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); - let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default); + let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default, max_doc); let mut limit = Score::MIN; loop { @@ -478,7 +482,8 @@ mod tests { for top_k in 1..4 { let checkpoints_for_each_pruning = compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k); - let checkpoints_manual = compute_checkpoints_manual(term_scorers.clone(), top_k); + let checkpoints_manual = + compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000); assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len()); for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning .iter() diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 96642ea59..1b72c44d2 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -1,3 +1,4 @@ +use core::num; use std::collections::HashMap; use crate::docset::COLLECT_BLOCK_BUFFER_LEN; @@ -39,9 +40,11 @@ where )) } +/// num_docs is the number of documents in the segment. fn scorer_union( scorers: Vec>, score_combiner_fn: impl Fn() -> TScoreCombiner, + num_docs: u32, ) -> SpecializedScorer where TScoreCombiner: ScoreCombiner, @@ -68,6 +71,7 @@ where return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( scorers, score_combiner_fn, + num_docs, ))); } } @@ -75,16 +79,19 @@ where SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( scorers, score_combiner_fn, + num_docs, ))) } fn into_box_scorer( scorer: SpecializedScorer, score_combiner_fn: impl Fn() -> TScoreCombiner, + num_docs: u32, ) -> Box { match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn); + let union_scorer = + BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs); Box::new(union_scorer) } SpecializedScorer::Other(scorer) => scorer, @@ -151,6 +158,7 @@ impl BooleanWeight { boost: Score, score_combiner_fn: impl Fn() -> TComplexScoreCombiner, ) -> crate::Result { + let num_docs = reader.num_docs(); let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; // Indicate how should clauses are combined with other clauses. enum CombinationMethod { @@ -167,11 +175,16 @@ impl BooleanWeight { return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } match self.minimum_number_should_match { - 0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)), - 1 => { - let scorer_union = scorer_union(should_scorers, &score_combiner_fn); - CombinationMethod::Required(scorer_union) - } + 0 => CombinationMethod::Optional(scorer_union( + should_scorers, + &score_combiner_fn, + num_docs, + )), + 1 => CombinationMethod::Required(scorer_union( + should_scorers, + &score_combiner_fn, + num_docs, + )), n if num_of_should_scorers == n => { // When num_of_should_scorers equals the number of should clauses, // they are no different from must clauses. @@ -200,21 +213,21 @@ impl BooleanWeight { }; let exclude_scorer_opt: Option> = per_occur_scorers .remove(&Occur::MustNot) - .map(|scorers| scorer_union(scorers, DoNothingCombiner::default)) + .map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs)) .map(|specialized_scorer: SpecializedScorer| { - into_box_scorer(specialized_scorer, DoNothingCombiner::default) + into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs) }); let positive_scorer = match (should_opt, must_scorers) { (CombinationMethod::Ignored, Some(must_scorers)) => { - SpecializedScorer::Other(intersect_scorers(must_scorers)) + SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) } (CombinationMethod::Optional(should_scorer), Some(must_scorers)) => { - let must_scorer = intersect_scorers(must_scorers); + let must_scorer = intersect_scorers(must_scorers, num_docs); if self.scoring_enabled { SpecializedScorer::Other(Box::new( RequiredOptionalScorer::<_, _, TScoreCombiner>::new( must_scorer, - into_box_scorer(should_scorer, &score_combiner_fn), + into_box_scorer(should_scorer, &score_combiner_fn, num_docs), ), )) } else { @@ -222,8 +235,8 @@ impl BooleanWeight { } } (CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => { - must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn)); - SpecializedScorer::Other(intersect_scorers(must_scorers)) + must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs)); + SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) } (CombinationMethod::Ignored, None) => { return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))) @@ -233,7 +246,8 @@ impl BooleanWeight { (CombinationMethod::Optional(should_scorer), None) => should_scorer, }; if let Some(exclude_scorer) = exclude_scorer_opt { - let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn); + let positive_scorer_boxed = + into_box_scorer(positive_scorer, &score_combiner_fn, num_docs); Ok(SpecializedScorer::Other(Box::new(Exclude::new( positive_scorer_boxed, exclude_scorer, @@ -246,6 +260,7 @@ impl BooleanWeight { impl Weight for BooleanWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { + let num_docs = reader.num_docs(); if self.weights.is_empty() { Ok(Box::new(EmptyScorer)) } else if self.weights.len() == 1 { @@ -258,12 +273,12 @@ impl Weight for BooleanWeight Weight for BooleanWeight { - let mut union_scorer = - BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); + let mut union_scorer = BufferedUnionScorer::build( + term_scorers, + &self.score_combiner_fn, + reader.num_docs(), + ); for_each_scorer(&mut union_scorer, callback); } SpecializedScorer::Other(mut scorer) => { @@ -317,8 +335,11 @@ impl Weight for BooleanWeight { - let mut union_scorer = - BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); + let mut union_scorer = BufferedUnionScorer::build( + term_scorers, + &self.score_combiner_fn, + reader.num_docs(), + ); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback); } SpecializedScorer::Other(mut scorer) => { diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index 4d2352d4d..06678287f 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -117,6 +117,10 @@ impl DocSet for BoostScorer { self.underlying.size_hint() } + fn cost(&self) -> u64 { + self.underlying.cost() + } + fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { self.underlying.count(alive_bitset) } diff --git a/src/query/const_score_query.rs b/src/query/const_score_query.rs index 8f27b8285..570c7feca 100644 --- a/src/query/const_score_query.rs +++ b/src/query/const_score_query.rs @@ -130,6 +130,10 @@ impl DocSet for ConstScorer { fn size_hint(&self) -> u32 { self.docset.size_hint() } + + fn cost(&self) -> u64 { + self.docset.cost() + } } impl Scorer for ConstScorer { diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs index 81723af9a..910e207df 100644 --- a/src/query/disjunction.rs +++ b/src/query/disjunction.rs @@ -70,6 +70,10 @@ impl DocSet for ScorerWrapper { fn size_hint(&self) -> u32 { self.scorer.size_hint() } + + fn cost(&self) -> u64 { + self.scorer.cost() + } } impl Disjunction { @@ -146,6 +150,14 @@ impl DocSet .max() .unwrap_or(0u32) } + + fn cost(&self) -> u64 { + self.chains + .iter() + .map(|docset| docset.cost()) + .max() + .unwrap_or(0u64) + } } impl Scorer diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 61c8ca8f3..10e257c43 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -1,4 +1,5 @@ use crate::docset::{DocSet, TERMINATED}; +use crate::query::size_hint::estimate_intersection; use crate::query::term_query::TermScorer; use crate::query::{EmptyScorer, Scorer}; use crate::{DocId, Score}; @@ -11,14 +12,18 @@ use crate::{DocId, Score}; /// For better performance, the function uses a /// specialized implementation if the two /// shortest scorers are `TermScorer`s. -pub fn intersect_scorers(mut scorers: Vec>) -> Box { +pub fn intersect_scorers( + mut scorers: Vec>, + num_docs_segment: u32, +) -> Box { if scorers.is_empty() { return Box::new(EmptyScorer); } if scorers.len() == 1 { return scorers.pop().unwrap(); } - scorers.sort_by_key(|scorer| scorer.size_hint()); + // Order by estimated cost to drive each scorer. + scorers.sort_by_key(|scorer| scorer.cost()); let doc = go_to_first_doc(&mut scorers[..]); if doc == TERMINATED { return Box::new(EmptyScorer); @@ -34,12 +39,14 @@ pub fn intersect_scorers(mut scorers: Vec>) -> Box { left: *(left.downcast::().map_err(|_| ()).unwrap()), right: *(right.downcast::().map_err(|_| ()).unwrap()), others: scorers, + num_docs: num_docs_segment, }); } Box::new(Intersection { left, right, others: scorers, + num_docs: num_docs_segment, }) } @@ -48,6 +55,7 @@ pub struct Intersection> left: TDocSet, right: TDocSet, others: Vec, + num_docs: u32, } fn go_to_first_doc(docsets: &mut [TDocSet]) -> DocId { @@ -66,10 +74,11 @@ fn go_to_first_doc(docsets: &mut [TDocSet]) -> DocId { } impl Intersection { - pub(crate) fn new(mut docsets: Vec) -> Intersection { + /// num_docs is the number of documents in the segment. + pub(crate) fn new(mut docsets: Vec, num_docs: u32) -> Intersection { let num_docsets = docsets.len(); assert!(num_docsets >= 2); - docsets.sort_by_key(|docset| docset.size_hint()); + docsets.sort_by_key(|docset| docset.cost()); go_to_first_doc(&mut docsets); let left = docsets.remove(0); let right = docsets.remove(0); @@ -77,6 +86,7 @@ impl Intersection { left, right, others: docsets, + num_docs, } } } @@ -141,7 +151,19 @@ impl DocSet for Intersection u32 { - self.left.size_hint() + estimate_intersection( + [self.left.size_hint(), self.right.size_hint()] + .into_iter() + .chain(self.others.iter().map(DocSet::size_hint)), + self.num_docs, + ) + } + + fn cost(&self) -> u64 { + // What's the best way to compute the cost of an intersection? + // For now we take the cost of the docset driver, which is the first docset. + // If there are docsets that are bad at skipping, they should also influence the cost. + self.left.cost() } } @@ -169,7 +191,7 @@ mod tests { { let left = VecDocSet::from(vec![1, 3, 9]); let right = VecDocSet::from(vec![3, 4, 9, 18]); - let mut intersection = Intersection::new(vec![left, right]); + let mut intersection = Intersection::new(vec![left, right], 10); assert_eq!(intersection.doc(), 3); assert_eq!(intersection.advance(), 9); assert_eq!(intersection.doc(), 9); @@ -179,7 +201,7 @@ mod tests { let a = VecDocSet::from(vec![1, 3, 9]); let b = VecDocSet::from(vec![3, 4, 9, 18]); let c = VecDocSet::from(vec![1, 5, 9, 111]); - let mut intersection = Intersection::new(vec![a, b, c]); + let mut intersection = Intersection::new(vec![a, b, c], 10); assert_eq!(intersection.doc(), 9); assert_eq!(intersection.advance(), TERMINATED); } @@ -189,7 +211,7 @@ mod tests { fn test_intersection_zero() { let left = VecDocSet::from(vec![0]); let right = VecDocSet::from(vec![0]); - let mut intersection = Intersection::new(vec![left, right]); + let mut intersection = Intersection::new(vec![left, right], 10); assert_eq!(intersection.doc(), 0); assert_eq!(intersection.advance(), TERMINATED); } @@ -198,7 +220,7 @@ mod tests { fn test_intersection_skip() { let left = VecDocSet::from(vec![0, 1, 2, 4]); let right = VecDocSet::from(vec![2, 5]); - let mut intersection = Intersection::new(vec![left, right]); + let mut intersection = Intersection::new(vec![left, right], 10); assert_eq!(intersection.seek(2), 2); assert_eq!(intersection.doc(), 2); } @@ -209,7 +231,7 @@ mod tests { || { let left = VecDocSet::from(vec![4]); let right = VecDocSet::from(vec![2, 5]); - Box::new(Intersection::new(vec![left, right])) + Box::new(Intersection::new(vec![left, right], 10)) }, vec![0, 2, 4, 5, 6], ); @@ -219,19 +241,22 @@ mod tests { let mut right = VecDocSet::from(vec![2, 5, 10]); left.advance(); right.advance(); - Box::new(Intersection::new(vec![left, right])) + Box::new(Intersection::new(vec![left, right], 10)) }, vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11], ); test_skip_against_unoptimized( || { - Box::new(Intersection::new(vec![ - VecDocSet::from(vec![1, 4, 5, 6]), - VecDocSet::from(vec![1, 2, 5, 6]), - VecDocSet::from(vec![1, 4, 5, 6]), - VecDocSet::from(vec![1, 5, 6]), - VecDocSet::from(vec![2, 4, 5, 7, 8]), - ])) + Box::new(Intersection::new( + vec![ + VecDocSet::from(vec![1, 4, 5, 6]), + VecDocSet::from(vec![1, 2, 5, 6]), + VecDocSet::from(vec![1, 4, 5, 6]), + VecDocSet::from(vec![1, 5, 6]), + VecDocSet::from(vec![2, 4, 5, 7, 8]), + ], + 10, + )) }, vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11], ); @@ -242,7 +267,7 @@ mod tests { let a = VecDocSet::from(vec![1, 3]); let b = VecDocSet::from(vec![1, 4]); let c = VecDocSet::from(vec![3, 9]); - let intersection = Intersection::new(vec![a, b, c]); + let intersection = Intersection::new(vec![a, b, c], 10); assert_eq!(intersection.doc(), TERMINATED); } } diff --git a/src/query/mod.rs b/src/query/mod.rs index 2ba2f4def..d609a0402 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -23,6 +23,7 @@ mod regex_query; mod reqopt_scorer; mod scorer; mod set_query; +mod size_hint; mod term_query; mod union; mod weight; diff --git a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs index 09cf6c5bd..14933f3ae 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs @@ -200,6 +200,10 @@ impl DocSet for PhrasePrefixScorer { fn size_hint(&self) -> u32 { self.phrase_scorer.size_hint() } + + fn cost(&self) -> u64 { + self.phrase_scorer.cost() + } } impl Scorer for PhrasePrefixScorer { diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 5c67f4e27..12a94dce3 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -368,6 +368,7 @@ impl PhraseScorer { slop: u32, offset: usize, ) -> PhraseScorer { + let num_docs = fieldnorm_reader.num_docs(); let max_offset = term_postings_with_offset .iter() .map(|&(offset, _)| offset) @@ -382,7 +383,7 @@ impl PhraseScorer { }) .collect::>(); let mut scorer = PhraseScorer { - intersection_docset: Intersection::new(postings_with_offsets), + intersection_docset: Intersection::new(postings_with_offsets, num_docs), num_terms: num_docsets, left_positions: Vec::with_capacity(100), right_positions: Vec::with_capacity(100), @@ -535,6 +536,15 @@ impl DocSet for PhraseScorer { fn size_hint(&self) -> u32 { self.intersection_docset.size_hint() } + + /// Returns a best-effort hint of the + /// cost to drive the docset. + fn cost(&self) -> u64 { + // Evaluating phrase matches is generally more expensive than simple term matches, + // as it requires loading and comparing positions. Use a conservative multiplier + // based on the number of terms. + self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64 + } } impl Scorer for PhraseScorer { diff --git a/src/query/range_query/fast_field_range_doc_set.rs b/src/query/range_query/fast_field_range_doc_set.rs index 779269069..dd4b8fe68 100644 --- a/src/query/range_query/fast_field_range_doc_set.rs +++ b/src/query/range_query/fast_field_range_doc_set.rs @@ -176,6 +176,14 @@ impl DocSet for RangeDocSe fn size_hint(&self) -> u32 { self.column.num_docs() } + + /// Returns a best-effort hint of the + /// cost to drive the docset. + fn cost(&self) -> u64 { + // Advancing the docset is relatively expensive since it scans the column. + // Keep cost relative to a term query driver; use num_docs as baseline. + self.column.num_docs() as u64 + } } #[cfg(test)] diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index 3a4ae61b9..be9e14692 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -63,6 +63,10 @@ where fn size_hint(&self) -> u32 { self.req_scorer.size_hint() } + + fn cost(&self) -> u64 { + self.req_scorer.cost() + } } impl Scorer diff --git a/src/query/size_hint.rs b/src/query/size_hint.rs new file mode 100644 index 000000000..3d2811d40 --- /dev/null +++ b/src/query/size_hint.rs @@ -0,0 +1,141 @@ +/// Computes the estimated number of documents in the intersection of multiple docsets +/// given their sizes. +/// +/// # Arguments +/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set). +/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the +/// segment. +/// +/// # Returns +/// The estimated number of documents in the intersection. +pub fn estimate_intersection(mut docset_sizes: I, max_docs: u32) -> u32 +where I: Iterator { + if max_docs == 0u32 { + return 0u32; + } + // Terms tend to be not really randomly distributed. + // This factor is used to adjust the estimate. + let mut co_loc_factor: f64 = 1.3; + + let mut intersection_estimate = match docset_sizes.next() { + Some(first_size) => first_size as f64, + None => return 0, // No docsets provided, so return 0. + }; + + let mut smallest_docset_size = intersection_estimate; + // Assuming random distribution of terms, the probability of a document being in the + // intersection + for size in docset_sizes { + // Diminish the co-location factor for each additional set, or we will overestimate. + co_loc_factor = (co_loc_factor - 0.1).max(1.0); + intersection_estimate *= (size as f64 / max_docs as f64) * co_loc_factor; + smallest_docset_size = smallest_docset_size.min(size as f64); + } + + intersection_estimate.round().min(smallest_docset_size) as u32 +} + +/// Computes the estimated number of documents in the union of multiple docsets +/// given their sizes. +/// +/// # Arguments +/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set). +/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the +/// segment. +/// +/// # Returns +/// The estimated number of documents in the union. +pub fn estimate_union(docset_sizes: I, max_docs: u32) -> u32 +where I: Iterator { + // Terms tend to be not really randomly distributed. + // This factor is used to adjust the estimate. + // Unlike intersection, the co-location reduces the estimate. + let co_loc_factor = 0.8; + + // The approach for union is to compute the probability of a document not being in any of the + // sets + let mut not_in_any_set_prob = 1.0; + + // Assuming random distribution of terms, the probability of a document being in the + // union is the complement of the probability of it not being in any of the sets. + for size in docset_sizes { + let prob_in_set = (size as f64 / max_docs as f64) * co_loc_factor; + not_in_any_set_prob *= 1.0 - prob_in_set; + } + + let union_estimate = (max_docs as f64 * (1.0 - not_in_any_set_prob)).round(); + + union_estimate.min(max_docs as f64) as u32 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_estimate_intersection_small1() { + let docset_sizes = &[500, 1000]; + let n = 10_000; + let result = estimate_intersection(docset_sizes.iter().copied(), n); + assert_eq!(result, 60); + } + + #[test] + fn test_estimate_intersection_small2() { + let docset_sizes = &[500, 1000, 1500]; + let n = 10_000; + let result = estimate_intersection(docset_sizes.iter().copied(), n); + assert_eq!(result, 10); + } + + #[test] + fn test_estimate_intersection_large_values() { + let docset_sizes = &[100_000, 50_000, 30_000]; + let n = 1_000_000; + let result = estimate_intersection(docset_sizes.iter().copied(), n); + assert_eq!(result, 198); + } + + #[test] + fn test_estimate_union_small() { + let docset_sizes = &[500, 1000, 1500]; + let n = 10000; + let result = estimate_union(docset_sizes.iter().copied(), n); + assert_eq!(result, 2228); + } + + #[test] + fn test_estimate_union_large_values() { + let docset_sizes = &[100000, 50000, 30000]; + let n = 1000000; + let result = estimate_union(docset_sizes.iter().copied(), n); + assert_eq!(result, 137997); + } + + #[test] + fn test_estimate_intersection_large() { + let docset_sizes: Vec<_> = (0..10).map(|_| 4_000_000).collect(); + let n = 5_000_000; + let result = estimate_intersection(docset_sizes.iter().copied(), n); + // Check that it doesn't overflow and returns a reasonable result + assert_eq!(result, 708_670); + } + + #[test] + fn test_estimate_intersection_overflow_safety() { + let docset_sizes: Vec<_> = (0..100).map(|_| 4_000_000).collect(); + let n = 5_000_000; + let result = estimate_intersection(docset_sizes.iter().copied(), n); + // Check that it doesn't overflow and returns a reasonable result + assert_eq!(result, 0); + } + + #[test] + fn test_estimate_union_overflow_safety() { + let docset_sizes: Vec<_> = (0..100).map(|_| 1_000_000).collect(); + let n = 20_000_000; + let result = estimate_union(docset_sizes.iter().copied(), n); + // Check that it doesn't overflow and returns a reasonable result + assert_eq!(result, 19_662_594); + } +} diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index 4811ad128..3c726b8a7 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -2,6 +2,7 @@ use common::TinySet; use crate::docset::{DocSet, TERMINATED}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; +use crate::query::size_hint::estimate_union; use crate::query::Scorer; use crate::{DocId, Score}; @@ -50,6 +51,8 @@ pub struct BufferedUnionScorer { doc: DocId, /// Combined score for current `doc` as produced by `TScoreCombiner`. score: Score, + /// Number of documents in the segment. + num_docs: u32, } fn refill( @@ -78,9 +81,11 @@ fn refill( } impl BufferedUnionScorer { + /// num_docs is the number of documents in the segment. pub(crate) fn build( docsets: Vec, score_combiner_fn: impl FnOnce() -> TScoreCombiner, + num_docs: u32, ) -> BufferedUnionScorer { let non_empty_docsets: Vec = docsets .into_iter() @@ -94,6 +99,7 @@ impl BufferedUnionScorer u32 { - self.docsets - .iter() - .map(|docset| docset.size_hint()) - .max() - .unwrap_or(0u32) + estimate_union(self.docsets.iter().map(DocSet::size_hint), self.num_docs) + } + + fn cost(&self) -> u64 { + self.docsets.iter().map(|docset| docset.cost()).sum() } fn count_including_deleted(&mut self) -> u32 { diff --git a/src/query/union/mod.rs b/src/query/union/mod.rs index 84153e272..539c6c387 100644 --- a/src/query/union/mod.rs +++ b/src/query/union/mod.rs @@ -27,11 +27,17 @@ mod tests { docs_list.iter().cloned().map(VecDocSet::from) } fn union_from_docs_list(docs_list: &[Vec]) -> Box { + let max_doc = docs_list + .iter() + .flat_map(|docs| docs.iter().copied()) + .max() + .unwrap_or(0); Box::new(BufferedUnionScorer::build( vec_doc_set_from_docs_list(docs_list) .map(|docset| ConstScorer::new(docset, 1.0)) .collect::>>(), DoNothingCombiner::default, + max_doc, )) } @@ -273,6 +279,7 @@ mod bench { .map(|docset| ConstScorer::new(docset, 1.0)) .collect::>(), DoNothingCombiner::default, + 100_000, ); while v.doc() != TERMINATED { v.advance(); @@ -294,6 +301,7 @@ mod bench { .map(|docset| ConstScorer::new(docset, 1.0)) .collect::>(), DoNothingCombiner::default, + 100_000, ); while v.doc() != TERMINATED { v.advance(); diff --git a/src/query/union/simple_union.rs b/src/query/union/simple_union.rs index 041d4c90e..61cbb94b6 100644 --- a/src/query/union/simple_union.rs +++ b/src/query/union/simple_union.rs @@ -99,6 +99,10 @@ impl DocSet for SimpleUnion { .unwrap_or(0u32) } + fn cost(&self) -> u64 { + self.docsets.iter().map(|docset| docset.cost()).sum() + } + fn count_including_deleted(&mut self) -> u32 { if self.doc == TERMINATED { return 0u32;