Merge pull request #1428 from izihawa/feature/dismax

[feat] Implement `DisjunctionMaxQuery` and refactor `ScoreCombiner`
This commit is contained in:
PSeitz
2022-08-22 06:15:30 -07:00
committed by GitHub
8 changed files with 269 additions and 67 deletions

View File

@@ -371,7 +371,7 @@ mod tests {
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut scorer: Union<TermScorer, SumCombiner> = Union::from(term_scorers);
let mut scorer = Union::build(term_scorers, SumCombiner::default);
let mut limit = Score::MIN;
loop {

View File

@@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use super::boolean_weight::BooleanWeight;
use crate::query::{Occur, Query, TermQuery, Weight};
use crate::query::{Occur, Query, SumWithCoordsCombiner, TermQuery, Weight};
use crate::schema::{IndexRecordOption, Term};
use crate::Searcher;
@@ -153,7 +153,11 @@ impl Query for BooleanQuery {
Ok((*occur, subquery.weight(searcher, scoring_enabled)?))
})
.collect::<crate::Result<_>>()?;
Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled)))
Ok(Box::new(BooleanWeight::new(
sub_weights,
scoring_enabled,
Box::new(SumWithCoordsCombiner::default),
)))
}
fn query_terms(&self, terms: &mut BTreeMap<Term, bool>) {

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use crate::core::SegmentReader;
use crate::postings::FreqReadingOption;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
use crate::query::{
@@ -17,8 +17,13 @@ enum SpecializedScorer {
Other(Box<dyn Scorer>),
}
fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> SpecializedScorer
where TScoreCombiner: ScoreCombiner {
fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner,
) -> SpecializedScorer
where
TScoreCombiner: ScoreCombiner,
{
assert!(!scorers.is_empty());
if scorers.len() == 1 {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
@@ -38,35 +43,45 @@ where TScoreCombiner: ScoreCombiner {
// Block wand is only available if we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else {
return SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(
return SpecializedScorer::Other(Box::new(Union::build(
scorers,
score_combiner_fn,
)));
}
}
}
SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers)))
SpecializedScorer::Other(Box::new(Union::build(scorers, score_combiner_fn)))
}
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(scorer: SpecializedScorer) -> Box<dyn Scorer> {
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
scorer: SpecializedScorer,
score_combiner_fn: impl Fn() -> TScoreCombiner,
) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer = Union::<TermScorer, TScoreCombiner>::from(term_scorers);
let union_scorer = Union::build(term_scorers, score_combiner_fn);
Box::new(union_scorer)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
pub struct BooleanWeight {
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
}
impl BooleanWeight {
pub fn new(weights: Vec<(Occur, Box<dyn Weight>)>, scoring_enabled: bool) -> BooleanWeight {
impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
pub fn new(
weights: Vec<(Occur, Box<dyn Weight>)>,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
) -> BooleanWeight<TScoreCombiner> {
BooleanWeight {
weights,
scoring_enabled,
score_combiner_fn,
}
}
@@ -86,21 +101,23 @@ impl BooleanWeight {
Ok(per_occur_scorers)
}
fn complex_scorer<TScoreCombiner: ScoreCombiner>(
fn complex_scorer<TComplexScoreCombiner: ScoreCombiner>(
&self,
reader: &SegmentReader,
boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> {
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
.remove(&Occur::Should)
.map(scorer_union::<TScoreCombiner>);
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(scorer_union::<DoNothingCombiner>)
.map(into_box_scorer::<DoNothingCombiner>);
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
});
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Must)
@@ -112,10 +129,10 @@ impl BooleanWeight {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
Box<dyn Scorer>,
Box<dyn Scorer>,
TScoreCombiner,
TComplexScoreCombiner,
>::new(
must_scorer,
into_box_scorer::<TScoreCombiner>(should_scorer),
into_box_scorer(should_scorer, &score_combiner_fn),
)))
} else {
SpecializedScorer::Other(must_scorer)
@@ -129,8 +146,7 @@ impl BooleanWeight {
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed: Box<dyn Scorer> =
into_box_scorer::<TScoreCombiner>(positive_scorer);
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
exclude_scorer,
@@ -141,7 +157,7 @@ impl BooleanWeight {
}
}
impl Weight for BooleanWeight {
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
if self.weights.is_empty() {
Ok(Box::new(EmptyScorer))
@@ -153,13 +169,15 @@ impl Weight for BooleanWeight {
weight.scorer(reader, boost)
}
} else if self.scoring_enabled {
self.complex_scorer::<SumWithCoordsCombiner>(reader, boost)
self.complex_scorer(reader, boost, &self.score_combiner_fn)
.map(|specialized_scorer| {
into_box_scorer::<SumWithCoordsCombiner>(specialized_scorer)
into_box_scorer(specialized_scorer, &self.score_combiner_fn)
})
} else {
self.complex_scorer::<DoNothingCombiner>(reader, boost)
.map(into_box_scorer::<DoNothingCombiner>)
self.complex_scorer(reader, boost, &DoNothingCombiner::default)
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, &DoNothingCombiner::default)
})
}
}
@@ -188,11 +206,10 @@ impl Weight for BooleanWeight {
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
Union::<TermScorer, SumWithCoordsCombiner>::from(term_scorers);
let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
@@ -218,7 +235,7 @@ impl Weight for BooleanWeight {
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);

View File

@@ -4,6 +4,7 @@ mod boolean_weight;
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub use self::boolean_query::BooleanQuery;
pub(crate) use self::boolean_weight::BooleanWeight;
#[cfg(test)]
mod tests {

View File

@@ -0,0 +1,133 @@
use std::collections::BTreeMap;
use tantivy_query_grammar::Occur;
use crate::query::{BooleanWeight, DisjunctionMaxCombiner, Query, Weight};
use crate::{Score, Searcher, Term};
/// The disjunction max query кeturns documents matching one or more wrapped queries,
/// called query clauses or clauses.
///
/// If a returned document matches multiple query clauses,
/// the `DisjunctionMaxQuery` assigns the document the highest relevance score from any matching
/// clause, plus a tie breaking increment for any additional matching subqueries.
///
/// ```rust
/// use tantivy::collector::TopDocs;
/// use tantivy::doc;
/// use tantivy::query::{DisjunctionMaxQuery, Query, QueryClone, TermQuery};
/// use tantivy::schema::{IndexRecordOption, Schema, TEXT};
/// use tantivy::Term;
/// use tantivy::Index;
///
/// fn main() -> tantivy::Result<()> {
/// let mut schema_builder = Schema::builder();
/// let title = schema_builder.add_text_field("title", TEXT);
/// let body = schema_builder.add_text_field("body", TEXT);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
/// {
/// let mut index_writer = index.writer(3_000_000)?;
/// index_writer.add_document(doc!(
/// title => "The Name of Girl",
/// ))?;
/// index_writer.add_document(doc!(
/// title => "The Diary of Muadib",
/// ))?;
/// index_writer.add_document(doc!(
/// title => "The Diary of Girl",
/// ))?;
/// index_writer.commit()?;
/// }
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// // Make TermQuery's for "girl" and "diary" in the title
/// let girl_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "girl"),
/// IndexRecordOption::Basic,
/// ));
/// let diary_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "diary"),
/// IndexRecordOption::Basic,
/// ));
///
/// // TermQuery "diary" and "girl" should be present and only one should be accounted in score
/// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
/// let diary_and_girl = DisjunctionMaxQuery::new(queries1);
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?;
/// assert_eq!(documents[0].0, documents[1].0);
/// assert_eq!(documents[1].0, documents[2].0);
///
/// // TermQuery "diary" and "girl" should be present
/// // and one should be accounted with multiplier 0.7
/// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
/// let tie_breaker = 0.7;
/// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker);
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?;
/// assert_eq!(documents[1].0, documents[2].0);
/// // For this test all terms brings the same score. So we can do easy math and assume that
/// // `DisjunctionMaxQuery` with tie breakers score should be equal
/// // to term1 score + `tie_breaker` * term2 score or (1.0 + tie_breaker) * term score
/// assert!(f32::abs(documents[0].0 - documents[1].0 * (1.0 + tie_breaker)) < 0.001);
/// Ok(())
/// }
/// ```
#[derive(Debug)]
pub struct DisjunctionMaxQuery {
disjuncts: Vec<Box<dyn Query>>,
tie_breaker: Score,
}
impl Clone for DisjunctionMaxQuery {
fn clone(&self) -> Self {
DisjunctionMaxQuery::with_tie_breaker(
self.disjuncts
.iter()
.map(|disjunct| disjunct.box_clone())
.collect::<Vec<_>>(),
self.tie_breaker,
)
}
}
impl Query for DisjunctionMaxQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> {
let disjuncts = self
.disjuncts
.iter()
.map(|disjunct| Ok((Occur::Should, disjunct.weight(searcher, scoring_enabled)?)))
.collect::<crate::Result<_>>()?;
let tie_breaker = self.tie_breaker;
Ok(Box::new(BooleanWeight::new(
disjuncts,
scoring_enabled,
Box::new(move || DisjunctionMaxCombiner::with_tie_breaker(tie_breaker)),
)))
}
fn query_terms(&self, terms: &mut BTreeMap<Term, bool>) {
for disjunct in &self.disjuncts {
disjunct.query_terms(terms);
}
}
}
impl DisjunctionMaxQuery {
/// Creates a new `DisjunctionMaxQuery` with tie breaker.
pub fn with_tie_breaker(
disjuncts: Vec<Box<dyn Query>>,
tie_breaker: Score,
) -> DisjunctionMaxQuery {
DisjunctionMaxQuery {
disjuncts,
tie_breaker,
}
}
/// Creates a new `DisjunctionMaxQuery` with no tie breaker.
pub fn new(disjuncts: Vec<Box<dyn Query>>) -> DisjunctionMaxQuery {
DisjunctionMaxQuery::with_tie_breaker(disjuncts, 0.0)
}
}

View File

@@ -6,6 +6,7 @@ mod bitset;
mod bm25;
mod boolean_query;
mod boost_query;
mod disjunction_max_query;
mod empty_query;
mod exclude;
mod explanation;
@@ -34,7 +35,9 @@ pub use self::automaton_weight::AutomatonWeight;
pub use self::bitset::BitSetDocSet;
pub(crate) use self::bm25::Bm25Weight;
pub use self::boolean_query::BooleanQuery;
pub(crate) use self::boolean_query::BooleanWeight;
pub use self::boost_query::BoostQuery;
pub use self::disjunction_max_query::DisjunctionMaxQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::Exclude;
pub use self::explanation::Explanation;
@@ -49,6 +52,9 @@ pub use self::query_parser::{QueryParser, QueryParserError};
pub use self::range_query::RangeQuery;
pub use self::regex_query::RegexQuery;
pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{
DisjunctionMaxCombiner, ScoreCombiner, SumCombiner, SumWithCoordsCombiner,
};
pub use self::scorer::{ConstScorer, Scorer};
pub use self::term_query::TermQuery;
pub use self::union::Union;

View File

@@ -77,3 +77,40 @@ impl ScoreCombiner for SumWithCoordsCombiner {
self.score
}
}
/// Take max score of different scorers
/// and optionally sum it with other matches multiplied by `tie_breaker`
#[derive(Default, Clone, Copy)]
pub struct DisjunctionMaxCombiner {
max: Score,
sum: Score,
tie_breaker: Score,
}
impl DisjunctionMaxCombiner {
/// Creates `DisjunctionMaxCombiner` with tie breaker
pub fn with_tie_breaker(tie_breaker: Score) -> DisjunctionMaxCombiner {
DisjunctionMaxCombiner {
max: 0.0,
sum: 0.0,
tie_breaker,
}
}
}
impl ScoreCombiner for DisjunctionMaxCombiner {
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
let score = scorer.score();
self.max = Score::max(score, self.max);
self.sum += score;
}
fn clear(&mut self) {
self.max = 0.0;
self.sum = 0.0;
}
fn score(&self) -> Score {
self.max + (self.sum - self.max) * self.tie_breaker
}
}

View File

@@ -36,34 +36,6 @@ pub struct Union<TScorer, TScoreCombiner = DoNothingCombiner> {
score: Score,
}
impl<TScorer, TScoreCombiner> From<Vec<TScorer>> for Union<TScorer, TScoreCombiner>
where
TScoreCombiner: ScoreCombiner,
TScorer: Scorer,
{
fn from(docsets: Vec<TScorer>) -> Union<TScorer, TScoreCombiner> {
let non_empty_docsets: Vec<TScorer> = docsets
.into_iter()
.filter(|docset| docset.doc() != TERMINATED)
.collect();
let mut union = Union {
docsets: non_empty_docsets,
bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]),
scores: Box::new([TScoreCombiner::default(); HORIZON as usize]),
cursor: HORIZON_NUM_TINYBITSETS,
offset: 0,
doc: 0,
score: 0.0,
};
if union.refill() {
union.advance();
} else {
union.doc = TERMINATED;
}
union
}
}
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorers: &mut Vec<TScorer>,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
@@ -90,6 +62,31 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Union<TScorer, TScoreCombiner> {
pub(crate) fn build(
docsets: Vec<TScorer>,
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
) -> Union<TScorer, TScoreCombiner> {
let non_empty_docsets: Vec<TScorer> = docsets
.into_iter()
.filter(|docset| docset.doc() != TERMINATED)
.collect();
let mut union = Union {
docsets: non_empty_docsets,
bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]),
scores: Box::new([score_combiner_fn(); HORIZON as usize]),
cursor: HORIZON_NUM_TINYBITSETS,
offset: 0,
doc: 0,
score: 0.0,
};
if union.refill() {
union.advance();
} else {
union.doc = TERMINATED;
}
union
}
fn refill(&mut self) -> bool {
if let Some(min_doc) = self.docsets.iter().map(DocSet::doc).min() {
self.offset = min_doc;
@@ -265,12 +262,13 @@ mod tests {
let union_vals: Vec<u32> = val_set.into_iter().collect();
let mut union_expected = VecDocSet::from(union_vals);
let make_union = || {
Union::from(
Union::build(
vals.iter()
.cloned()
.map(VecDocSet::from)
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<ConstScorer<VecDocSet>>>(),
DoNothingCombiner::default,
)
};
let mut union: Union<_, DoNothingCombiner> = make_union();
@@ -311,13 +309,14 @@ mod tests {
btree_set.extend(docs.iter().cloned());
}
let docset_factory = || {
let res: Box<dyn DocSet> = Box::new(Union::<_, DoNothingCombiner>::from(
let res: Box<dyn DocSet> = Box::new(Union::build(
docs_list
.iter()
.cloned()
.map(VecDocSet::from)
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(),
DoNothingCombiner::default,
));
res
};
@@ -345,10 +344,13 @@ mod tests {
#[test]
fn test_union_skip_corner_case3() {
let mut docset = Union::<_, DoNothingCombiner>::from(vec![
ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])),
ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])),
]);
let mut docset = Union::build(
vec![
ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])),
ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])),
],
DoNothingCombiner::default,
);
assert_eq!(docset.doc(), 0u32);
assert_eq!(docset.seek(0u32), 0u32);
assert_eq!(docset.seek(0u32), 0u32);
@@ -404,12 +406,13 @@ mod bench {
tests::sample_with_seed(100_000, 0.2, 1),
];
bench.iter(|| {
let mut v = Union::<_, DoNothingCombiner>::from(
let mut v = Union::build(
union_docset
.iter()
.map(|doc_ids| VecDocSet::from(doc_ids.clone()))
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(),
DoNothingCombiner::default,
);
while v.doc() != TERMINATED {
v.advance();
@@ -424,12 +427,13 @@ mod bench {
tests::sample_with_seed(100_000, 0.001, 2),
];
bench.iter(|| {
let mut v = Union::<_, DoNothingCombiner>::from(
let mut v = Union::build(
union_docset
.iter()
.map(|doc_ids| VecDocSet::from(doc_ids.clone()))
.map(|docset| ConstScorer::new(docset, 1.0))
.collect::<Vec<_>>(),
DoNothingCombiner::default,
);
while v.doc() != TERMINATED {
v.advance();