mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-06-04 09:30:42 +00:00
Merge pull request #1428 from izihawa/feature/dismax
[feat] Implement `DisjunctionMaxQuery` and refactor `ScoreCombiner`
This commit is contained in:
@@ -371,7 +371,7 @@ mod tests {
|
||||
fn compute_checkpoints_manual(term_scorers: Vec<TermScorer>, n: usize) -> Vec<(DocId, Score)> {
|
||||
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(n);
|
||||
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
|
||||
let mut scorer: Union<TermScorer, SumCombiner> = Union::from(term_scorers);
|
||||
let mut scorer = Union::build(term_scorers, SumCombiner::default);
|
||||
|
||||
let mut limit = Score::MIN;
|
||||
loop {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use super::boolean_weight::BooleanWeight;
|
||||
use crate::query::{Occur, Query, TermQuery, Weight};
|
||||
use crate::query::{Occur, Query, SumWithCoordsCombiner, TermQuery, Weight};
|
||||
use crate::schema::{IndexRecordOption, Term};
|
||||
use crate::Searcher;
|
||||
|
||||
@@ -153,7 +153,11 @@ impl Query for BooleanQuery {
|
||||
Ok((*occur, subquery.weight(searcher, scoring_enabled)?))
|
||||
})
|
||||
.collect::<crate::Result<_>>()?;
|
||||
Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled)))
|
||||
Ok(Box::new(BooleanWeight::new(
|
||||
sub_weights,
|
||||
scoring_enabled,
|
||||
Box::new(SumWithCoordsCombiner::default),
|
||||
)))
|
||||
}
|
||||
|
||||
fn query_terms(&self, terms: &mut BTreeMap<Term, bool>) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::collections::HashMap;
|
||||
use crate::core::SegmentReader;
|
||||
use crate::postings::FreqReadingOption;
|
||||
use crate::query::explanation::does_not_match;
|
||||
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
|
||||
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
||||
use crate::query::term_query::TermScorer;
|
||||
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
|
||||
use crate::query::{
|
||||
@@ -17,8 +17,13 @@ enum SpecializedScorer {
|
||||
Other(Box<dyn Scorer>),
|
||||
}
|
||||
|
||||
fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> SpecializedScorer
|
||||
where TScoreCombiner: ScoreCombiner {
|
||||
fn scorer_union<TScoreCombiner>(
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
) -> SpecializedScorer
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
{
|
||||
assert!(!scorers.is_empty());
|
||||
if scorers.len() == 1 {
|
||||
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
|
||||
@@ -38,35 +43,45 @@ where TScoreCombiner: ScoreCombiner {
|
||||
// Block wand is only available if we read frequencies.
|
||||
return SpecializedScorer::TermUnion(scorers);
|
||||
} else {
|
||||
return SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(
|
||||
return SpecializedScorer::Other(Box::new(Union::build(
|
||||
scorers,
|
||||
score_combiner_fn,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers)))
|
||||
SpecializedScorer::Other(Box::new(Union::build(scorers, score_combiner_fn)))
|
||||
}
|
||||
|
||||
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(scorer: SpecializedScorer) -> Box<dyn Scorer> {
|
||||
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
scorer: SpecializedScorer,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
) -> Box<dyn Scorer> {
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let union_scorer = Union::<TermScorer, TScoreCombiner>::from(term_scorers);
|
||||
let union_scorer = Union::build(term_scorers, score_combiner_fn);
|
||||
Box::new(union_scorer)
|
||||
}
|
||||
SpecializedScorer::Other(scorer) => scorer,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BooleanWeight {
|
||||
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
scoring_enabled: bool,
|
||||
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
|
||||
}
|
||||
|
||||
impl BooleanWeight {
|
||||
pub fn new(weights: Vec<(Occur, Box<dyn Weight>)>, scoring_enabled: bool) -> BooleanWeight {
|
||||
impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
pub fn new(
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
scoring_enabled: bool,
|
||||
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
|
||||
) -> BooleanWeight<TScoreCombiner> {
|
||||
BooleanWeight {
|
||||
weights,
|
||||
scoring_enabled,
|
||||
score_combiner_fn,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,21 +101,23 @@ impl BooleanWeight {
|
||||
Ok(per_occur_scorers)
|
||||
}
|
||||
|
||||
fn complex_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
fn complex_scorer<TComplexScoreCombiner: ScoreCombiner>(
|
||||
&self,
|
||||
reader: &SegmentReader,
|
||||
boost: Score,
|
||||
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
|
||||
) -> crate::Result<SpecializedScorer> {
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||
|
||||
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
|
||||
.remove(&Occur::Should)
|
||||
.map(scorer_union::<TScoreCombiner>);
|
||||
|
||||
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
|
||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::MustNot)
|
||||
.map(scorer_union::<DoNothingCombiner>)
|
||||
.map(into_box_scorer::<DoNothingCombiner>);
|
||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
|
||||
.map(|specialized_scorer| {
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
||||
});
|
||||
|
||||
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::Must)
|
||||
@@ -112,10 +129,10 @@ impl BooleanWeight {
|
||||
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
|
||||
Box<dyn Scorer>,
|
||||
Box<dyn Scorer>,
|
||||
TScoreCombiner,
|
||||
TComplexScoreCombiner,
|
||||
>::new(
|
||||
must_scorer,
|
||||
into_box_scorer::<TScoreCombiner>(should_scorer),
|
||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
||||
)))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
@@ -129,8 +146,7 @@ impl BooleanWeight {
|
||||
};
|
||||
|
||||
if let Some(exclude_scorer) = exclude_scorer_opt {
|
||||
let positive_scorer_boxed: Box<dyn Scorer> =
|
||||
into_box_scorer::<TScoreCombiner>(positive_scorer);
|
||||
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
|
||||
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
||||
positive_scorer_boxed,
|
||||
exclude_scorer,
|
||||
@@ -141,7 +157,7 @@ impl BooleanWeight {
|
||||
}
|
||||
}
|
||||
|
||||
impl Weight for BooleanWeight {
|
||||
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
if self.weights.is_empty() {
|
||||
Ok(Box::new(EmptyScorer))
|
||||
@@ -153,13 +169,15 @@ impl Weight for BooleanWeight {
|
||||
weight.scorer(reader, boost)
|
||||
}
|
||||
} else if self.scoring_enabled {
|
||||
self.complex_scorer::<SumWithCoordsCombiner>(reader, boost)
|
||||
self.complex_scorer(reader, boost, &self.score_combiner_fn)
|
||||
.map(|specialized_scorer| {
|
||||
into_box_scorer::<SumWithCoordsCombiner>(specialized_scorer)
|
||||
into_box_scorer(specialized_scorer, &self.score_combiner_fn)
|
||||
})
|
||||
} else {
|
||||
self.complex_scorer::<DoNothingCombiner>(reader, boost)
|
||||
.map(into_box_scorer::<DoNothingCombiner>)
|
||||
self.complex_scorer(reader, boost, &DoNothingCombiner::default)
|
||||
.map(|specialized_scorer| {
|
||||
into_box_scorer(specialized_scorer, &DoNothingCombiner::default)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,11 +206,10 @@ impl Weight for BooleanWeight {
|
||||
reader: &SegmentReader,
|
||||
callback: &mut dyn FnMut(DocId, Score),
|
||||
) -> crate::Result<()> {
|
||||
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
|
||||
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
let mut union_scorer =
|
||||
Union::<TermScorer, SumWithCoordsCombiner>::from(term_scorers);
|
||||
let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn);
|
||||
for_each_scorer(&mut union_scorer, callback);
|
||||
}
|
||||
SpecializedScorer::Other(mut scorer) => {
|
||||
@@ -218,7 +235,7 @@ impl Weight for BooleanWeight {
|
||||
reader: &SegmentReader,
|
||||
callback: &mut dyn FnMut(DocId, Score) -> Score,
|
||||
) -> crate::Result<()> {
|
||||
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
|
||||
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
|
||||
match scorer {
|
||||
SpecializedScorer::TermUnion(term_scorers) => {
|
||||
super::block_wand(term_scorers, threshold, callback);
|
||||
|
||||
@@ -4,6 +4,7 @@ mod boolean_weight;
|
||||
|
||||
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
|
||||
pub use self::boolean_query::BooleanQuery;
|
||||
pub(crate) use self::boolean_weight::BooleanWeight;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
133
src/query/disjunction_max_query.rs
Normal file
133
src/query/disjunction_max_query.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use tantivy_query_grammar::Occur;
|
||||
|
||||
use crate::query::{BooleanWeight, DisjunctionMaxCombiner, Query, Weight};
|
||||
use crate::{Score, Searcher, Term};
|
||||
|
||||
/// The disjunction max query кeturns documents matching one or more wrapped queries,
|
||||
/// called query clauses or clauses.
|
||||
///
|
||||
/// If a returned document matches multiple query clauses,
|
||||
/// the `DisjunctionMaxQuery` assigns the document the highest relevance score from any matching
|
||||
/// clause, plus a tie breaking increment for any additional matching subqueries.
|
||||
///
|
||||
/// ```rust
|
||||
/// use tantivy::collector::TopDocs;
|
||||
/// use tantivy::doc;
|
||||
/// use tantivy::query::{DisjunctionMaxQuery, Query, QueryClone, TermQuery};
|
||||
/// use tantivy::schema::{IndexRecordOption, Schema, TEXT};
|
||||
/// use tantivy::Term;
|
||||
/// use tantivy::Index;
|
||||
///
|
||||
/// fn main() -> tantivy::Result<()> {
|
||||
/// let mut schema_builder = Schema::builder();
|
||||
/// let title = schema_builder.add_text_field("title", TEXT);
|
||||
/// let body = schema_builder.add_text_field("body", TEXT);
|
||||
/// let schema = schema_builder.build();
|
||||
/// let index = Index::create_in_ram(schema);
|
||||
/// {
|
||||
/// let mut index_writer = index.writer(3_000_000)?;
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Name of Girl",
|
||||
/// ))?;
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of Muadib",
|
||||
/// ))?;
|
||||
/// index_writer.add_document(doc!(
|
||||
/// title => "The Diary of Girl",
|
||||
/// ))?;
|
||||
/// index_writer.commit()?;
|
||||
/// }
|
||||
///
|
||||
/// let reader = index.reader()?;
|
||||
/// let searcher = reader.searcher();
|
||||
///
|
||||
/// // Make TermQuery's for "girl" and "diary" in the title
|
||||
/// let girl_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(title, "girl"),
|
||||
/// IndexRecordOption::Basic,
|
||||
/// ));
|
||||
/// let diary_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(title, "diary"),
|
||||
/// IndexRecordOption::Basic,
|
||||
/// ));
|
||||
///
|
||||
/// // TermQuery "diary" and "girl" should be present and only one should be accounted in score
|
||||
/// let queries1 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
|
||||
/// let diary_and_girl = DisjunctionMaxQuery::new(queries1);
|
||||
/// let documents = searcher.search(&diary_and_girl, &TopDocs::with_limit(3))?;
|
||||
/// assert_eq!(documents[0].0, documents[1].0);
|
||||
/// assert_eq!(documents[1].0, documents[2].0);
|
||||
///
|
||||
/// // TermQuery "diary" and "girl" should be present
|
||||
/// // and one should be accounted with multiplier 0.7
|
||||
/// let queries2 = vec![diary_term_query.box_clone(), girl_term_query.box_clone()];
|
||||
/// let tie_breaker = 0.7;
|
||||
/// let diary_and_girl_with_tie_breaker = DisjunctionMaxQuery::with_tie_breaker(queries2, tie_breaker);
|
||||
/// let documents = searcher.search(&diary_and_girl_with_tie_breaker, &TopDocs::with_limit(3))?;
|
||||
/// assert_eq!(documents[1].0, documents[2].0);
|
||||
/// // For this test all terms brings the same score. So we can do easy math and assume that
|
||||
/// // `DisjunctionMaxQuery` with tie breakers score should be equal
|
||||
/// // to term1 score + `tie_breaker` * term2 score or (1.0 + tie_breaker) * term score
|
||||
/// assert!(f32::abs(documents[0].0 - documents[1].0 * (1.0 + tie_breaker)) < 0.001);
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct DisjunctionMaxQuery {
|
||||
disjuncts: Vec<Box<dyn Query>>,
|
||||
tie_breaker: Score,
|
||||
}
|
||||
|
||||
impl Clone for DisjunctionMaxQuery {
|
||||
fn clone(&self) -> Self {
|
||||
DisjunctionMaxQuery::with_tie_breaker(
|
||||
self.disjuncts
|
||||
.iter()
|
||||
.map(|disjunct| disjunct.box_clone())
|
||||
.collect::<Vec<_>>(),
|
||||
self.tie_breaker,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Query for DisjunctionMaxQuery {
|
||||
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> {
|
||||
let disjuncts = self
|
||||
.disjuncts
|
||||
.iter()
|
||||
.map(|disjunct| Ok((Occur::Should, disjunct.weight(searcher, scoring_enabled)?)))
|
||||
.collect::<crate::Result<_>>()?;
|
||||
let tie_breaker = self.tie_breaker;
|
||||
Ok(Box::new(BooleanWeight::new(
|
||||
disjuncts,
|
||||
scoring_enabled,
|
||||
Box::new(move || DisjunctionMaxCombiner::with_tie_breaker(tie_breaker)),
|
||||
)))
|
||||
}
|
||||
|
||||
fn query_terms(&self, terms: &mut BTreeMap<Term, bool>) {
|
||||
for disjunct in &self.disjuncts {
|
||||
disjunct.query_terms(terms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisjunctionMaxQuery {
|
||||
/// Creates a new `DisjunctionMaxQuery` with tie breaker.
|
||||
pub fn with_tie_breaker(
|
||||
disjuncts: Vec<Box<dyn Query>>,
|
||||
tie_breaker: Score,
|
||||
) -> DisjunctionMaxQuery {
|
||||
DisjunctionMaxQuery {
|
||||
disjuncts,
|
||||
tie_breaker,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new `DisjunctionMaxQuery` with no tie breaker.
|
||||
pub fn new(disjuncts: Vec<Box<dyn Query>>) -> DisjunctionMaxQuery {
|
||||
DisjunctionMaxQuery::with_tie_breaker(disjuncts, 0.0)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ mod bitset;
|
||||
mod bm25;
|
||||
mod boolean_query;
|
||||
mod boost_query;
|
||||
mod disjunction_max_query;
|
||||
mod empty_query;
|
||||
mod exclude;
|
||||
mod explanation;
|
||||
@@ -34,7 +35,9 @@ pub use self::automaton_weight::AutomatonWeight;
|
||||
pub use self::bitset::BitSetDocSet;
|
||||
pub(crate) use self::bm25::Bm25Weight;
|
||||
pub use self::boolean_query::BooleanQuery;
|
||||
pub(crate) use self::boolean_query::BooleanWeight;
|
||||
pub use self::boost_query::BoostQuery;
|
||||
pub use self::disjunction_max_query::DisjunctionMaxQuery;
|
||||
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
|
||||
pub use self::exclude::Exclude;
|
||||
pub use self::explanation::Explanation;
|
||||
@@ -49,6 +52,9 @@ pub use self::query_parser::{QueryParser, QueryParserError};
|
||||
pub use self::range_query::RangeQuery;
|
||||
pub use self::regex_query::RegexQuery;
|
||||
pub use self::reqopt_scorer::RequiredOptionalScorer;
|
||||
pub use self::score_combiner::{
|
||||
DisjunctionMaxCombiner, ScoreCombiner, SumCombiner, SumWithCoordsCombiner,
|
||||
};
|
||||
pub use self::scorer::{ConstScorer, Scorer};
|
||||
pub use self::term_query::TermQuery;
|
||||
pub use self::union::Union;
|
||||
|
||||
@@ -77,3 +77,40 @@ impl ScoreCombiner for SumWithCoordsCombiner {
|
||||
self.score
|
||||
}
|
||||
}
|
||||
|
||||
/// Take max score of different scorers
|
||||
/// and optionally sum it with other matches multiplied by `tie_breaker`
|
||||
#[derive(Default, Clone, Copy)]
|
||||
pub struct DisjunctionMaxCombiner {
|
||||
max: Score,
|
||||
sum: Score,
|
||||
tie_breaker: Score,
|
||||
}
|
||||
|
||||
impl DisjunctionMaxCombiner {
|
||||
/// Creates `DisjunctionMaxCombiner` with tie breaker
|
||||
pub fn with_tie_breaker(tie_breaker: Score) -> DisjunctionMaxCombiner {
|
||||
DisjunctionMaxCombiner {
|
||||
max: 0.0,
|
||||
sum: 0.0,
|
||||
tie_breaker,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScoreCombiner for DisjunctionMaxCombiner {
|
||||
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
|
||||
let score = scorer.score();
|
||||
self.max = Score::max(score, self.max);
|
||||
self.sum += score;
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.max = 0.0;
|
||||
self.sum = 0.0;
|
||||
}
|
||||
|
||||
fn score(&self) -> Score {
|
||||
self.max + (self.sum - self.max) * self.tie_breaker
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,34 +36,6 @@ pub struct Union<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
score: Score,
|
||||
}
|
||||
|
||||
impl<TScorer, TScoreCombiner> From<Vec<TScorer>> for Union<TScorer, TScoreCombiner>
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
TScorer: Scorer,
|
||||
{
|
||||
fn from(docsets: Vec<TScorer>) -> Union<TScorer, TScoreCombiner> {
|
||||
let non_empty_docsets: Vec<TScorer> = docsets
|
||||
.into_iter()
|
||||
.filter(|docset| docset.doc() != TERMINATED)
|
||||
.collect();
|
||||
let mut union = Union {
|
||||
docsets: non_empty_docsets,
|
||||
bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]),
|
||||
scores: Box::new([TScoreCombiner::default(); HORIZON as usize]),
|
||||
cursor: HORIZON_NUM_TINYBITSETS,
|
||||
offset: 0,
|
||||
doc: 0,
|
||||
score: 0.0,
|
||||
};
|
||||
if union.refill() {
|
||||
union.advance();
|
||||
} else {
|
||||
union.doc = TERMINATED;
|
||||
}
|
||||
union
|
||||
}
|
||||
}
|
||||
|
||||
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
scorers: &mut Vec<TScorer>,
|
||||
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
|
||||
@@ -90,6 +62,31 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Union<TScorer, TScoreCombiner> {
|
||||
pub(crate) fn build(
|
||||
docsets: Vec<TScorer>,
|
||||
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
|
||||
) -> Union<TScorer, TScoreCombiner> {
|
||||
let non_empty_docsets: Vec<TScorer> = docsets
|
||||
.into_iter()
|
||||
.filter(|docset| docset.doc() != TERMINATED)
|
||||
.collect();
|
||||
let mut union = Union {
|
||||
docsets: non_empty_docsets,
|
||||
bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]),
|
||||
scores: Box::new([score_combiner_fn(); HORIZON as usize]),
|
||||
cursor: HORIZON_NUM_TINYBITSETS,
|
||||
offset: 0,
|
||||
doc: 0,
|
||||
score: 0.0,
|
||||
};
|
||||
if union.refill() {
|
||||
union.advance();
|
||||
} else {
|
||||
union.doc = TERMINATED;
|
||||
}
|
||||
union
|
||||
}
|
||||
|
||||
fn refill(&mut self) -> bool {
|
||||
if let Some(min_doc) = self.docsets.iter().map(DocSet::doc).min() {
|
||||
self.offset = min_doc;
|
||||
@@ -265,12 +262,13 @@ mod tests {
|
||||
let union_vals: Vec<u32> = val_set.into_iter().collect();
|
||||
let mut union_expected = VecDocSet::from(union_vals);
|
||||
let make_union = || {
|
||||
Union::from(
|
||||
Union::build(
|
||||
vals.iter()
|
||||
.cloned()
|
||||
.map(VecDocSet::from)
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<ConstScorer<VecDocSet>>>(),
|
||||
DoNothingCombiner::default,
|
||||
)
|
||||
};
|
||||
let mut union: Union<_, DoNothingCombiner> = make_union();
|
||||
@@ -311,13 +309,14 @@ mod tests {
|
||||
btree_set.extend(docs.iter().cloned());
|
||||
}
|
||||
let docset_factory = || {
|
||||
let res: Box<dyn DocSet> = Box::new(Union::<_, DoNothingCombiner>::from(
|
||||
let res: Box<dyn DocSet> = Box::new(Union::build(
|
||||
docs_list
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(VecDocSet::from)
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<_>>(),
|
||||
DoNothingCombiner::default,
|
||||
));
|
||||
res
|
||||
};
|
||||
@@ -345,10 +344,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_union_skip_corner_case3() {
|
||||
let mut docset = Union::<_, DoNothingCombiner>::from(vec![
|
||||
ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])),
|
||||
ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])),
|
||||
]);
|
||||
let mut docset = Union::build(
|
||||
vec![
|
||||
ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])),
|
||||
ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])),
|
||||
],
|
||||
DoNothingCombiner::default,
|
||||
);
|
||||
assert_eq!(docset.doc(), 0u32);
|
||||
assert_eq!(docset.seek(0u32), 0u32);
|
||||
assert_eq!(docset.seek(0u32), 0u32);
|
||||
@@ -404,12 +406,13 @@ mod bench {
|
||||
tests::sample_with_seed(100_000, 0.2, 1),
|
||||
];
|
||||
bench.iter(|| {
|
||||
let mut v = Union::<_, DoNothingCombiner>::from(
|
||||
let mut v = Union::build(
|
||||
union_docset
|
||||
.iter()
|
||||
.map(|doc_ids| VecDocSet::from(doc_ids.clone()))
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<_>>(),
|
||||
DoNothingCombiner::default,
|
||||
);
|
||||
while v.doc() != TERMINATED {
|
||||
v.advance();
|
||||
@@ -424,12 +427,13 @@ mod bench {
|
||||
tests::sample_with_seed(100_000, 0.001, 2),
|
||||
];
|
||||
bench.iter(|| {
|
||||
let mut v = Union::<_, DoNothingCombiner>::from(
|
||||
let mut v = Union::build(
|
||||
union_docset
|
||||
.iter()
|
||||
.map(|doc_ids| VecDocSet::from(doc_ids.clone()))
|
||||
.map(|docset| ConstScorer::new(docset, 1.0))
|
||||
.collect::<Vec<_>>(),
|
||||
DoNothingCombiner::default,
|
||||
);
|
||||
while v.doc() != TERMINATED {
|
||||
v.advance();
|
||||
|
||||
Reference in New Issue
Block a user