Add DocSet::cost() (#2707)

* query: add DocSet cost hint and use it for intersection ordering

- Add DocSet::cost()
- Use cost() instead of size_hint() to order scorers in intersect_scorers

This isolates cost-related changes without the new seek APIs from
PR #2538

* add comments

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
This commit is contained in:
PSeitz
2025-10-13 16:25:49 +02:00
committed by GitHub
parent 270ca5123c
commit 33835b6a01
17 changed files with 334 additions and 54 deletions

View File

@@ -87,6 +87,17 @@ pub trait DocSet: Send {
/// length of the docset. /// length of the docset.
fn size_hint(&self) -> u32; fn size_hint(&self) -> u32;
/// Returns a best-effort hint of the cost to consume the entire docset.
///
/// Consuming means calling advance until [`TERMINATED`] is returned.
/// The cost should be relative to the cost of driving a Term query,
/// which would be the number of documents in the DocSet.
///
/// By default this returns `size_hint()`.
fn cost(&self) -> u64 {
self.size_hint() as u64
}
/// Returns the number documents matching. /// Returns the number documents matching.
/// Calling this method consumes the `DocSet`. /// Calling this method consumes the `DocSet`.
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
@@ -134,6 +145,10 @@ impl DocSet for &mut dyn DocSet {
(**self).size_hint() (**self).size_hint()
} }
fn cost(&self) -> u64 {
(**self).cost()
}
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
(**self).count(alive_bitset) (**self).count(alive_bitset)
} }
@@ -169,6 +184,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.size_hint() unboxed.size_hint()
} }
fn cost(&self) -> u64 {
let unboxed: &TDocSet = self.borrow();
unboxed.cost()
}
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
let unboxed: &mut TDocSet = self.borrow_mut(); let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.count(alive_bitset) unboxed.count(alive_bitset)

View File

@@ -667,12 +667,15 @@ mod bench {
.read_postings(&TERM_D, IndexRecordOption::Basic) .read_postings(&TERM_D, IndexRecordOption::Basic)
.unwrap() .unwrap()
.unwrap(); .unwrap();
let mut intersection = Intersection::new(vec![ let mut intersection = Intersection::new(
segment_postings_a, vec![
segment_postings_b, segment_postings_a,
segment_postings_c, segment_postings_b,
segment_postings_d, segment_postings_c,
]); segment_postings_d,
],
reader.searcher().num_docs() as u32,
);
while intersection.advance() != TERMINATED {} while intersection.advance() != TERMINATED {}
}); });
} }

View File

@@ -367,10 +367,14 @@ mod tests {
checkpoints checkpoints
} }
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> { fn compute_checkpoints_manual(
term_scorers: Vec<TermScorer>,
n: usize,
max_doc: u32,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n); let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default); let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default, max_doc);
let mut limit = Score::MIN; let mut limit = Score::MIN;
loop { loop {
@@ -478,7 +482,8 @@ mod tests {
for top_k in 1..4 { for top_k in 1..4 {
let checkpoints_for_each_pruning = let checkpoints_for_each_pruning =
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k); compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
let checkpoints_manual = compute_checkpoints_manual(term_scorers.clone(), top_k); let checkpoints_manual =
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len()); assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
.iter() .iter()

View File

@@ -1,3 +1,4 @@
use core::num;
use std::collections::HashMap; use std::collections::HashMap;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
@@ -39,9 +40,11 @@ where
)) ))
} }
/// num_docs is the number of documents in the segment.
fn scorer_union<TScoreCombiner>( fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>, scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner, score_combiner_fn: impl Fn() -> TScoreCombiner,
num_docs: u32,
) -> SpecializedScorer ) -> SpecializedScorer
where where
TScoreCombiner: ScoreCombiner, TScoreCombiner: ScoreCombiner,
@@ -68,6 +71,7 @@ where
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers, scorers,
score_combiner_fn, score_combiner_fn,
num_docs,
))); )));
} }
} }
@@ -75,16 +79,19 @@ where
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers, scorers,
score_combiner_fn, score_combiner_fn,
num_docs,
))) )))
} }
fn into_box_scorer<TScoreCombiner: ScoreCombiner>( fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
scorer: SpecializedScorer, scorer: SpecializedScorer,
score_combiner_fn: impl Fn() -> TScoreCombiner, score_combiner_fn: impl Fn() -> TScoreCombiner,
num_docs: u32,
) -> Box<dyn Scorer> { ) -> Box<dyn Scorer> {
match scorer { match scorer {
SpecializedScorer::TermUnion(term_scorers) => { SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn); let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer) Box::new(union_scorer)
} }
SpecializedScorer::Other(scorer) => scorer, SpecializedScorer::Other(scorer) => scorer,
@@ -151,6 +158,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
boost: Score, boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner, score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> { ) -> crate::Result<SpecializedScorer> {
let num_docs = reader.num_docs();
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
// Indicate how should clauses are combined with other clauses. // Indicate how should clauses are combined with other clauses.
enum CombinationMethod { enum CombinationMethod {
@@ -167,11 +175,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} }
match self.minimum_number_should_match { match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)), 0 => CombinationMethod::Optional(scorer_union(
1 => { should_scorers,
let scorer_union = scorer_union(should_scorers, &score_combiner_fn); &score_combiner_fn,
CombinationMethod::Required(scorer_union) num_docs,
} )),
1 => CombinationMethod::Required(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
)),
n if num_of_should_scorers == n => { n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses, // When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses. // they are no different from must clauses.
@@ -200,21 +213,21 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
}; };
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot) .remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default)) .map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
.map(|specialized_scorer: SpecializedScorer| { .map(|specialized_scorer: SpecializedScorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default) into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
}); });
let positive_scorer = match (should_opt, must_scorers) { let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => { (CombinationMethod::Ignored, Some(must_scorers)) => {
SpecializedScorer::Other(intersect_scorers(must_scorers)) SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
} }
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => { (CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers); let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled { if self.scoring_enabled {
SpecializedScorer::Other(Box::new( SpecializedScorer::Other(Box::new(
RequiredOptionalScorer::<_, _, TScoreCombiner>::new( RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
must_scorer, must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn), into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
), ),
)) ))
} else { } else {
@@ -222,8 +235,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
} }
} }
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => { (CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn)); must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers)) SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
} }
(CombinationMethod::Ignored, None) => { (CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))) return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
@@ -233,7 +246,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
(CombinationMethod::Optional(should_scorer), None) => should_scorer, (CombinationMethod::Optional(should_scorer), None) => should_scorer,
}; };
if let Some(exclude_scorer) = exclude_scorer_opt { if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn); let positive_scorer_boxed =
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
Ok(SpecializedScorer::Other(Box::new(Exclude::new( Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed, positive_scorer_boxed,
exclude_scorer, exclude_scorer,
@@ -246,6 +260,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> { impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs();
if self.weights.is_empty() { if self.weights.is_empty() {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
} else if self.weights.len() == 1 { } else if self.weights.len() == 1 {
@@ -258,12 +273,12 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
} else if self.scoring_enabled { } else if self.scoring_enabled {
self.complex_scorer(reader, boost, &self.score_combiner_fn) self.complex_scorer(reader, boost, &self.score_combiner_fn)
.map(|specialized_scorer| { .map(|specialized_scorer| {
into_box_scorer(specialized_scorer, &self.score_combiner_fn) into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs)
}) })
} else { } else {
self.complex_scorer(reader, boost, DoNothingCombiner::default) self.complex_scorer(reader, boost, DoNothingCombiner::default)
.map(|specialized_scorer| { .map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default) into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
}) })
} }
} }
@@ -296,8 +311,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?; let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer { match scorer {
SpecializedScorer::TermUnion(term_scorers) => { SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = let mut union_scorer = BufferedUnionScorer::build(
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_scorer(&mut union_scorer, callback); for_each_scorer(&mut union_scorer, callback);
} }
SpecializedScorer::Other(mut scorer) => { SpecializedScorer::Other(mut scorer) => {
@@ -317,8 +335,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
match scorer { match scorer {
SpecializedScorer::TermUnion(term_scorers) => { SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = let mut union_scorer = BufferedUnionScorer::build(
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
} }
SpecializedScorer::Other(mut scorer) => { SpecializedScorer::Other(mut scorer) => {

View File

@@ -117,6 +117,10 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.size_hint() self.underlying.size_hint()
} }
fn cost(&self) -> u64 {
self.underlying.cost()
}
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 { fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
self.underlying.count(alive_bitset) self.underlying.count(alive_bitset)
} }

View File

@@ -130,6 +130,10 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.docset.size_hint() self.docset.size_hint()
} }
fn cost(&self) -> u64 {
self.docset.cost()
}
} }
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> { impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {

View File

@@ -70,6 +70,10 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.scorer.size_hint() self.scorer.size_hint()
} }
fn cost(&self) -> u64 {
self.scorer.cost()
}
} }
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> { impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
@@ -146,6 +150,14 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
.max() .max()
.unwrap_or(0u32) .unwrap_or(0u32)
} }
fn cost(&self) -> u64 {
self.chains
.iter()
.map(|docset| docset.cost())
.max()
.unwrap_or(0u64)
}
} }
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer

View File

@@ -1,4 +1,5 @@
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, TERMINATED};
use crate::query::size_hint::estimate_intersection;
use crate::query::term_query::TermScorer; use crate::query::term_query::TermScorer;
use crate::query::{EmptyScorer, Scorer}; use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score}; use crate::{DocId, Score};
@@ -11,14 +12,18 @@ use crate::{DocId, Score};
/// For better performance, the function uses a /// For better performance, the function uses a
/// specialized implementation if the two /// specialized implementation if the two
/// shortest scorers are `TermScorer`s. /// shortest scorers are `TermScorer`s.
pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> { pub fn intersect_scorers(
mut scorers: Vec<Box<dyn Scorer>>,
num_docs_segment: u32,
) -> Box<dyn Scorer> {
if scorers.is_empty() { if scorers.is_empty() {
return Box::new(EmptyScorer); return Box::new(EmptyScorer);
} }
if scorers.len() == 1 { if scorers.len() == 1 {
return scorers.pop().unwrap(); return scorers.pop().unwrap();
} }
scorers.sort_by_key(|scorer| scorer.size_hint()); // Order by estimated cost to drive each scorer.
scorers.sort_by_key(|scorer| scorer.cost());
let doc = go_to_first_doc(&mut scorers[..]); let doc = go_to_first_doc(&mut scorers[..]);
if doc == TERMINATED { if doc == TERMINATED {
return Box::new(EmptyScorer); return Box::new(EmptyScorer);
@@ -34,12 +39,14 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()), left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()), right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers, others: scorers,
num_docs: num_docs_segment,
}); });
} }
Box::new(Intersection { Box::new(Intersection {
left, left,
right, right,
others: scorers, others: scorers,
num_docs: num_docs_segment,
}) })
} }
@@ -48,6 +55,7 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
left: TDocSet, left: TDocSet,
right: TDocSet, right: TDocSet,
others: Vec<TOtherDocSet>, others: Vec<TOtherDocSet>,
num_docs: u32,
} }
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId { fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
@@ -66,10 +74,11 @@ fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
} }
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> { impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> { /// num_docs is the number of documents in the segment.
pub(crate) fn new(mut docsets: Vec<TDocSet>, num_docs: u32) -> Intersection<TDocSet, TDocSet> {
let num_docsets = docsets.len(); let num_docsets = docsets.len();
assert!(num_docsets >= 2); assert!(num_docsets >= 2);
docsets.sort_by_key(|docset| docset.size_hint()); docsets.sort_by_key(|docset| docset.cost());
go_to_first_doc(&mut docsets); go_to_first_doc(&mut docsets);
let left = docsets.remove(0); let left = docsets.remove(0);
let right = docsets.remove(0); let right = docsets.remove(0);
@@ -77,6 +86,7 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
left, left,
right, right,
others: docsets, others: docsets,
num_docs,
} }
} }
} }
@@ -141,7 +151,19 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
} }
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.left.size_hint() estimate_intersection(
[self.left.size_hint(), self.right.size_hint()]
.into_iter()
.chain(self.others.iter().map(DocSet::size_hint)),
self.num_docs,
)
}
fn cost(&self) -> u64 {
// What's the best way to compute the cost of an intersection?
// For now we take the cost of the docset driver, which is the first docset.
// If there are docsets that are bad at skipping, they should also influence the cost.
self.left.cost()
} }
} }
@@ -169,7 +191,7 @@ mod tests {
{ {
let left = VecDocSet::from(vec![1, 3, 9]); let left = VecDocSet::from(vec![1, 3, 9]);
let right = VecDocSet::from(vec![3, 4, 9, 18]); let right = VecDocSet::from(vec![3, 4, 9, 18]);
let mut intersection = Intersection::new(vec![left, right]); let mut intersection = Intersection::new(vec![left, right], 10);
assert_eq!(intersection.doc(), 3); assert_eq!(intersection.doc(), 3);
assert_eq!(intersection.advance(), 9); assert_eq!(intersection.advance(), 9);
assert_eq!(intersection.doc(), 9); assert_eq!(intersection.doc(), 9);
@@ -179,7 +201,7 @@ mod tests {
let a = VecDocSet::from(vec![1, 3, 9]); let a = VecDocSet::from(vec![1, 3, 9]);
let b = VecDocSet::from(vec![3, 4, 9, 18]); let b = VecDocSet::from(vec![3, 4, 9, 18]);
let c = VecDocSet::from(vec![1, 5, 9, 111]); let c = VecDocSet::from(vec![1, 5, 9, 111]);
let mut intersection = Intersection::new(vec![a, b, c]); let mut intersection = Intersection::new(vec![a, b, c], 10);
assert_eq!(intersection.doc(), 9); assert_eq!(intersection.doc(), 9);
assert_eq!(intersection.advance(), TERMINATED); assert_eq!(intersection.advance(), TERMINATED);
} }
@@ -189,7 +211,7 @@ mod tests {
fn test_intersection_zero() { fn test_intersection_zero() {
let left = VecDocSet::from(vec![0]); let left = VecDocSet::from(vec![0]);
let right = VecDocSet::from(vec![0]); let right = VecDocSet::from(vec![0]);
let mut intersection = Intersection::new(vec![left, right]); let mut intersection = Intersection::new(vec![left, right], 10);
assert_eq!(intersection.doc(), 0); assert_eq!(intersection.doc(), 0);
assert_eq!(intersection.advance(), TERMINATED); assert_eq!(intersection.advance(), TERMINATED);
} }
@@ -198,7 +220,7 @@ mod tests {
fn test_intersection_skip() { fn test_intersection_skip() {
let left = VecDocSet::from(vec![0, 1, 2, 4]); let left = VecDocSet::from(vec![0, 1, 2, 4]);
let right = VecDocSet::from(vec![2, 5]); let right = VecDocSet::from(vec![2, 5]);
let mut intersection = Intersection::new(vec![left, right]); let mut intersection = Intersection::new(vec![left, right], 10);
assert_eq!(intersection.seek(2), 2); assert_eq!(intersection.seek(2), 2);
assert_eq!(intersection.doc(), 2); assert_eq!(intersection.doc(), 2);
} }
@@ -209,7 +231,7 @@ mod tests {
|| { || {
let left = VecDocSet::from(vec![4]); let left = VecDocSet::from(vec![4]);
let right = VecDocSet::from(vec![2, 5]); let right = VecDocSet::from(vec![2, 5]);
Box::new(Intersection::new(vec![left, right])) Box::new(Intersection::new(vec![left, right], 10))
}, },
vec![0, 2, 4, 5, 6], vec![0, 2, 4, 5, 6],
); );
@@ -219,19 +241,22 @@ mod tests {
let mut right = VecDocSet::from(vec![2, 5, 10]); let mut right = VecDocSet::from(vec![2, 5, 10]);
left.advance(); left.advance();
right.advance(); right.advance();
Box::new(Intersection::new(vec![left, right])) Box::new(Intersection::new(vec![left, right], 10))
}, },
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11], vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
); );
test_skip_against_unoptimized( test_skip_against_unoptimized(
|| { || {
Box::new(Intersection::new(vec![ Box::new(Intersection::new(
VecDocSet::from(vec![1, 4, 5, 6]), vec![
VecDocSet::from(vec![1, 2, 5, 6]), VecDocSet::from(vec![1, 4, 5, 6]),
VecDocSet::from(vec![1, 4, 5, 6]), VecDocSet::from(vec![1, 2, 5, 6]),
VecDocSet::from(vec![1, 5, 6]), VecDocSet::from(vec![1, 4, 5, 6]),
VecDocSet::from(vec![2, 4, 5, 7, 8]), VecDocSet::from(vec![1, 5, 6]),
])) VecDocSet::from(vec![2, 4, 5, 7, 8]),
],
10,
))
}, },
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11], vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
); );
@@ -242,7 +267,7 @@ mod tests {
let a = VecDocSet::from(vec![1, 3]); let a = VecDocSet::from(vec![1, 3]);
let b = VecDocSet::from(vec![1, 4]); let b = VecDocSet::from(vec![1, 4]);
let c = VecDocSet::from(vec![3, 9]); let c = VecDocSet::from(vec![3, 9]);
let intersection = Intersection::new(vec![a, b, c]); let intersection = Intersection::new(vec![a, b, c], 10);
assert_eq!(intersection.doc(), TERMINATED); assert_eq!(intersection.doc(), TERMINATED);
} }
} }

View File

@@ -23,6 +23,7 @@ mod regex_query;
mod reqopt_scorer; mod reqopt_scorer;
mod scorer; mod scorer;
mod set_query; mod set_query;
mod size_hint;
mod term_query; mod term_query;
mod union; mod union;
mod weight; mod weight;

View File

@@ -200,6 +200,10 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.phrase_scorer.size_hint() self.phrase_scorer.size_hint()
} }
fn cost(&self) -> u64 {
self.phrase_scorer.cost()
}
} }
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> { impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {

View File

@@ -368,6 +368,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
slop: u32, slop: u32,
offset: usize, offset: usize,
) -> PhraseScorer<TPostings> { ) -> PhraseScorer<TPostings> {
let num_docs = fieldnorm_reader.num_docs();
let max_offset = term_postings_with_offset let max_offset = term_postings_with_offset
.iter() .iter()
.map(|&(offset, _)| offset) .map(|&(offset, _)| offset)
@@ -382,7 +383,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut scorer = PhraseScorer { let mut scorer = PhraseScorer {
intersection_docset: Intersection::new(postings_with_offsets), intersection_docset: Intersection::new(postings_with_offsets, num_docs),
num_terms: num_docsets, num_terms: num_docsets,
left_positions: Vec::with_capacity(100), left_positions: Vec::with_capacity(100),
right_positions: Vec::with_capacity(100), right_positions: Vec::with_capacity(100),
@@ -535,6 +536,15 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.intersection_docset.size_hint() self.intersection_docset.size_hint()
} }
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Evaluating phrase matches is generally more expensive than simple term matches,
// as it requires loading and comparing positions. Use a conservative multiplier
// based on the number of terms.
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
}
} }
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> { impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {

View File

@@ -176,6 +176,14 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.column.num_docs() self.column.num_docs()
} }
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Advancing the docset is relatively expensive since it scans the column.
// Keep cost relative to a term query driver; use num_docs as baseline.
self.column.num_docs() as u64
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -63,6 +63,10 @@ where
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.req_scorer.size_hint() self.req_scorer.size_hint()
} }
fn cost(&self) -> u64 {
self.req_scorer.cost()
}
} }
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer

141
src/query/size_hint.rs Normal file
View File

@@ -0,0 +1,141 @@
/// Computes the estimated number of documents in the intersection of multiple docsets
/// given their sizes.
///
/// # Arguments
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
/// segment.
///
/// # Returns
/// The estimated number of documents in the intersection.
pub fn estimate_intersection<I>(mut docset_sizes: I, max_docs: u32) -> u32
where I: Iterator<Item = u32> {
if max_docs == 0u32 {
return 0u32;
}
// Terms tend to be not really randomly distributed.
// This factor is used to adjust the estimate.
let mut co_loc_factor: f64 = 1.3;
let mut intersection_estimate = match docset_sizes.next() {
Some(first_size) => first_size as f64,
None => return 0, // No docsets provided, so return 0.
};
let mut smallest_docset_size = intersection_estimate;
// Assuming random distribution of terms, the probability of a document being in the
// intersection
for size in docset_sizes {
// Diminish the co-location factor for each additional set, or we will overestimate.
co_loc_factor = (co_loc_factor - 0.1).max(1.0);
intersection_estimate *= (size as f64 / max_docs as f64) * co_loc_factor;
smallest_docset_size = smallest_docset_size.min(size as f64);
}
intersection_estimate.round().min(smallest_docset_size) as u32
}
/// Computes the estimated number of documents in the union of multiple docsets
/// given their sizes.
///
/// # Arguments
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
/// segment.
///
/// # Returns
/// The estimated number of documents in the union.
pub fn estimate_union<I>(docset_sizes: I, max_docs: u32) -> u32
where I: Iterator<Item = u32> {
// Terms tend to be not really randomly distributed.
// This factor is used to adjust the estimate.
// Unlike intersection, the co-location reduces the estimate.
let co_loc_factor = 0.8;
// The approach for union is to compute the probability of a document not being in any of the
// sets
let mut not_in_any_set_prob = 1.0;
// Assuming random distribution of terms, the probability of a document being in the
// union is the complement of the probability of it not being in any of the sets.
for size in docset_sizes {
let prob_in_set = (size as f64 / max_docs as f64) * co_loc_factor;
not_in_any_set_prob *= 1.0 - prob_in_set;
}
let union_estimate = (max_docs as f64 * (1.0 - not_in_any_set_prob)).round();
union_estimate.min(max_docs as f64) as u32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_intersection_small1() {
let docset_sizes = &[500, 1000];
let n = 10_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
assert_eq!(result, 60);
}
#[test]
fn test_estimate_intersection_small2() {
let docset_sizes = &[500, 1000, 1500];
let n = 10_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
assert_eq!(result, 10);
}
#[test]
fn test_estimate_intersection_large_values() {
let docset_sizes = &[100_000, 50_000, 30_000];
let n = 1_000_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
assert_eq!(result, 198);
}
#[test]
fn test_estimate_union_small() {
let docset_sizes = &[500, 1000, 1500];
let n = 10000;
let result = estimate_union(docset_sizes.iter().copied(), n);
assert_eq!(result, 2228);
}
#[test]
fn test_estimate_union_large_values() {
let docset_sizes = &[100000, 50000, 30000];
let n = 1000000;
let result = estimate_union(docset_sizes.iter().copied(), n);
assert_eq!(result, 137997);
}
#[test]
fn test_estimate_intersection_large() {
let docset_sizes: Vec<_> = (0..10).map(|_| 4_000_000).collect();
let n = 5_000_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
// Check that it doesn't overflow and returns a reasonable result
assert_eq!(result, 708_670);
}
#[test]
fn test_estimate_intersection_overflow_safety() {
let docset_sizes: Vec<_> = (0..100).map(|_| 4_000_000).collect();
let n = 5_000_000;
let result = estimate_intersection(docset_sizes.iter().copied(), n);
// Check that it doesn't overflow and returns a reasonable result
assert_eq!(result, 0);
}
#[test]
fn test_estimate_union_overflow_safety() {
let docset_sizes: Vec<_> = (0..100).map(|_| 1_000_000).collect();
let n = 20_000_000;
let result = estimate_union(docset_sizes.iter().copied(), n);
// Check that it doesn't overflow and returns a reasonable result
assert_eq!(result, 19_662_594);
}
}

View File

@@ -2,6 +2,7 @@ use common::TinySet;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, TERMINATED};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::size_hint::estimate_union;
use crate::query::Scorer; use crate::query::Scorer;
use crate::{DocId, Score}; use crate::{DocId, Score};
@@ -50,6 +51,8 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
doc: DocId, doc: DocId,
/// Combined score for current `doc` as produced by `TScoreCombiner`. /// Combined score for current `doc` as produced by `TScoreCombiner`.
score: Score, score: Score,
/// Number of documents in the segment.
num_docs: u32,
} }
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>( fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
@@ -78,9 +81,11 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
} }
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> { impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
/// num_docs is the number of documents in the segment.
pub(crate) fn build( pub(crate) fn build(
docsets: Vec<TScorer>, docsets: Vec<TScorer>,
score_combiner_fn: impl FnOnce() -> TScoreCombiner, score_combiner_fn: impl FnOnce() -> TScoreCombiner,
num_docs: u32,
) -> BufferedUnionScorer<TScorer, TScoreCombiner> { ) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
let non_empty_docsets: Vec<TScorer> = docsets let non_empty_docsets: Vec<TScorer> = docsets
.into_iter() .into_iter()
@@ -94,6 +99,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
window_start_doc: 0, window_start_doc: 0,
doc: 0, doc: 0,
score: 0.0, score: 0.0,
num_docs,
}; };
if union.refill() { if union.refill() {
union.advance(); union.advance();
@@ -218,11 +224,11 @@ where
} }
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.docsets estimate_union(self.docsets.iter().map(DocSet::size_hint), self.num_docs)
.iter() }
.map(|docset| docset.size_hint())
.max() fn cost(&self) -> u64 {
.unwrap_or(0u32) self.docsets.iter().map(|docset| docset.cost()).sum()
} }
fn count_including_deleted(&mut self) -> u32 { fn count_including_deleted(&mut self) -> u32 {

View File

@@ -27,11 +27,17 @@ mod tests {
docs_list.iter().cloned().map(VecDocSet::from) docs_list.iter().cloned().map(VecDocSet::from)
} }
fn union_from_docs_list(docs_list: &[Vec<DocId>]) -> Box<dyn DocSet> { fn union_from_docs_list(docs_list: &[Vec<DocId>]) -> Box<dyn DocSet> {
let max_doc = docs_list
.iter()
.flat_map(|docs| docs.iter().copied())
.max()
.unwrap_or(0);
Box::new(BufferedUnionScorer::build( Box::new(BufferedUnionScorer::build(
vec_doc_set_from_docs_list(docs_list) vec_doc_set_from_docs_list(docs_list)
.map(|docset| ConstScorer::new(docset, 1.0)) .map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<ConstScorer<VecDocSet>>>(), .collect::<Vec<ConstScorer<VecDocSet>>>(),
DoNothingCombiner::default, DoNothingCombiner::default,
max_doc,
)) ))
} }
@@ -273,6 +279,7 @@ mod bench {
.map(|docset| ConstScorer::new(docset, 1.0)) .map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
DoNothingCombiner::default, DoNothingCombiner::default,
100_000,
); );
while v.doc() != TERMINATED { while v.doc() != TERMINATED {
v.advance(); v.advance();
@@ -294,6 +301,7 @@ mod bench {
.map(|docset| ConstScorer::new(docset, 1.0)) .map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
DoNothingCombiner::default, DoNothingCombiner::default,
100_000,
); );
while v.doc() != TERMINATED { while v.doc() != TERMINATED {
v.advance(); v.advance();

View File

@@ -99,6 +99,10 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
.unwrap_or(0u32) .unwrap_or(0u32)
} }
fn cost(&self) -> u64 {
self.docsets.iter().map(|docset| docset.cost()).sum()
}
fn count_including_deleted(&mut self) -> u32 { fn count_including_deleted(&mut self) -> u32 {
if self.doc == TERMINATED { if self.doc == TERMINATED {
return 0u32; return 0u32;