mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-23 02:29:57 +00:00
Add DocSet::cost() (#2707)
* query: add DocSet cost hint and use it for intersection ordering - Add DocSet::cost() - Use cost() instead of size_hint() to order scorers in intersect_scorers This isolates cost-related changes without the new seek APIs from PR #2538 * add comments --------- Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
This commit is contained in:
@@ -87,6 +87,17 @@ pub trait DocSet: Send {
|
||||
/// length of the docset.
|
||||
fn size_hint(&self) -> u32;
|
||||
|
||||
/// Returns a best-effort hint of the cost to consume the entire docset.
|
||||
///
|
||||
/// Consuming means calling advance until [`TERMINATED`] is returned.
|
||||
/// The cost should be relative to the cost of driving a Term query,
|
||||
/// which would be the number of documents in the DocSet.
|
||||
///
|
||||
/// By default this returns `size_hint()`.
|
||||
fn cost(&self) -> u64 {
|
||||
self.size_hint() as u64
|
||||
}
|
||||
|
||||
/// Returns the number documents matching.
|
||||
/// Calling this method consumes the `DocSet`.
|
||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||
@@ -134,6 +145,10 @@ impl DocSet for &mut dyn DocSet {
|
||||
(**self).size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
(**self).cost()
|
||||
}
|
||||
|
||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||
(**self).count(alive_bitset)
|
||||
}
|
||||
@@ -169,6 +184,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
|
||||
unboxed.size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
let unboxed: &TDocSet = self.borrow();
|
||||
unboxed.cost()
|
||||
}
|
||||
|
||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||
let unboxed: &mut TDocSet = self.borrow_mut();
|
||||
unboxed.count(alive_bitset)
|
||||
|
||||
@@ -667,12 +667,15 @@ mod bench {
|
||||
.read_postings(&TERM_D, IndexRecordOption::Basic)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let mut intersection = Intersection::new(vec![
|
||||
segment_postings_a,
|
||||
segment_postings_b,
|
||||
segment_postings_c,
|
||||
segment_postings_d,
|
||||
]);
|
||||
let mut intersection = Intersection::new(
|
||||
vec![
|
||||
segment_postings_a,
|
||||
segment_postings_b,
|
||||
segment_postings_c,
|
||||
segment_postings_d,
|
||||
],
|
||||
reader.searcher().num_docs() as u32,
|
||||
);
|
||||
while intersection.advance() != TERMINATED {}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -367,10 +367,14 @@ mod tests {
|
||||
checkpoints
|
||||
}
|
||||
|
||||
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> {
|
||||
fn compute_checkpoints_manual(
|
||||
term_scorers: Vec<TermScorer>,
|
||||
n: usize,
|
||||
max_doc: u32,
|
||||
) -> Vec<(DocId, Score)> {
|
||||
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
|
||||
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
|
||||
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default);
|
||||
let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default, max_doc);
|
||||
|
||||
let mut limit = Score::MIN;
|
||||
loop {
|
||||
@@ -478,7 +482,8 @@ mod tests {
|
||||
for top_k in 1..4 {
|
||||
let checkpoints_for_each_pruning =
|
||||
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
|
||||
let checkpoints_manual = compute_checkpoints_manual(term_scorers.clone(), top_k);
|
||||
let checkpoints_manual =
|
||||
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
|
||||
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
|
||||
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
|
||||
.iter()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use core::num;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
@@ -39,9 +40,11 @@ where
|
||||
))
|
||||
}
|
||||
|
||||
/// num_docs is the number of documents in the segment.
|
||||
fn scorer_union<TScoreCombiner>(
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
num_docs: u32,
|
||||
) -> SpecializedScorer
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
@@ -68,6 +71,7 @@ where
|
||||
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
||||
scorers,
|
||||
score_combiner_fn,
|
||||
num_docs,
|
||||
)));
|
||||
}
|
||||
}
|
||||
@@ -75,16 +79,19 @@ where
|
||||
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
|
||||
scorers,
|
||||
score_combiner_fn,
|
||||
num_docs,
|
||||
)))
|
||||
}
|
||||
|
||||
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
scorer: SpecializedScorer,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
num_docs: u32,
|
||||
) -> Box<dyn Scorer> {
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn);
|
||||
let union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
|
||||
Box::new(union_scorer)
|
||||
}
|
||||
SpecializedScorer::Other(scorer) => scorer,
|
||||
@@ -151,6 +158,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
boost: Score,
|
||||
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
|
||||
) -> crate::Result<SpecializedScorer> {
|
||||
let num_docs = reader.num_docs();
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||
// Indicate how should clauses are combined with other clauses.
|
||||
enum CombinationMethod {
|
||||
@@ -167,11 +175,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
match self.minimum_number_should_match {
|
||||
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
|
||||
1 => {
|
||||
let scorer_union = scorer_union(should_scorers, &score_combiner_fn);
|
||||
CombinationMethod::Required(scorer_union)
|
||||
}
|
||||
0 => CombinationMethod::Optional(scorer_union(
|
||||
should_scorers,
|
||||
&score_combiner_fn,
|
||||
num_docs,
|
||||
)),
|
||||
1 => CombinationMethod::Required(scorer_union(
|
||||
should_scorers,
|
||||
&score_combiner_fn,
|
||||
num_docs,
|
||||
)),
|
||||
n if num_of_should_scorers == n => {
|
||||
// When num_of_should_scorers equals the number of should clauses,
|
||||
// they are no different from must clauses.
|
||||
@@ -200,21 +213,21 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
};
|
||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::MustNot)
|
||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
|
||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs))
|
||||
.map(|specialized_scorer: SpecializedScorer| {
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
|
||||
});
|
||||
let positive_scorer = match (should_opt, must_scorers) {
|
||||
(CombinationMethod::Ignored, Some(must_scorers)) => {
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||
}
|
||||
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
|
||||
let must_scorer = intersect_scorers(must_scorers);
|
||||
let must_scorer = intersect_scorers(must_scorers, num_docs);
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(
|
||||
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
||||
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
@@ -222,8 +235,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
}
|
||||
}
|
||||
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
|
||||
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn));
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
||||
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
|
||||
}
|
||||
(CombinationMethod::Ignored, None) => {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
|
||||
@@ -233,7 +246,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
|
||||
};
|
||||
if let Some(exclude_scorer) = exclude_scorer_opt {
|
||||
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
|
||||
let positive_scorer_boxed =
|
||||
into_box_scorer(positive_scorer, &score_combiner_fn, num_docs);
|
||||
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
||||
positive_scorer_boxed,
|
||||
exclude_scorer,
|
||||
@@ -246,6 +260,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
|
||||
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let num_docs = reader.num_docs();
|
||||
if self.weights.is_empty() {
|
||||
Ok(Box::new(EmptyScorer))
|
||||
} else if self.weights.len() == 1 {
|
||||
@@ -258,12 +273,12 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
} else if self.scoring_enabled {
|
||||
self.complex_scorer(reader, boost, &self.score_combiner_fn)
|
||||
.map(|specialized_scorer| {
|
||||
into_box_scorer(specialized_scorer, &self.score_combiner_fn)
|
||||
into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs)
|
||||
})
|
||||
} else {
|
||||
self.complex_scorer(reader, boost, DoNothingCombiner::default)
|
||||
.map(|specialized_scorer| {
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -296,8 +311,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let mut union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn);
|
||||
let mut union_scorer = BufferedUnionScorer::build(
|
||||
term_scorers,
|
||||
&self.score_combiner_fn,
|
||||
reader.num_docs(),
|
||||
);
|
||||
for_each_scorer(&mut union_scorer, callback);
|
||||
}
|
||||
SpecializedScorer::Other(mut scorer) => {
|
||||
@@ -317,8 +335,11 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
|
||||
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let mut union_scorer =
|
||||
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn);
|
||||
let mut union_scorer = BufferedUnionScorer::build(
|
||||
term_scorers,
|
||||
&self.score_combiner_fn,
|
||||
reader.num_docs(),
|
||||
);
|
||||
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
|
||||
}
|
||||
SpecializedScorer::Other(mut scorer) => {
|
||||
|
||||
@@ -117,6 +117,10 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
|
||||
self.underlying.size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.underlying.cost()
|
||||
}
|
||||
|
||||
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
|
||||
self.underlying.count(alive_bitset)
|
||||
}
|
||||
|
||||
@@ -130,6 +130,10 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.docset.size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.docset.cost()
|
||||
}
|
||||
}
|
||||
|
||||
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
|
||||
|
||||
@@ -70,6 +70,10 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.scorer.size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.scorer.cost()
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
|
||||
@@ -146,6 +150,14 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
|
||||
.max()
|
||||
.unwrap_or(0u32)
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.chains
|
||||
.iter()
|
||||
.map(|docset| docset.cost())
|
||||
.max()
|
||||
.unwrap_or(0u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::docset::{DocSet, TERMINATED};
|
||||
use crate::query::size_hint::estimate_intersection;
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::{EmptyScorer, Scorer};
|
||||
use crate::{DocId, Score};
|
||||
@@ -11,14 +12,18 @@ use crate::{DocId, Score};
|
||||
/// For better performance, the function uses a
|
||||
/// specialized implementation if the two
|
||||
/// shortest scorers are `TermScorer`s.
|
||||
pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
|
||||
pub fn intersect_scorers(
|
||||
mut scorers: Vec<Box<dyn Scorer>>,
|
||||
num_docs_segment: u32,
|
||||
) -> Box<dyn Scorer> {
|
||||
if scorers.is_empty() {
|
||||
return Box::new(EmptyScorer);
|
||||
}
|
||||
if scorers.len() == 1 {
|
||||
return scorers.pop().unwrap();
|
||||
}
|
||||
scorers.sort_by_key(|scorer| scorer.size_hint());
|
||||
// Order by estimated cost to drive each scorer.
|
||||
scorers.sort_by_key(|scorer| scorer.cost());
|
||||
let doc = go_to_first_doc(&mut scorers[..]);
|
||||
if doc == TERMINATED {
|
||||
return Box::new(EmptyScorer);
|
||||
@@ -34,12 +39,14 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
|
||||
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
|
||||
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
|
||||
others: scorers,
|
||||
num_docs: num_docs_segment,
|
||||
});
|
||||
}
|
||||
Box::new(Intersection {
|
||||
left,
|
||||
right,
|
||||
others: scorers,
|
||||
num_docs: num_docs_segment,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -48,6 +55,7 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
|
||||
left: TDocSet,
|
||||
right: TDocSet,
|
||||
others: Vec<TOtherDocSet>,
|
||||
num_docs: u32,
|
||||
}
|
||||
|
||||
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
|
||||
@@ -66,10 +74,11 @@ fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
|
||||
}
|
||||
|
||||
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
|
||||
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> {
|
||||
/// num_docs is the number of documents in the segment.
|
||||
pub(crate) fn new(mut docsets: Vec<TDocSet>, num_docs: u32) -> Intersection<TDocSet, TDocSet> {
|
||||
let num_docsets = docsets.len();
|
||||
assert!(num_docsets >= 2);
|
||||
docsets.sort_by_key(|docset| docset.size_hint());
|
||||
docsets.sort_by_key(|docset| docset.cost());
|
||||
go_to_first_doc(&mut docsets);
|
||||
let left = docsets.remove(0);
|
||||
let right = docsets.remove(0);
|
||||
@@ -77,6 +86,7 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
|
||||
left,
|
||||
right,
|
||||
others: docsets,
|
||||
num_docs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,7 +151,19 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.left.size_hint()
|
||||
estimate_intersection(
|
||||
[self.left.size_hint(), self.right.size_hint()]
|
||||
.into_iter()
|
||||
.chain(self.others.iter().map(DocSet::size_hint)),
|
||||
self.num_docs,
|
||||
)
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
// What's the best way to compute the cost of an intersection?
|
||||
// For now we take the cost of the docset driver, which is the first docset.
|
||||
// If there are docsets that are bad at skipping, they should also influence the cost.
|
||||
self.left.cost()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,7 +191,7 @@ mod tests {
|
||||
{
|
||||
let left = VecDocSet::from(vec![1, 3, 9]);
|
||||
let right = VecDocSet::from(vec![3, 4, 9, 18]);
|
||||
let mut intersection = Intersection::new(vec![left, right]);
|
||||
let mut intersection = Intersection::new(vec![left, right], 10);
|
||||
assert_eq!(intersection.doc(), 3);
|
||||
assert_eq!(intersection.advance(), 9);
|
||||
assert_eq!(intersection.doc(), 9);
|
||||
@@ -179,7 +201,7 @@ mod tests {
|
||||
let a = VecDocSet::from(vec![1, 3, 9]);
|
||||
let b = VecDocSet::from(vec![3, 4, 9, 18]);
|
||||
let c = VecDocSet::from(vec![1, 5, 9, 111]);
|
||||
let mut intersection = Intersection::new(vec![a, b, c]);
|
||||
let mut intersection = Intersection::new(vec![a, b, c], 10);
|
||||
assert_eq!(intersection.doc(), 9);
|
||||
assert_eq!(intersection.advance(), TERMINATED);
|
||||
}
|
||||
@@ -189,7 +211,7 @@ mod tests {
|
||||
fn test_intersection_zero() {
|
||||
let left = VecDocSet::from(vec![0]);
|
||||
let right = VecDocSet::from(vec![0]);
|
||||
let mut intersection = Intersection::new(vec![left, right]);
|
||||
let mut intersection = Intersection::new(vec![left, right], 10);
|
||||
assert_eq!(intersection.doc(), 0);
|
||||
assert_eq!(intersection.advance(), TERMINATED);
|
||||
}
|
||||
@@ -198,7 +220,7 @@ mod tests {
|
||||
fn test_intersection_skip() {
|
||||
let left = VecDocSet::from(vec![0, 1, 2, 4]);
|
||||
let right = VecDocSet::from(vec![2, 5]);
|
||||
let mut intersection = Intersection::new(vec![left, right]);
|
||||
let mut intersection = Intersection::new(vec![left, right], 10);
|
||||
assert_eq!(intersection.seek(2), 2);
|
||||
assert_eq!(intersection.doc(), 2);
|
||||
}
|
||||
@@ -209,7 +231,7 @@ mod tests {
|
||||
|| {
|
||||
let left = VecDocSet::from(vec![4]);
|
||||
let right = VecDocSet::from(vec![2, 5]);
|
||||
Box::new(Intersection::new(vec![left, right]))
|
||||
Box::new(Intersection::new(vec![left, right], 10))
|
||||
},
|
||||
vec![0, 2, 4, 5, 6],
|
||||
);
|
||||
@@ -219,19 +241,22 @@ mod tests {
|
||||
let mut right = VecDocSet::from(vec![2, 5, 10]);
|
||||
left.advance();
|
||||
right.advance();
|
||||
Box::new(Intersection::new(vec![left, right]))
|
||||
Box::new(Intersection::new(vec![left, right], 10))
|
||||
},
|
||||
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
|
||||
);
|
||||
test_skip_against_unoptimized(
|
||||
|| {
|
||||
Box::new(Intersection::new(vec![
|
||||
VecDocSet::from(vec![1, 4, 5, 6]),
|
||||
VecDocSet::from(vec![1, 2, 5, 6]),
|
||||
VecDocSet::from(vec![1, 4, 5, 6]),
|
||||
VecDocSet::from(vec![1, 5, 6]),
|
||||
VecDocSet::from(vec![2, 4, 5, 7, 8]),
|
||||
]))
|
||||
Box::new(Intersection::new(
|
||||
vec![
|
||||
VecDocSet::from(vec![1, 4, 5, 6]),
|
||||
VecDocSet::from(vec![1, 2, 5, 6]),
|
||||
VecDocSet::from(vec![1, 4, 5, 6]),
|
||||
VecDocSet::from(vec![1, 5, 6]),
|
||||
VecDocSet::from(vec![2, 4, 5, 7, 8]),
|
||||
],
|
||||
10,
|
||||
))
|
||||
},
|
||||
vec![0, 1, 2, 3, 4, 5, 6, 7, 10, 11],
|
||||
);
|
||||
@@ -242,7 +267,7 @@ mod tests {
|
||||
let a = VecDocSet::from(vec![1, 3]);
|
||||
let b = VecDocSet::from(vec![1, 4]);
|
||||
let c = VecDocSet::from(vec![3, 9]);
|
||||
let intersection = Intersection::new(vec![a, b, c]);
|
||||
let intersection = Intersection::new(vec![a, b, c], 10);
|
||||
assert_eq!(intersection.doc(), TERMINATED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ mod regex_query;
|
||||
mod reqopt_scorer;
|
||||
mod scorer;
|
||||
mod set_query;
|
||||
mod size_hint;
|
||||
mod term_query;
|
||||
mod union;
|
||||
mod weight;
|
||||
|
||||
@@ -200,6 +200,10 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.phrase_scorer.size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.phrase_scorer.cost()
|
||||
}
|
||||
}
|
||||
|
||||
impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> {
|
||||
|
||||
@@ -368,6 +368,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
|
||||
slop: u32,
|
||||
offset: usize,
|
||||
) -> PhraseScorer<TPostings> {
|
||||
let num_docs = fieldnorm_reader.num_docs();
|
||||
let max_offset = term_postings_with_offset
|
||||
.iter()
|
||||
.map(|&(offset, _)| offset)
|
||||
@@ -382,7 +383,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut scorer = PhraseScorer {
|
||||
intersection_docset: Intersection::new(postings_with_offsets),
|
||||
intersection_docset: Intersection::new(postings_with_offsets, num_docs),
|
||||
num_terms: num_docsets,
|
||||
left_positions: Vec::with_capacity(100),
|
||||
right_positions: Vec::with_capacity(100),
|
||||
@@ -535,6 +536,15 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.intersection_docset.size_hint()
|
||||
}
|
||||
|
||||
/// Returns a best-effort hint of the
|
||||
/// cost to drive the docset.
|
||||
fn cost(&self) -> u64 {
|
||||
// Evaluating phrase matches is generally more expensive than simple term matches,
|
||||
// as it requires loading and comparing positions. Use a conservative multiplier
|
||||
// based on the number of terms.
|
||||
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
|
||||
|
||||
@@ -176,6 +176,14 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.column.num_docs()
|
||||
}
|
||||
|
||||
/// Returns a best-effort hint of the
|
||||
/// cost to drive the docset.
|
||||
fn cost(&self) -> u64 {
|
||||
// Advancing the docset is relatively expensive since it scans the column.
|
||||
// Keep cost relative to a term query driver; use num_docs as baseline.
|
||||
self.column.num_docs() as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -63,6 +63,10 @@ where
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.req_scorer.size_hint()
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.req_scorer.cost()
|
||||
}
|
||||
}
|
||||
|
||||
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer
|
||||
|
||||
141
src/query/size_hint.rs
Normal file
141
src/query/size_hint.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
/// Computes the estimated number of documents in the intersection of multiple docsets
|
||||
/// given their sizes.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
|
||||
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
|
||||
/// segment.
|
||||
///
|
||||
/// # Returns
|
||||
/// The estimated number of documents in the intersection.
|
||||
pub fn estimate_intersection<I>(mut docset_sizes: I, max_docs: u32) -> u32
|
||||
where I: Iterator<Item = u32> {
|
||||
if max_docs == 0u32 {
|
||||
return 0u32;
|
||||
}
|
||||
// Terms tend to be not really randomly distributed.
|
||||
// This factor is used to adjust the estimate.
|
||||
let mut co_loc_factor: f64 = 1.3;
|
||||
|
||||
let mut intersection_estimate = match docset_sizes.next() {
|
||||
Some(first_size) => first_size as f64,
|
||||
None => return 0, // No docsets provided, so return 0.
|
||||
};
|
||||
|
||||
let mut smallest_docset_size = intersection_estimate;
|
||||
// Assuming random distribution of terms, the probability of a document being in the
|
||||
// intersection
|
||||
for size in docset_sizes {
|
||||
// Diminish the co-location factor for each additional set, or we will overestimate.
|
||||
co_loc_factor = (co_loc_factor - 0.1).max(1.0);
|
||||
intersection_estimate *= (size as f64 / max_docs as f64) * co_loc_factor;
|
||||
smallest_docset_size = smallest_docset_size.min(size as f64);
|
||||
}
|
||||
|
||||
intersection_estimate.round().min(smallest_docset_size) as u32
|
||||
}
|
||||
|
||||
/// Computes the estimated number of documents in the union of multiple docsets
|
||||
/// given their sizes.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `docset_sizes` - An iterator over the sizes of the docsets (number of documents in each set).
|
||||
/// * `max_docs` - The maximum number of docs that can hit, usually number of documents in the
|
||||
/// segment.
|
||||
///
|
||||
/// # Returns
|
||||
/// The estimated number of documents in the union.
|
||||
pub fn estimate_union<I>(docset_sizes: I, max_docs: u32) -> u32
|
||||
where I: Iterator<Item = u32> {
|
||||
// Terms tend to be not really randomly distributed.
|
||||
// This factor is used to adjust the estimate.
|
||||
// Unlike intersection, the co-location reduces the estimate.
|
||||
let co_loc_factor = 0.8;
|
||||
|
||||
// The approach for union is to compute the probability of a document not being in any of the
|
||||
// sets
|
||||
let mut not_in_any_set_prob = 1.0;
|
||||
|
||||
// Assuming random distribution of terms, the probability of a document being in the
|
||||
// union is the complement of the probability of it not being in any of the sets.
|
||||
for size in docset_sizes {
|
||||
let prob_in_set = (size as f64 / max_docs as f64) * co_loc_factor;
|
||||
not_in_any_set_prob *= 1.0 - prob_in_set;
|
||||
}
|
||||
|
||||
let union_estimate = (max_docs as f64 * (1.0 - not_in_any_set_prob)).round();
|
||||
|
||||
union_estimate.min(max_docs as f64) as u32
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_estimate_intersection_small1() {
|
||||
let docset_sizes = &[500, 1000];
|
||||
let n = 10_000;
|
||||
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||
assert_eq!(result, 60);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_intersection_small2() {
|
||||
let docset_sizes = &[500, 1000, 1500];
|
||||
let n = 10_000;
|
||||
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||
assert_eq!(result, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_intersection_large_values() {
|
||||
let docset_sizes = &[100_000, 50_000, 30_000];
|
||||
let n = 1_000_000;
|
||||
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||
assert_eq!(result, 198);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_union_small() {
|
||||
let docset_sizes = &[500, 1000, 1500];
|
||||
let n = 10000;
|
||||
let result = estimate_union(docset_sizes.iter().copied(), n);
|
||||
assert_eq!(result, 2228);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_union_large_values() {
|
||||
let docset_sizes = &[100000, 50000, 30000];
|
||||
let n = 1000000;
|
||||
let result = estimate_union(docset_sizes.iter().copied(), n);
|
||||
assert_eq!(result, 137997);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_intersection_large() {
|
||||
let docset_sizes: Vec<_> = (0..10).map(|_| 4_000_000).collect();
|
||||
let n = 5_000_000;
|
||||
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||
// Check that it doesn't overflow and returns a reasonable result
|
||||
assert_eq!(result, 708_670);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_intersection_overflow_safety() {
|
||||
let docset_sizes: Vec<_> = (0..100).map(|_| 4_000_000).collect();
|
||||
let n = 5_000_000;
|
||||
let result = estimate_intersection(docset_sizes.iter().copied(), n);
|
||||
// Check that it doesn't overflow and returns a reasonable result
|
||||
assert_eq!(result, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_union_overflow_safety() {
|
||||
let docset_sizes: Vec<_> = (0..100).map(|_| 1_000_000).collect();
|
||||
let n = 20_000_000;
|
||||
let result = estimate_union(docset_sizes.iter().copied(), n);
|
||||
// Check that it doesn't overflow and returns a reasonable result
|
||||
assert_eq!(result, 19_662_594);
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ use common::TinySet;
|
||||
|
||||
use crate::docset::{DocSet, TERMINATED};
|
||||
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
||||
use crate::query::size_hint::estimate_union;
|
||||
use crate::query::Scorer;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
@@ -50,6 +51,8 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
doc: DocId,
|
||||
/// Combined score for current `doc` as produced by `TScoreCombiner`.
|
||||
score: Score,
|
||||
/// Number of documents in the segment.
|
||||
num_docs: u32,
|
||||
}
|
||||
|
||||
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
@@ -78,9 +81,11 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||
/// num_docs is the number of documents in the segment.
|
||||
pub(crate) fn build(
|
||||
docsets: Vec<TScorer>,
|
||||
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
|
||||
num_docs: u32,
|
||||
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
|
||||
let non_empty_docsets: Vec<TScorer> = docsets
|
||||
.into_iter()
|
||||
@@ -94,6 +99,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
|
||||
window_start_doc: 0,
|
||||
doc: 0,
|
||||
score: 0.0,
|
||||
num_docs,
|
||||
};
|
||||
if union.refill() {
|
||||
union.advance();
|
||||
@@ -218,11 +224,11 @@ where
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.docsets
|
||||
.iter()
|
||||
.map(|docset| docset.size_hint())
|
||||
.max()
|
||||
.unwrap_or(0u32)
|
||||
estimate_union(self.docsets.iter().map(DocSet::size_hint), self.num_docs)
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.docsets.iter().map(|docset| docset.cost()).sum()
|
||||
}
|
||||
|
||||
fn count_including_deleted(&mut self) -> u32 {
|
||||
|
||||
@@ -27,11 +27,17 @@ mod tests {
|
||||
docs_list.iter().cloned().map(VecDocSet::from)
|
||||
}
|
||||
fn union_from_docs_list(docs_list: &[Vec<DocId>]) -> Box<dyn DocSet> {
|
||||
let max_doc = docs_list
|
||||
.iter()
|
||||
.flat_map(|docs| docs.iter().copied())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
Box::new(BufferedUnionScorer::build(
|
||||
vec_doc_set_from_docs_list(docs_list)
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<ConstScorer<VecDocSet>>>(),
|
||||
DoNothingCombiner::default,
|
||||
max_doc,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -273,6 +279,7 @@ mod bench {
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<_>>(),
|
||||
DoNothingCombiner::default,
|
||||
100_000,
|
||||
);
|
||||
while v.doc() != TERMINATED {
|
||||
v.advance();
|
||||
@@ -294,6 +301,7 @@ mod bench {
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<_>>(),
|
||||
DoNothingCombiner::default,
|
||||
100_000,
|
||||
);
|
||||
while v.doc() != TERMINATED {
|
||||
v.advance();
|
||||
|
||||
@@ -99,6 +99,10 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
|
||||
.unwrap_or(0u32)
|
||||
}
|
||||
|
||||
fn cost(&self) -> u64 {
|
||||
self.docsets.iter().map(|docset| docset.cost()).sum()
|
||||
}
|
||||
|
||||
fn count_including_deleted(&mut self) -> u32 {
|
||||
if self.doc == TERMINATED {
|
||||
return 0u32;
|
||||
|
||||
Reference in New Issue
Block a user