[feat] Implement DisjunctionMaxQuery and refactor ScoreCombiner

This commit is contained in:
Pasha Podolsky
2022-07-28 20:47:20 +03:00
parent da0f78e06c
commit 09aae134e6
8 changed files with 265 additions and 65 deletions

View File

@@ -371,7 +371,7 @@ mod tests {
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut scorer: Union<TermScorer, SumCombiner> = Union::from(term_scorers);
let mut scorer = Union::build(term_scorers, SumCombiner::default);
let mut limit = Score::MIN;
loop {

View File

@@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use super::boolean_weight::BooleanWeight;
use crate::query::{Occur, Query, TermQuery, Weight};
use crate::query::{Occur, Query, SumWithCoordsCombiner, TermQuery, Weight};
use crate::schema::{IndexRecordOption, Term};
use crate::Searcher;
@@ -153,7 +153,11 @@ impl Query for BooleanQuery {
Ok((*occur, subquery.weight(searcher, scoring_enabled)?))
})
.collect::<crate::Result<_>>()?;
Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled)))
Ok(Box::new(BooleanWeight::new(
sub_weights,
scoring_enabled,
Box::new(SumWithCoordsCombiner::default),
)))
}
fn query_terms(&self, terms: &mut BTreeMap<Term, bool>) {

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use crate::core::SegmentReader;
use crate::postings::FreqReadingOption;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
use crate::query::{
@@ -17,8 +17,13 @@ enum SpecializedScorer {
Other(Box<dyn Scorer>),
}
fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> SpecializedScorer
where TScoreCombiner: ScoreCombiner {
fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner,
) -> SpecializedScorer
where
TScoreCombiner: ScoreCombiner,
{
assert!(!scorers.is_empty());
if scorers.len() == 1 {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehands
@@ -38,35 +43,45 @@ where TScoreCombiner: ScoreCombiner {
// Block wand is only available if we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else {
return SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(
return SpecializedScorer::Other(Box::new(Union::build(
scorers,
score_combiner_fn,
)));
}
}
}
SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers)))
SpecializedScorer::Other(Box::new(Union::build(scorers, score_combiner_fn)))
}
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(scorer: SpecializedScorer) -> Box<dyn Scorer> {
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
scorer: SpecializedScorer,
score_combiner_fn: impl Fn() -> TScoreCombiner,
) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer = Union::<TermScorer, TScoreCombiner>::from(term_scorers);
let union_scorer = Union::build(term_scorers, score_combiner_fn);
Box::new(union_scorer)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
pub struct BooleanWeight {
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
}
impl BooleanWeight {
pub fn new(weights: Vec<(Occur, Box<dyn Weight>)>, scoring_enabled: bool) -> BooleanWeight {
impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
pub fn new(
weights: Vec<(Occur, Box<dyn Weight>)>,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
) -> BooleanWeight<TScoreCombiner> {
BooleanWeight {
weights,
scoring_enabled,
score_combiner_fn,
}
}
@@ -86,21 +101,23 @@ impl BooleanWeight {
Ok(per_occur_scorers)
}
fn complex_scorer<TScoreCombiner: ScoreCombiner>(
fn complex_scorer<TComplexScoreCombiner: ScoreCombiner>(
&self,
reader: &SegmentReader,
boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> {
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
.remove(&Occur::Should)
.map(scorer_union::<TScoreCombiner>);
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(scorer_union::<DoNothingCombiner>)
.map(into_box_scorer::<DoNothingCombiner>);
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
});
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Must)
@@ -112,10 +129,10 @@ impl BooleanWeight {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
Box<dyn Scorer>,
Box<dyn Scorer>,
TScoreCombiner,
TComplexScoreCombiner,
>::new(
must_scorer,
into_box_scorer::<TScoreCombiner>(should_scorer),
into_box_scorer(should_scorer, &score_combiner_fn),
)))
} else {
SpecializedScorer::Other(must_scorer)
@@ -129,8 +146,7 @@ impl BooleanWeight {
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed: Box<dyn Scorer> =
into_box_scorer::<TScoreCombiner>(positive_scorer);
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
exclude_scorer,
@@ -141,7 +157,7 @@ impl BooleanWeight {
}
}
impl Weight for BooleanWeight {
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
if self.weights.is_empty() {
Ok(Box::new(EmptyScorer))
@@ -153,13 +169,15 @@ impl Weight for BooleanWeight {
weight.scorer(reader, boost)
}
} else if self.scoring_enabled {
self.complex_scorer::<SumWithCoordsCombiner>(reader, boost)
self.complex_scorer(reader, boost, &self.score_combiner_fn)
.map(|specialized_scorer| {
into_box_scorer::<SumWithCoordsCombiner>(specialized_scorer)
into_box_scorer(specialized_scorer, &self.score_combiner_fn)
})
} else {
self.complex_scorer::<DoNothingCombiner>(reader, boost)
.map(into_box_scorer::<DoNothingCombiner>)
self.complex_scorer(reader, boost, &DoNothingCombiner::default)
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, &DoNothingCombiner::default)
})
}
}
@@ -188,11 +206,10 @@ impl Weight for BooleanWeight {
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
Union::<TermScorer, SumWithCoordsCombiner>::from(term_scorers);
let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
@@ -218,7 +235,7 @@ impl Weight for BooleanWeight {
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);

View File

@@ -4,6 +4,7 @@ mod boolean_weight;
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub use self::boolean_query::BooleanQuery;
pub(crate) use self::boolean_weight::BooleanWeight;
#[cfg(test)]
mod tests {