mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-23 02:29:57 +00:00
Add DocSet::cost() (#2707)
* query: add DocSet cost hint and use it for intersection ordering - Add DocSet::cost() - Use cost() instead of size_hint() to order scorers in intersect_scorers This isolates cost-related changes without the new seek APIs from PR #2538 * add comments --------- Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
This commit is contained in:
@@ -87,6 +87,17 @@ pub trait DocSet: Send {
|
|||||||
/// length of the docset.
|
/// length of the docset.
|
||||||
fn size_hint(&self) -> u32;
|
fn size_hint(&self) -> u32;
|
||||||
|
|
||||||
|
/// Returns a best-effort hint of the cost to consume the entire docset.
|
||||||
|
///
|
||||||
|
/// Consuming means calling advance until [`TERMINATED`] is returned.
|
||||||
|
/// The cost should be relative to the cost of driving a Term query,
|
||||||
|
/// which would be the number of documents in the DocSet.
|
||||||
|
///
|
||||||
|
/// By default this returns `size_hint()`.
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.size_hint() as u64
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the number documents matching.
|
/// Returns the number documents matching.
|
||||||
/// Calling this method consumes the `DocSet`.
|
/// Calling this method consumes the `DocSet`.
|
||||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||||
@@ -134,6 +145,10 @@ impl DocSet for &mut dyn DocSet {
|
|||||||
(**self).size_hint()
|
(**self).size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
(**self).cost()
|
||||||
|
}
|
||||||
|
|
||||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||||
(**self).count(alive_bitset)
|
(**self).count(alive_bitset)
|
||||||
}
|
}
|
||||||
@@ -169,6 +184,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
|
|||||||
unboxed.size_hint()
|
unboxed.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
let unboxed: &TDocSet = self.borrow();
|
||||||
|
unboxed.cost()
|
||||||
|
}
|
||||||
|
|
||||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||||
let unboxed: &mut TDocSet = self.borrow_mut();
|
let unboxed: &mut TDocSet = self.borrow_mut();
|
||||||
unboxed.count(alive_bitset)
|
unboxed.count(alive_bitset)
|
||||||
|
|||||||
@@ -667,12 +667,15 @@ mod bench {
|
|||||||
.read_postings(&TERM_D, IndexRecordOption::Basic)
|
.read_postings(&TERM_D, IndexRecordOption::Basic)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let mut intersection = Intersection::new(vec![
|
let mut intersection = Intersection::new(
|
||||||
segment_postings_a,
|
vec![
|
||||||
segment_postings_b,
|
segment_postings_a,
|
||||||
segment_postings_c,
|
segment_postings_b,
|
||||||
segment_postings_d,
|
segment_postings_c,
|
||||||
]);
|
segment_postings_d,
|
||||||
|
],
|
||||||
|
reader.searcher().num_docs() as u32,
|
||||||
|
);
|
||||||
while intersection.advance() != TERMINATED {}
|
while intersection.advance() != TERMINATED {}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -367,10 +367,14 @@ mod tests {
|
|||||||
checkpoints
|
checkpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> {
|
fn compute_checkpoints_manual(
|
||||||
|
term_scorers: Vec<TermScorer>,
|
||||||
|
n: usize,
|
||||||
|
max_doc: u32,
|
||||||
|
) -> Vec<(DocId, Score)> {
|
||||||
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
|
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
|
||||||
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
|
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
|
||||||
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default);
|
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default, max_doc);
|
||||||
|
|
||||||
let mut limit = Score::MIN;
|
let mut limit = Score::MIN;
|
||||||
loop {
|
loop {
|
||||||
@@ -478,7 +482,8 @@ mod tests {
|
|||||||
for top_k in 1..4 {
|
for top_k in 1..4 {
|
||||||
let checkpoints_for_each_pruning =
|
let checkpoints_for_each_pruning =
|
||||||
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
|
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
|
||||||
let checkpoints_manual = compute_checkpoints_manual(term_scorers.clone(), top_k);
|
let checkpoints_manual =
|
||||||
|
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
|
||||||
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
|
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
|
||||||
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
|
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use core::num;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||||
@@ -39,9 +40,11 @@ where
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// num_docs is the number of documents in the segment.
|
||||||
fn scorer_union<TScoreCombiner>(
|
fn scorer_union<TScoreCombiner>(
|
||||||
scorers: Vec<Box<dyn Scorer>>,
|
scorers: Vec<Box<dyn Scorer>>,
|
||||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||||
|
num_docs: u32,
|
||||||
) -> SpecializedScorer
|
) -> SpecializedScorer
|
||||||
where
|
where
|
||||||
TScoreCombiner: ScoreCombiner,
|
TScoreCombiner: ScoreCombiner,
|
||||||
@@ -68,6 +71,7 @@ where
|
|||||||
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
||||||
scorers,
|
scorers,
|
||||||
score_combiner_fn,
|
score_combiner_fn,
|
||||||
|
num_docs,
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,16 +79,19 @@ where
|
|||||||
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
||||||
scorers,
|
scorers,
|
||||||
score_combiner_fn,
|
score_combiner_fn,
|
||||||
|
num_docs,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||||
scorer: SpecializedScorer,
|
scorer: SpecializedScorer,
|
||||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||||
|
num_docs: u32,
|
||||||
) -> Box<dyn Scorer> {
|
) -> Box<dyn Scorer> {
|
||||||
match scorer {
|
match scorer {
|
||||||
SpecializedScorer::TermUnion(term_scorers) => {
|
SpecializedScorer::TermUnion(term_scorers) => {
|
||||||
let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn);
|
let union_scorer =
|
||||||
|
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
|
||||||
Box::new(union_scorer)
|
Box::new(union_scorer)
|
||||||
}
|
}
|
||||||
SpecializedScorer::Other(scorer) => scorer,
|
SpecializedScorer::Other(scorer) => scorer,
|
||||||
@@ -151,6 +158,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
|||||||
boost: Score,
|
boost: Score,
|
||||||
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
|
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
|
||||||
) -> crate::Result<SpecializedScorer> {
|
) -> crate::Result<SpecializedScorer> {
|
||||||
|
let num_docs = reader.num_docs();
|
||||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||||
// Indicate how should clauses are combined with other clauses.
|
// Indicate how should clauses are combined with other clauses.
|
||||||
enum CombinationMethod {
|
enum CombinationMethod {
|
||||||
@@ -167,11 +175,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
|||||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||||
}
|
}
|
||||||
match self.minimum_number_should_match {
|
match self.minimum_number_should_match {
|
||||||
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
|
0 => CombinationMethod::Optional(scorer_union(
|
||||||
1 => {
|
should_scorers,
|
||||||
let scorer_union = scorer_union(should_scorers, &score_combiner_fn);
|
&score_combiner_fn,
|
||||||
CombinationMethod::Required(scorer_union)
|
num_docs,
|
||||||
}
|
)),
|
||||||
|
1 => CombinationMethod::Required(scorer_union(
|
||||||
|
should_scorers,
|
||||||
|
&score_combiner_fn,
|
||||||
|
num_docs,
|
||||||
|
)),
|
||||||
n if num_of_should_scorers == n => {
|
n if num_of_should_scorers == n => {
|
||||||
// When num_of_should_scorers equals the number of should clauses,
|
// When num_of_should_scorers equals the number of should clauses,
|
||||||
// they are no different from must clauses.
|
// they are no different from must clauses.
|
||||||
@@ -200,21 +213,21 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
|||||||
};
|
};
|
||||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||||
.remove(&Occur::MustNot)
|
.remove(&Occur::MustNot)
|
||||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
|
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
|
||||||
.map(|specialized_scorer: SpecializedScorer| {
|
.map(|specialized_scorer: SpecializedScorer| {
|
||||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
|
||||||
});
|
});
|
||||||
let positive_scorer = match (should_opt, must_scorers) {
|
let positive_scorer = match (should_opt, must_scorers) {
|
||||||
(CombinationMethod::Ignored, Some(must_scorers)) => {
|
(CombinationMethod::Ignored, Some(must_scorers)) => {
|
||||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||||
}
|
}
|
||||||
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
|
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
|
||||||
let must_scorer = intersect_scorers(must_scorers);
|
let must_scorer = intersect_scorers(must_scorers, num_docs);
|
||||||
if self.scoring_enabled {
|
if self.scoring_enabled {
|
||||||
SpecializedScorer::Other(Box::new(
|
SpecializedScorer::Other(Box::new(
|
||||||
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
|
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
|
||||||
must_scorer,
|
must_scorer,
|
||||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
@@ -222,8 +235,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
|
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
|
||||||
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn));
|
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
|
||||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||||
}
|
}
|
||||||
(CombinationMethod::Ignored, None) => {
|
(CombinationMethod::Ignored, None) => {
|
||||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
|
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
|
||||||
@@ -233,7 +246,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
|||||||
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
|
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
|
||||||
};
|
};
|
||||||
if let Some(exclude_scorer) = exclude_scorer_opt {
|
if let Some(exclude_scorer) = exclude_scorer_opt {
|
||||||
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
|
let positive_scorer_boxed =
|
||||||
|
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
|
||||||
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
||||||
positive_scorer_boxed,
|
positive_scorer_boxed,
|
||||||
exclude_scorer,
|
exclude_scorer,
|
||||||
@@ -246,6 +260,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
|||||||
|
|
||||||
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
|
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
|
||||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||||
|
let num_docs = reader.num_docs();
|
||||||
if self.weights.is_empty() {
|
if self.weights.is_empty() {
|
||||||
Ok(Box::new(EmptyScorer))
|
Ok(Box::new(EmptyScorer))
|
||||||
} else if self.weights.len() == 1 {
|
} else if self.weights.len() == 1 {
|
||||||
@@ -258,12 +273,12 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
|||||||
} else if self.scoring_enabled {
|
} else if self.scoring_enabled {
|
||||||
self.complex_scorer(reader, boost, &self.score_combiner_fn)
|
self.complex_scorer(reader, boost, &self.score_combiner_fn)
|
||||||
.map(|specialized_scorer| {
|
.map(|specialized_scorer| {
|
||||||
into_box_scorer(specialized_scorer, &self.score_combiner_fn)
|
into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
self.complex_scorer(reader, boost, DoNothingCombiner::default)
|
self.complex_scorer(reader, boost, DoNothingCombiner::default)
|
||||||
.map(|specialized_scorer| {
|
.map(|specialized_scorer| {
|
||||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -296,8 +311,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
|||||||
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
|
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
|
||||||
match scorer {
|
match scorer {
|
||||||
SpecializedScorer::TermUnion(term_scorers) => {
|
SpecializedScorer::TermUnion(term_scorers) => {
|
||||||
let mut union_scorer =
|
let mut union_scorer = BufferedUnionScorer::build(
|
||||||
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn);
|
term_scorers,
|
||||||
|
&self.score_combiner_fn,
|
||||||
|
reader.num_docs(),
|
||||||
|
);
|
||||||
for_each_scorer(&mut union_scorer, callback);
|
for_each_scorer(&mut union_scorer, callback);
|
||||||
}
|
}
|
||||||
SpecializedScorer::Other(mut scorer) => {
|
SpecializedScorer::Other(mut scorer) => {
|
||||||
@@ -317,8 +335,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
|||||||
|
|
||||||
match scorer {
|
match scorer {
|
||||||
SpecializedScorer::TermUnion(term_scorers) => {
|
SpecializedScorer::TermUnion(term_scorers) => {
|
||||||
let mut union_scorer =
|
let mut union_scorer = BufferedUnionScorer::build(
|
||||||
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn);
|
term_scorers,
|
||||||
|
&self.score_combiner_fn,
|
||||||
|
reader.num_docs(),
|
||||||
|
);
|
||||||
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
|
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
|
||||||
}
|
}
|
||||||
SpecializedScorer::Other(mut scorer) => {
|
SpecializedScorer::Other(mut scorer) => {
|
||||||
|
|||||||
@@ -117,6 +117,10 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
|
|||||||
self.underlying.size_hint()
|
self.underlying.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.underlying.cost()
|
||||||
|
}
|
||||||
|
|
||||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||||
self.underlying.count(alive_bitset)
|
self.underlying.count(alive_bitset)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,6 +130,10 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
|
|||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.docset.size_hint()
|
self.docset.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.docset.cost()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
|
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
|
|||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.scorer.size_hint()
|
self.scorer.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.scorer.cost()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
|
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
|
||||||
@@ -146,6 +150,14 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
|
|||||||
.max()
|
.max()
|
||||||
.unwrap_or(0u32)
|
.unwrap_or(0u32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.chains
|
||||||
|
.iter()
|
||||||
|
.map(|docset| docset.cost())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(0u64)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
|
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::docset::{DocSet, TERMINATED};
|
use crate::docset::{DocSet, TERMINATED};
|
||||||
|
use crate::query::size_hint::estimate_intersection;
|
||||||
use crate::query::term_query::TermScorer;
|
use crate::query::term_query::TermScorer;
|
||||||
use crate::query::{EmptyScorer, Scorer};
|
use crate::query::{EmptyScorer, Scorer};
|
||||||
use crate::{DocId, Score};
|
use crate::{DocId, Score};
|
||||||
@@ -11,14 +12,18 @@ use crate::{DocId, Score};
|
|||||||
/// For better performance, the function uses a
|
/// For better performance, the function uses a
|
||||||
/// specialized implementation if the two
|
/// specialized implementation if the two
|
||||||
/// shortest scorers are `TermScorer`s.
|
/// shortest scorers are `TermScorer`s.
|
||||||
pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
|
pub fn intersect_scorers(
|
||||||
|
mut scorers: Vec<Box<dyn Scorer>>,
|
||||||
|
num_docs_segment: u32,
|
||||||
|
) -> Box<dyn Scorer> {
|
||||||
if scorers.is_empty() {
|
if scorers.is_empty() {
|
||||||
return Box::new(EmptyScorer);
|
return Box::new(EmptyScorer);
|
||||||
}
|
}
|
||||||
if scorers.len() == 1 {
|
if scorers.len() == 1 {
|
||||||
return scorers.pop().unwrap();
|
return scorers.pop().unwrap();
|
||||||
}
|
}
|
||||||
scorers.sort_by_key(|scorer| scorer.size_hint());
|
// Order by estimated cost to drive each scorer.
|
||||||
|
scorers.sort_by_key(|scorer| scorer.cost());
|
||||||
let doc = go_to_first_doc(&mut scorers[..]);
|
let doc = go_to_first_doc(&mut scorers[..]);
|
||||||
if doc == TERMINATED {
|
if doc == TERMINATED {
|
||||||
return Box::new(EmptyScorer);
|
return Box::new(EmptyScorer);
|
||||||
@@ -34,12 +39,14 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
|
|||||||
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
|
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
|
||||||
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
|
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
|
||||||
others: scorers,
|
others: scorers,
|
||||||
|
num_docs: num_docs_segment,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Box::new(Intersection {
|
Box::new(Intersection {
|
||||||
left,
|
left,
|
||||||
right,
|
right,
|
||||||
others: scorers,
|
others: scorers,
|
||||||
|
num_docs: num_docs_segment,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,6 +55,7 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
|
|||||||
left: TDocSet,
|
left: TDocSet,
|
||||||
right: TDocSet,
|
right: TDocSet,
|
||||||
others: Vec<TOtherDocSet>,
|
others: Vec<TOtherDocSet>,
|
||||||
|
num_docs: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
|
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
|
||||||
@@ -66,10 +74,11 @@ fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
|
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
|
||||||
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> {
|
/// num_docs is the number of documents in the segment.
|
||||||
|
pub(crate) fn new(mut docsets: Vec<TDocSet>, num_docs: u32) -> Intersection<TDocSet, TDocSet> {
|
||||||
let num_docsets = docsets.len();
|
let num_docsets = docsets.len();
|
||||||
assert!(num_docsets >= 2);
|
assert!(num_docsets >= 2);
|
||||||
docsets.sort_by_key(|docset| docset.size_hint());
|
docsets.sort_by_key(|docset| docset.cost());
|
||||||
go_to_first_doc(&mut docsets);
|
go_to_first_doc(&mut docsets);
|
||||||
let left = docsets.remove(0);
|
let left = docsets.remove(0);
|
||||||
let right = docsets.remove(0);
|
let right = docsets.remove(0);
|
||||||
@@ -77,6 +86,7 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
|
|||||||
left,
|
left,
|
||||||
right,
|
right,
|
||||||
others: docsets,
|
others: docsets,
|
||||||
|
num_docs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -141,7 +151,19 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.left.size_hint()
|
estimate_intersection(
|
||||||
|
[self.left.size_hint(), self.right.size_hint()]
|
||||||
|
.into_iter()
|
||||||
|
.chain(self.others.iter().map(DocSet::size_hint)),
|
||||||
|
self.num_docs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
// What's the best way to compute the cost of an intersection?
|
||||||
|
// For now we take the cost of the docset driver, which is the first docset.
|
||||||
|
// If there are docsets that are bad at skipping, they should also influence the cost.
|
||||||
|
self.left.cost()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,7 +191,7 @@ mod tests {
|
|||||||
{
|
{
|
||||||
let left = VecDocSet::from(vec![1, 3, 9]);
|
let left = VecDocSet::from(vec![1, 3, 9]);
|
||||||
let right = VecDocSet::from(vec![3, 4, 9, 18]);
|
let right = VecDocSet::from(vec![3, 4, 9, 18]);
|
||||||
let mut intersection = Intersection::new(vec![left, right]);
|
let mut intersection = Intersection::new(vec![left, right], 10);
|
||||||
assert_eq!(intersection.doc(), 3);
|
assert_eq!(intersection.doc(), 3);
|
||||||
assert_eq!(intersection.advance(), 9);
|
assert_eq!(intersection.advance(), 9);
|
||||||
assert_eq!(intersection.doc(), 9);
|
assert_eq!(intersection.doc(), 9);
|
||||||
@@ -179,7 +201,7 @@ mod tests {
|
|||||||
let a = VecDocSet::from(vec![1, 3, 9]);
|
let a = VecDocSet::from(vec![1, 3, 9]);
|
||||||
let b = VecDocSet::from(vec![3, 4, 9, 18]);
|
let b = VecDocSet::from(vec![3, 4, 9, 18]);
|
||||||
let c = VecDocSet::from(vec![1, 5, 9, 111]);
|
let c = VecDocSet::from(vec![1, 5, 9, 111]);
|
||||||
let mut intersection = Intersection::new(vec![a, b, c]);
|
let mut intersection = Intersection::new(vec![a, b, c], 10);
|
||||||
assert_eq!(intersection.doc(), 9);
|
assert_eq!(intersection.doc(), 9);
|
||||||
assert_eq!(intersection.advance(), TERMINATED);
|
assert_eq!(intersection.advance(), TERMINATED);
|
||||||
}
|
}
|
||||||
@@ -189,7 +211,7 @@ mod tests {
|
|||||||
fn test_intersection_zero() {
|
fn test_intersection_zero() {
|
||||||
let left = VecDocSet::from(vec![0]);
|
let left = VecDocSet::from(vec![0]);
|
||||||
let right = VecDocSet::from(vec![0]);
|
let right = VecDocSet::from(vec![0]);
|
||||||
let mut intersection = Intersection::new(vec![left, right]);
|
let mut intersection = Intersection::new(vec![left, right], 10);
|
||||||
assert_eq!(intersection.doc(), 0);
|
assert_eq!(intersection.doc(), 0);
|
||||||
assert_eq!(intersection.advance(), TERMINATED);
|
assert_eq!(intersection.advance(), TERMINATED);
|
||||||
}
|
}
|
||||||
@@ -198,7 +220,7 @@ mod tests {
|
|||||||
fn test_intersection_skip() {
|
fn test_intersection_skip() {
|
||||||
let left = VecDocSet::from(vec![0, 1, 2, 4]);
|
let left = VecDocSet::from(vec![0, 1, 2, 4]);
|
||||||
let right = VecDocSet::from(vec![2, 5]);
|
let right = VecDocSet::from(vec![2, 5]);
|
||||||
let mut intersection = Intersection::new(vec![left, right]);
|
let mut intersection = Intersection::new(vec![left, right], 10);
|
||||||
assert_eq!(intersection.seek(2), 2);
|
assert_eq!(intersection.seek(2), 2);
|
||||||
assert_eq!(intersection.doc(), 2);
|
assert_eq!(intersection.doc(), 2);
|
||||||
}
|
}
|
||||||
@@ -209,7 +231,7 @@ mod tests {
|
|||||||
|| {
|
|| {
|
||||||
let left = VecDocSet::from(vec![4]);
|
let left = VecDocSet::from(vec![4]);
|
||||||
let right = VecDocSet::from(vec![2, 5]);
|
let right = VecDocSet::from(vec![2, 5]);
|
||||||
Box::new(Intersection::new(vec![left, right]))
|
Box::new(Intersection::new(vec![left, right], 10))
|
||||||
},
|
},
|
||||||
vec![0, 2, 4, 5, 6],
|
vec![0, 2, 4, 5, 6],
|
||||||
);
|
);
|
||||||
@@ -219,19 +241,22 @@ mod tests {
|
|||||||
let mut right = VecDocSet::from(vec![2, 5, 10]);
|
let mut right = VecDocSet::from(vec![2, 5, 10]);
|
||||||
left.advance();
|
left.advance();
|
||||||
right.advance();
|
right.advance();
|
||||||
Box::new(Intersection::new(vec![left, right]))
|
Box::new(Intersection::new(vec![left, right], 10))
|
||||||
},
|
},
|
||||||
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
|
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
|
||||||
);
|
);
|
||||||
test_skip_against_unoptimized(
|
test_skip_against_unoptimized(
|
||||||
|| {
|
|| {
|
||||||
Box::new(Intersection::new(vec![
|
Box::new(Intersection::new(
|
||||||
VecDocSet::from(vec![1, 4, 5, 6]),
|
vec![
|
||||||
VecDocSet::from(vec![1, 2, 5, 6]),
|
VecDocSet::from(vec![1, 4, 5, 6]),
|
||||||
VecDocSet::from(vec![1, 4, 5, 6]),
|
VecDocSet::from(vec![1, 2, 5, 6]),
|
||||||
VecDocSet::from(vec![1, 5, 6]),
|
VecDocSet::from(vec![1, 4, 5, 6]),
|
||||||
VecDocSet::from(vec![2, 4, 5, 7, 8]),
|
VecDocSet::from(vec![1, 5, 6]),
|
||||||
]))
|
VecDocSet::from(vec![2, 4, 5, 7, 8]),
|
||||||
|
],
|
||||||
|
10,
|
||||||
|
))
|
||||||
},
|
},
|
||||||
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
|
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
|
||||||
);
|
);
|
||||||
@@ -242,7 +267,7 @@ mod tests {
|
|||||||
let a = VecDocSet::from(vec![1, 3]);
|
let a = VecDocSet::from(vec![1, 3]);
|
||||||
let b = VecDocSet::from(vec![1, 4]);
|
let b = VecDocSet::from(vec![1, 4]);
|
||||||
let c = VecDocSet::from(vec![3, 9]);
|
let c = VecDocSet::from(vec![3, 9]);
|
||||||
let intersection = Intersection::new(vec![a, b, c]);
|
let intersection = Intersection::new(vec![a, b, c], 10);
|
||||||
assert_eq!(intersection.doc(), TERMINATED);
|
assert_eq!(intersection.doc(), TERMINATED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ mod regex_query;
|
|||||||
mod reqopt_scorer;
|
mod reqopt_scorer;
|
||||||
mod scorer;
|
mod scorer;
|
||||||
mod set_query;
|
mod set_query;
|
||||||
|
mod size_hint;
|
||||||
mod term_query;
|
mod term_query;
|
||||||
mod union;
|
mod union;
|
||||||
mod weight;
|
mod weight;
|
||||||
|
|||||||
@@ -200,6 +200,10 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
|
|||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.phrase_scorer.size_hint()
|
self.phrase_scorer.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.phrase_scorer.cost()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {
|
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {
|
||||||
|
|||||||
@@ -368,6 +368,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
|
|||||||
slop: u32,
|
slop: u32,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
) -> PhraseScorer<TPostings> {
|
) -> PhraseScorer<TPostings> {
|
||||||
|
let num_docs = fieldnorm_reader.num_docs();
|
||||||
let max_offset = term_postings_with_offset
|
let max_offset = term_postings_with_offset
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&(offset, _)| offset)
|
.map(|&(offset, _)| offset)
|
||||||
@@ -382,7 +383,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
|
|||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let mut scorer = PhraseScorer {
|
let mut scorer = PhraseScorer {
|
||||||
intersection_docset: Intersection::new(postings_with_offsets),
|
intersection_docset: Intersection::new(postings_with_offsets, num_docs),
|
||||||
num_terms: num_docsets,
|
num_terms: num_docsets,
|
||||||
left_positions: Vec::with_capacity(100),
|
left_positions: Vec::with_capacity(100),
|
||||||
right_positions: Vec::with_capacity(100),
|
right_positions: Vec::with_capacity(100),
|
||||||
@@ -535,6 +536,15 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
|
|||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.intersection_docset.size_hint()
|
self.intersection_docset.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a best-effort hint of the
|
||||||
|
/// cost to drive the docset.
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
// Evaluating phrase matches is generally more expensive than simple term matches,
|
||||||
|
// as it requires loading and comparing positions. Use a conservative multiplier
|
||||||
|
// based on the number of terms.
|
||||||
|
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
|
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
|
||||||
|
|||||||
@@ -176,6 +176,14 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
|
|||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.column.num_docs()
|
self.column.num_docs()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a best-effort hint of the
|
||||||
|
/// cost to drive the docset.
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
// Advancing the docset is relatively expensive since it scans the column.
|
||||||
|
// Keep cost relative to a term query driver; use num_docs as baseline.
|
||||||
|
self.column.num_docs() as u64
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -63,6 +63,10 @@ where
|
|||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.req_scorer.size_hint()
|
self.req_scorer.size_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.req_scorer.cost()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer
|
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer
|
||||||
|
|||||||
141
src/query/size_hint.rs
Normal file
141
src/query/size_hint.rs
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
/// Computes the estimated number of documents in the intersection of multiple docsets
|
||||||
|
/// given their sizes.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
|
||||||
|
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
|
||||||
|
/// segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// The estimated number of documents in the intersection.
|
||||||
|
pub fn estimate_intersection<I>(mut docset_sizes: I, max_docs: u32) -> u32
|
||||||
|
where I: Iterator<Item = u32> {
|
||||||
|
if max_docs == 0u32 {
|
||||||
|
return 0u32;
|
||||||
|
}
|
||||||
|
// Terms tend to be not really randomly distributed.
|
||||||
|
// This factor is used to adjust the estimate.
|
||||||
|
let mut co_loc_factor: f64 = 1.3;
|
||||||
|
|
||||||
|
let mut intersection_estimate = match docset_sizes.next() {
|
||||||
|
Some(first_size) => first_size as f64,
|
||||||
|
None => return 0, // No docsets provided, so return 0.
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut smallest_docset_size = intersection_estimate;
|
||||||
|
// Assuming random distribution of terms, the probability of a document being in the
|
||||||
|
// intersection
|
||||||
|
for size in docset_sizes {
|
||||||
|
// Diminish the co-location factor for each additional set, or we will overestimate.
|
||||||
|
co_loc_factor = (co_loc_factor - 0.1).max(1.0);
|
||||||
|
intersection_estimate *= (size as f64 / max_docs as f64) * co_loc_factor;
|
||||||
|
smallest_docset_size = smallest_docset_size.min(size as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
intersection_estimate.round().min(smallest_docset_size) as u32
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the estimated number of documents in the union of multiple docsets
|
||||||
|
/// given their sizes.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
|
||||||
|
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
|
||||||
|
/// segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// The estimated number of documents in the union.
|
||||||
|
pub fn estimate_union<I>(docset_sizes: I, max_docs: u32) -> u32
|
||||||
|
where I: Iterator<Item = u32> {
|
||||||
|
// Terms tend to be not really randomly distributed.
|
||||||
|
// This factor is used to adjust the estimate.
|
||||||
|
// Unlike intersection, the co-location reduces the estimate.
|
||||||
|
let co_loc_factor = 0.8;
|
||||||
|
|
||||||
|
// The approach for union is to compute the probability of a document not being in any of the
|
||||||
|
// sets
|
||||||
|
let mut not_in_any_set_prob = 1.0;
|
||||||
|
|
||||||
|
// Assuming random distribution of terms, the probability of a document being in the
|
||||||
|
// union is the complement of the probability of it not being in any of the sets.
|
||||||
|
for size in docset_sizes {
|
||||||
|
let prob_in_set = (size as f64 / max_docs as f64) * co_loc_factor;
|
||||||
|
not_in_any_set_prob *= 1.0 - prob_in_set;
|
||||||
|
}
|
||||||
|
|
||||||
|
let union_estimate = (max_docs as f64 * (1.0 - not_in_any_set_prob)).round();
|
||||||
|
|
||||||
|
union_estimate.min(max_docs as f64) as u32
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_intersection_small1() {
|
||||||
|
let docset_sizes = &[500, 1000];
|
||||||
|
let n = 10_000;
|
||||||
|
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||||
|
assert_eq!(result, 60);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_intersection_small2() {
|
||||||
|
let docset_sizes = &[500, 1000, 1500];
|
||||||
|
let n = 10_000;
|
||||||
|
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||||
|
assert_eq!(result, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_intersection_large_values() {
|
||||||
|
let docset_sizes = &[100_000, 50_000, 30_000];
|
||||||
|
let n = 1_000_000;
|
||||||
|
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||||
|
assert_eq!(result, 198);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_union_small() {
|
||||||
|
let docset_sizes = &[500, 1000, 1500];
|
||||||
|
let n = 10000;
|
||||||
|
let result = estimate_union(docset_sizes.iter().copied(), n);
|
||||||
|
assert_eq!(result, 2228);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_union_large_values() {
|
||||||
|
let docset_sizes = &[100000, 50000, 30000];
|
||||||
|
let n = 1000000;
|
||||||
|
let result = estimate_union(docset_sizes.iter().copied(), n);
|
||||||
|
assert_eq!(result, 137997);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_intersection_large() {
|
||||||
|
let docset_sizes: Vec<_> = (0..10).map(|_| 4_000_000).collect();
|
||||||
|
let n = 5_000_000;
|
||||||
|
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||||
|
// Check that it doesn't overflow and returns a reasonable result
|
||||||
|
assert_eq!(result, 708_670);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_intersection_overflow_safety() {
|
||||||
|
let docset_sizes: Vec<_> = (0..100).map(|_| 4_000_000).collect();
|
||||||
|
let n = 5_000_000;
|
||||||
|
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||||
|
// Check that it doesn't overflow and returns a reasonable result
|
||||||
|
assert_eq!(result, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_estimate_union_overflow_safety() {
|
||||||
|
let docset_sizes: Vec<_> = (0..100).map(|_| 1_000_000).collect();
|
||||||
|
let n = 20_000_000;
|
||||||
|
let result = estimate_union(docset_sizes.iter().copied(), n);
|
||||||
|
// Check that it doesn't overflow and returns a reasonable result
|
||||||
|
assert_eq!(result, 19_662_594);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ use common::TinySet;
|
|||||||
|
|
||||||
use crate::docset::{DocSet, TERMINATED};
|
use crate::docset::{DocSet, TERMINATED};
|
||||||
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
||||||
|
use crate::query::size_hint::estimate_union;
|
||||||
use crate::query::Scorer;
|
use crate::query::Scorer;
|
||||||
use crate::{DocId, Score};
|
use crate::{DocId, Score};
|
||||||
|
|
||||||
@@ -50,6 +51,8 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
|
|||||||
doc: DocId,
|
doc: DocId,
|
||||||
/// Combined score for current `doc` as produced by `TScoreCombiner`.
|
/// Combined score for current `doc` as produced by `TScoreCombiner`.
|
||||||
score: Score,
|
score: Score,
|
||||||
|
/// Number of documents in the segment.
|
||||||
|
num_docs: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||||
@@ -78,9 +81,11 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||||
|
/// num_docs is the number of documents in the segment.
|
||||||
pub(crate) fn build(
|
pub(crate) fn build(
|
||||||
docsets: Vec<TScorer>,
|
docsets: Vec<TScorer>,
|
||||||
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
|
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
|
||||||
|
num_docs: u32,
|
||||||
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||||
let non_empty_docsets: Vec<TScorer> = docsets
|
let non_empty_docsets: Vec<TScorer> = docsets
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -94,6 +99,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
|||||||
window_start_doc: 0,
|
window_start_doc: 0,
|
||||||
doc: 0,
|
doc: 0,
|
||||||
score: 0.0,
|
score: 0.0,
|
||||||
|
num_docs,
|
||||||
};
|
};
|
||||||
if union.refill() {
|
if union.refill() {
|
||||||
union.advance();
|
union.advance();
|
||||||
@@ -218,11 +224,11 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn size_hint(&self) -> u32 {
|
fn size_hint(&self) -> u32 {
|
||||||
self.docsets
|
estimate_union(self.docsets.iter().map(DocSet::size_hint), self.num_docs)
|
||||||
.iter()
|
}
|
||||||
.map(|docset| docset.size_hint())
|
|
||||||
.max()
|
fn cost(&self) -> u64 {
|
||||||
.unwrap_or(0u32)
|
self.docsets.iter().map(|docset| docset.cost()).sum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_including_deleted(&mut self) -> u32 {
|
fn count_including_deleted(&mut self) -> u32 {
|
||||||
|
|||||||
@@ -27,11 +27,17 @@ mod tests {
|
|||||||
docs_list.iter().cloned().map(VecDocSet::from)
|
docs_list.iter().cloned().map(VecDocSet::from)
|
||||||
}
|
}
|
||||||
fn union_from_docs_list(docs_list: &[Vec<DocId>]) -> Box<dyn DocSet> {
|
fn union_from_docs_list(docs_list: &[Vec<DocId>]) -> Box<dyn DocSet> {
|
||||||
|
let max_doc = docs_list
|
||||||
|
.iter()
|
||||||
|
.flat_map(|docs| docs.iter().copied())
|
||||||
|
.max()
|
||||||
|
.unwrap_or(0);
|
||||||
Box::new(BufferedUnionScorer::build(
|
Box::new(BufferedUnionScorer::build(
|
||||||
vec_doc_set_from_docs_list(docs_list)
|
vec_doc_set_from_docs_list(docs_list)
|
||||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||||
.collect::<Vec<ConstScorer<VecDocSet>>>(),
|
.collect::<Vec<ConstScorer<VecDocSet>>>(),
|
||||||
DoNothingCombiner::default,
|
DoNothingCombiner::default,
|
||||||
|
max_doc,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,6 +279,7 @@ mod bench {
|
|||||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
DoNothingCombiner::default,
|
DoNothingCombiner::default,
|
||||||
|
100_000,
|
||||||
);
|
);
|
||||||
while v.doc() != TERMINATED {
|
while v.doc() != TERMINATED {
|
||||||
v.advance();
|
v.advance();
|
||||||
@@ -294,6 +301,7 @@ mod bench {
|
|||||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
DoNothingCombiner::default,
|
DoNothingCombiner::default,
|
||||||
|
100_000,
|
||||||
);
|
);
|
||||||
while v.doc() != TERMINATED {
|
while v.doc() != TERMINATED {
|
||||||
v.advance();
|
v.advance();
|
||||||
|
|||||||
@@ -99,6 +99,10 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
|
|||||||
.unwrap_or(0u32)
|
.unwrap_or(0u32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cost(&self) -> u64 {
|
||||||
|
self.docsets.iter().map(|docset| docset.cost()).sum()
|
||||||
|
}
|
||||||
|
|
||||||
fn count_including_deleted(&mut self) -> u32 {
|
fn count_including_deleted(&mut self) -> u32 {
|
||||||
if self.doc == TERMINATED {
|
if self.doc == TERMINATED {
|
||||||
return 0u32;
|
return 0u32;
|
||||||
|
|||||||
Reference in New Issue
Block a user