diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index 92bb2e8b1..9a882424b 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -371,7 +371,7 @@ mod tests { fn compute_checkpoints_manual(term_scorers: Vec, n: usize) -> Vec<(DocId, Score)> { let mut heap: BinaryHeap = BinaryHeap::with_capacity(n); let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); - let mut scorer: Union = Union::from(term_scorers); + let mut scorer = Union::build(term_scorers, SumCombiner::default); let mut limit = Score::MIN; loop { diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index 9b30eec04..2e19b80ed 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use super::boolean_weight::BooleanWeight; -use crate::query::{Occur, Query, TermQuery, Weight}; +use crate::query::{Occur, Query, SumWithCoordsCombiner, TermQuery, Weight}; use crate::schema::{IndexRecordOption, Term}; use crate::Searcher; @@ -153,7 +153,11 @@ impl Query for BooleanQuery { Ok((*occur, subquery.weight(searcher, scoring_enabled)?)) }) .collect::>()?; - Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled))) + Ok(Box::new(BooleanWeight::new( + sub_weights, + scoring_enabled, + Box::new(SumWithCoordsCombiner::default), + ))) } fn query_terms(&self, terms: &mut BTreeMap) { diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index f2a5fd376..2b7406a65 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use crate::core::SegmentReader; use crate::postings::FreqReadingOption; use crate::query::explanation::does_not_match; -use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner}; +use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::term_query::TermScorer; use crate::query::weight::{for_each_pruning_scorer, for_each_scorer}; use crate::query::{ @@ -17,8 +17,13 @@ enum SpecializedScorer { Other(Box), } -fn scorer_union(scorers: Vec>) -> SpecializedScorer -where TScoreCombiner: ScoreCombiner { +fn scorer_union( + scorers: Vec>, + score_combiner_fn: impl Fn() -> TScoreCombiner, +) -> SpecializedScorer +where + TScoreCombiner: ScoreCombiner, +{ assert!(!scorers.is_empty()); if scorers.len() == 1 { return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehands @@ -38,35 +43,45 @@ where TScoreCombiner: ScoreCombiner { // Block wand is only available if we read frequencies. return SpecializedScorer::TermUnion(scorers); } else { - return SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from( + return SpecializedScorer::Other(Box::new(Union::build( scorers, + score_combiner_fn, ))); } } } - SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers))) + SpecializedScorer::Other(Box::new(Union::build(scorers, score_combiner_fn))) } -fn into_box_scorer(scorer: SpecializedScorer) -> Box { +fn into_box_scorer( + scorer: SpecializedScorer, + score_combiner_fn: impl Fn() -> TScoreCombiner, +) -> Box { match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let union_scorer = Union::::from(term_scorers); + let union_scorer = Union::build(term_scorers, score_combiner_fn); Box::new(union_scorer) } SpecializedScorer::Other(scorer) => scorer, } } -pub struct BooleanWeight { +pub struct BooleanWeight { weights: Vec<(Occur, Box)>, scoring_enabled: bool, + score_combiner_fn: Box TScoreCombiner + Sync + Send>, } -impl BooleanWeight { - pub fn new(weights: Vec<(Occur, Box)>, scoring_enabled: bool) -> BooleanWeight { +impl BooleanWeight { + pub fn new( + weights: Vec<(Occur, Box)>, + scoring_enabled: bool, + score_combiner_fn: Box TScoreCombiner + Sync + Send + 'static>, + ) -> BooleanWeight { BooleanWeight { weights, scoring_enabled, + score_combiner_fn, } } @@ -86,21 +101,23 @@ impl BooleanWeight { Ok(per_occur_scorers) } - fn complex_scorer( + fn complex_scorer( &self, reader: &SegmentReader, boost: Score, + score_combiner_fn: impl Fn() -> TComplexScoreCombiner, ) -> crate::Result { let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; let should_scorer_opt: Option = per_occur_scorers .remove(&Occur::Should) - .map(scorer_union::); - + .map(|scorers| scorer_union(scorers, &score_combiner_fn)); let exclude_scorer_opt: Option> = per_occur_scorers .remove(&Occur::MustNot) - .map(scorer_union::) - .map(into_box_scorer::); + .map(|scorers| scorer_union(scorers, DoNothingCombiner::default)) + .map(|specialized_scorer| { + into_box_scorer(specialized_scorer, DoNothingCombiner::default) + }); let must_scorer_opt: Option> = per_occur_scorers .remove(&Occur::Must) @@ -112,10 +129,10 @@ impl BooleanWeight { SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< Box, Box, - TScoreCombiner, + TComplexScoreCombiner, >::new( must_scorer, - into_box_scorer::(should_scorer), + into_box_scorer(should_scorer, &score_combiner_fn), ))) } else { SpecializedScorer::Other(must_scorer) @@ -129,8 +146,7 @@ impl BooleanWeight { }; if let Some(exclude_scorer) = exclude_scorer_opt { - let positive_scorer_boxed: Box = - into_box_scorer::(positive_scorer); + let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn); Ok(SpecializedScorer::Other(Box::new(Exclude::new( positive_scorer_boxed, exclude_scorer, @@ -141,7 +157,7 @@ impl BooleanWeight { } } -impl Weight for BooleanWeight { +impl Weight for BooleanWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { if self.weights.is_empty() { Ok(Box::new(EmptyScorer)) @@ -153,13 +169,15 @@ impl Weight for BooleanWeight { weight.scorer(reader, boost) } } else if self.scoring_enabled { - self.complex_scorer::(reader, boost) + self.complex_scorer(reader, boost, &self.score_combiner_fn) .map(|specialized_scorer| { - into_box_scorer::(specialized_scorer) + into_box_scorer(specialized_scorer, &self.score_combiner_fn) }) } else { - self.complex_scorer::(reader, boost) - .map(into_box_scorer::) + self.complex_scorer(reader, boost, &DoNothingCombiner::default) + .map(|specialized_scorer| { + into_box_scorer(specialized_scorer, &DoNothingCombiner::default) + }) } } @@ -188,11 +206,10 @@ impl Weight for BooleanWeight { reader: &SegmentReader, callback: &mut dyn FnMut(DocId, Score), ) -> crate::Result<()> { - let scorer = self.complex_scorer::(reader, 1.0)?; + let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?; match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let mut union_scorer = - Union::::from(term_scorers); + let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn); for_each_scorer(&mut union_scorer, callback); } SpecializedScorer::Other(mut scorer) => { @@ -218,7 +235,7 @@ impl Weight for BooleanWeight { reader: &SegmentReader, callback: &mut dyn FnMut(DocId, Score) -> Score, ) -> crate::Result<()> { - let scorer = self.complex_scorer::(reader, 1.0)?; + let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?; match scorer { SpecializedScorer::TermUnion(term_scorers) => { super::block_wand(term_scorers, threshold, callback); diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index a5a1c710b..f596a4b66 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -4,6 +4,7 @@ mod boolean_weight; pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer}; pub use self::boolean_query::BooleanQuery; +pub(crate) use self::boolean_weight::BooleanWeight; #[cfg(test)] mod tests { diff --git a/src/query/disjunction_max_query.rs b/src/query/disjunction_max_query.rs new file mode 100644 index 000000000..5a73ddf82 --- /dev/null +++ b/src/query/disjunction_max_query.rs @@ -0,0 +1,133 @@ +use std::collections::BTreeMap; + +use tantivy_query_grammar::Occur; + +use crate::query::{BooleanWeight, DisjunctionMaxCombiner, Query, Weight}; +use crate::{Score, Searcher, Term}; + +/// The disjunction max query кeturns documents matching one or more wrapped queries, +/// called query clauses or clauses. +/// +/// If a returned document matches multiple query clauses, +/// the `DisjunctionMaxQuery` assigns the document the highest relevance score from any matching +/// clause, plus a tie breaking increment for any additional matching subqueries. +/// +/// ```rust +/// use tantivy::collector::TopDocs; +/// use tantivy::doc; +/// use tantivy::query::{DisjunctionMaxQuery, Query, QueryClone, TermQuery}; +/// use tantivy::schema::{IndexRecordOption, Schema, TEXT}; +/// use tantivy::Term; +/// use tantivy::Index; +/// +/// fn main() -> tantivy::Result<()> { +/// let mut schema_builder = Schema::builder(); +/// let title = schema_builder.add_text_field("title", TEXT); +/// let body = schema_builder.add_text_field("body", TEXT); +/// let schema = schema_builder.build(); +/// let index = Index::create_in_ram(schema); +/// { +/// let mut index_writer = index.writer(3_000_000)?; +/// index_writer.add_document(doc!( +/// title => "The Name of Girl", +/// ))?; +/// index_writer.add_document(doc!( +/// title => "The Diary of Muadib", +/// ))?; +/// index_writer.add_document(doc!( +/// title => "The Diary of Girl", +/// ))?; +/// index_writer.commit()?; +/// } +/// +/// let reader = index.reader()?; +/// let searcher = reader.searcher(); +/// +/// // Make TermQuery's for "girl" and "diary" in the title +/// let girl_term_query: Box = Box::new(TermQuery::new( +/// Term::from_field_text(title, "girl"), +/// IndexRecordOption::Basic, +/// )); +/// let diary_term_query: Box = Box::new(TermQuery::new( +/// Term::from_field_text(title, "diary"), +/// IndexRecordOption::Basic, +/// )); +/// +/// // TermQuery "diary" and "girl" should be present and only one should be accounted in score +/// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()]; +/// let diary_and_girl = DisjunctionMaxQuery::new(queries1); +/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?; +/// assert_eq!(documents[0].0, documents[1].0); +/// assert_eq!(documents[1].0, documents[2].0); +/// +/// // TermQuery "diary" and "girl" should be present +/// // and one should be accounted with multiplier 0.7 +/// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()]; +/// let tie_breaker = 0.7; +/// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker); +/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?; +/// assert_eq!(documents[1].0, documents[2].0); +/// // For this test all terms brings the same score. So we can do easy math and assume that +/// // `DisjunctionMaxQuery` with tie breakers score should be equal +/// // to term1 score + `tie_breaker` * term2 score or (1.0 + tie_breaker) * term score +/// assert!(f32::abs(documents[0].0 - documents[1].0 * (1.0 + tie_breaker)) < 0.001); +/// Ok(()) +/// } +/// ``` +#[derive(Debug)] +pub struct DisjunctionMaxQuery { + disjuncts: Vec>, + tie_breaker: Score, +} + +impl Clone for DisjunctionMaxQuery { + fn clone(&self) -> Self { + DisjunctionMaxQuery::with_tie_breaker( + self.disjuncts + .iter() + .map(|disjunct| disjunct.box_clone()) + .collect::>(), + self.tie_breaker, + ) + } +} + +impl Query for DisjunctionMaxQuery { + fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result> { + let disjuncts = self + .disjuncts + .iter() + .map(|disjunct| Ok((Occur::Should, disjunct.weight(searcher, scoring_enabled)?))) + .collect::>()?; + let tie_breaker = self.tie_breaker; + Ok(Box::new(BooleanWeight::new( + disjuncts, + scoring_enabled, + Box::new(move || DisjunctionMaxCombiner::with_tie_breaker(tie_breaker)), + ))) + } + + fn query_terms(&self, terms: &mut BTreeMap) { + for disjunct in &self.disjuncts { + disjunct.query_terms(terms); + } + } +} + +impl DisjunctionMaxQuery { + /// Creates a new `DisjunctionMaxQuery` with tie breaker. + pub fn with_tie_breaker( + disjuncts: Vec>, + tie_breaker: Score, + ) -> DisjunctionMaxQuery { + DisjunctionMaxQuery { + disjuncts, + tie_breaker, + } + } + + /// Creates a new `DisjunctionMaxQuery` with no tie breaker. + pub fn new(disjuncts: Vec>) -> DisjunctionMaxQuery { + DisjunctionMaxQuery::with_tie_breaker(disjuncts, 0.0) + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index 759bc2e72..a585503e0 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -6,6 +6,7 @@ mod bitset; mod bm25; mod boolean_query; mod boost_query; +mod disjunction_max_query; mod empty_query; mod exclude; mod explanation; @@ -34,7 +35,9 @@ pub use self::automaton_weight::AutomatonWeight; pub use self::bitset::BitSetDocSet; pub(crate) use self::bm25::Bm25Weight; pub use self::boolean_query::BooleanQuery; +pub(crate) use self::boolean_query::BooleanWeight; pub use self::boost_query::BoostQuery; +pub use self::disjunction_max_query::DisjunctionMaxQuery; pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight}; pub use self::exclude::Exclude; pub use self::explanation::Explanation; @@ -49,6 +52,9 @@ pub use self::query_parser::{QueryParser, QueryParserError}; pub use self::range_query::RangeQuery; pub use self::regex_query::RegexQuery; pub use self::reqopt_scorer::RequiredOptionalScorer; +pub use self::score_combiner::{ + DisjunctionMaxCombiner, ScoreCombiner, SumCombiner, SumWithCoordsCombiner, +}; pub use self::scorer::{ConstScorer, Scorer}; pub use self::term_query::TermQuery; pub use self::union::Union; diff --git a/src/query/score_combiner.rs b/src/query/score_combiner.rs index 5fd68af45..449badbad 100644 --- a/src/query/score_combiner.rs +++ b/src/query/score_combiner.rs @@ -77,3 +77,40 @@ impl ScoreCombiner for SumWithCoordsCombiner { self.score } } + +/// Take max score of different scorers +/// and optionally sum it with other matches multiplied by `tie_breaker` +#[derive(Default, Clone, Copy)] +pub struct DisjunctionMaxCombiner { + max: Score, + sum: Score, + tie_breaker: Score, +} + +impl DisjunctionMaxCombiner { + /// Creates `DisjunctionMaxCombiner` with tie breaker + pub fn with_tie_breaker(tie_breaker: Score) -> DisjunctionMaxCombiner { + DisjunctionMaxCombiner { + max: 0.0, + sum: 0.0, + tie_breaker, + } + } +} + +impl ScoreCombiner for DisjunctionMaxCombiner { + fn update(&mut self, scorer: &mut TScorer) { + let score = scorer.score(); + self.max = Score::max(score, self.max); + self.sum += score; + } + + fn clear(&mut self) { + self.max = 0.0; + self.sum = 0.0; + } + + fn score(&self) -> Score { + self.max + (self.sum - self.max) * self.tie_breaker + } +} diff --git a/src/query/union.rs b/src/query/union.rs index aa6ee75e8..5a7bc7048 100644 --- a/src/query/union.rs +++ b/src/query/union.rs @@ -36,34 +36,6 @@ pub struct Union { score: Score, } -impl From> for Union -where - TScoreCombiner: ScoreCombiner, - TScorer: Scorer, -{ - fn from(docsets: Vec) -> Union { - let non_empty_docsets: Vec = docsets - .into_iter() - .filter(|docset| docset.doc() != TERMINATED) - .collect(); - let mut union = Union { - docsets: non_empty_docsets, - bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]), - scores: Box::new([TScoreCombiner::default(); HORIZON as usize]), - cursor: HORIZON_NUM_TINYBITSETS, - offset: 0, - doc: 0, - score: 0.0, - }; - if union.refill() { - union.advance(); - } else { - union.doc = TERMINATED; - } - union - } -} - fn refill( scorers: &mut Vec, bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], @@ -90,6 +62,31 @@ fn refill( } impl Union { + pub(crate) fn build( + docsets: Vec, + score_combiner_fn: impl Fn() -> TScoreCombiner, + ) -> Union { + let non_empty_docsets: Vec = docsets + .into_iter() + .filter(|docset| docset.doc() != TERMINATED) + .collect(); + let mut union = Union { + docsets: non_empty_docsets, + bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]), + scores: Box::new([score_combiner_fn(); HORIZON as usize]), + cursor: HORIZON_NUM_TINYBITSETS, + offset: 0, + doc: 0, + score: 0.0, + }; + if union.refill() { + union.advance(); + } else { + union.doc = TERMINATED; + } + union + } + fn refill(&mut self) -> bool { if let Some(min_doc) = self.docsets.iter().map(DocSet::doc).min() { self.offset = min_doc; @@ -266,12 +263,13 @@ mod tests { let union_vals: Vec = val_set.into_iter().collect(); let mut union_expected = VecDocSet::from(union_vals); let make_union = || { - Union::from( + Union::build( vals.iter() .cloned() .map(VecDocSet::from) .map(|docset| ConstScorer::new(docset, 1.0)) .collect::>>(), + DoNothingCombiner::default, ) }; let mut union: Union<_, DoNothingCombiner> = make_union(); @@ -312,13 +310,14 @@ mod tests { btree_set.extend(docs.iter().cloned()); } let docset_factory = || { - let res: Box = Box::new(Union::<_, DoNothingCombiner>::from( + let res: Box = Box::new(Union::build( docs_list .iter() .cloned() .map(VecDocSet::from) .map(|docset| ConstScorer::new(docset, 1.0)) .collect::>(), + DoNothingCombiner::default, )); res }; @@ -346,10 +345,13 @@ mod tests { #[test] fn test_union_skip_corner_case3() { - let mut docset = Union::<_, DoNothingCombiner>::from(vec![ - ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), - ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), - ]); + let mut docset = Union::build( + vec![ + ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), + ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), + ], + DoNothingCombiner::default, + ); assert_eq!(docset.doc(), 0u32); assert_eq!(docset.seek(0u32), 0u32); assert_eq!(docset.seek(0u32), 0u32);