mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-04 16:22:55 +00:00
feat(query): Make BooleanQuery supports minimum_number_should_match (#2405)
* feat(query): Make `BooleanQuery` supports `minimum_number_should_match`. see issue #2398 In this commit, a novel scorer named DisjunctionScorer is introduced, which performs the union of inverted chains with the minimal required elements. BTW, it's implemented via a min-heap. Necessary modifications on `BooleanQuery` and `BooleanWeight` are performed as well. * fixup! fix test * fixup!: refactor code. 1. More meaningful names. 2. Add Cache for `Disjunction`'s scorers, and fix bug. 3. Optimize `BooleanWeight::complex_scorer` Thanks Paul Masurel <paul@quickwit.io> * squash!: come up with better variable naming. * squash!: fix naming issues. * squash!: fix typo. * squash!: Remove CombinationMethod::FullIntersection
This commit is contained in:
@@ -66,6 +66,10 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// Term::from_field_text(title, "diary"),
|
||||
/// IndexRecordOption::Basic,
|
||||
/// ));
|
||||
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(title, "cow"),
|
||||
/// IndexRecordOption::Basic
|
||||
/// ));
|
||||
/// // A TermQuery with "found" in the body
|
||||
/// let body_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(body, "found"),
|
||||
@@ -74,7 +78,7 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// // TermQuery "diary" must and "girl" must not be present
|
||||
/// let queries_with_occurs1 = vec![
|
||||
/// (Occur::Must, diary_term_query.box_clone()),
|
||||
/// (Occur::MustNot, girl_term_query),
|
||||
/// (Occur::MustNot, girl_term_query.box_clone()),
|
||||
/// ];
|
||||
/// // Make a BooleanQuery equivalent to
|
||||
/// // title:+diary title:-girl
|
||||
@@ -82,15 +86,10 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?;
|
||||
/// assert_eq!(count1, 1);
|
||||
///
|
||||
/// // TermQuery for "cow" in the title
|
||||
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(title, "cow"),
|
||||
/// IndexRecordOption::Basic,
|
||||
/// ));
|
||||
/// // "title:diary OR title:cow"
|
||||
/// let title_diary_or_cow = BooleanQuery::new(vec![
|
||||
/// (Occur::Should, diary_term_query.box_clone()),
|
||||
/// (Occur::Should, cow_term_query),
|
||||
/// (Occur::Should, cow_term_query.box_clone()),
|
||||
/// ]);
|
||||
/// let count2 = searcher.search(&title_diary_or_cow, &Count)?;
|
||||
/// assert_eq!(count2, 4);
|
||||
@@ -118,21 +117,39 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// ]);
|
||||
/// let count4 = searcher.search(&nested_query, &Count)?;
|
||||
/// assert_eq!(count4, 1);
|
||||
///
|
||||
/// // You may call `with_minimum_required_clauses` to
|
||||
/// // specify the number of should clauses the returned documents must match.
|
||||
/// let minimum_required_query = BooleanQuery::with_minimum_required_clauses(vec![
|
||||
/// (Occur::Should, cow_term_query.box_clone()),
|
||||
/// (Occur::Should, girl_term_query.box_clone()),
|
||||
/// (Occur::Should, diary_term_query.box_clone()),
|
||||
/// ], 2);
|
||||
/// // Return documents contains "Diary Cow", "Diary Girl" or "Cow Girl"
|
||||
/// // Notice: "Diary" isn't "Dairy". ;-)
|
||||
/// let count5 = searcher.search(&minimum_required_query, &Count)?;
|
||||
/// assert_eq!(count5, 1);
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct BooleanQuery {
|
||||
subqueries: Vec<(Occur, Box<dyn Query>)>,
|
||||
minimum_number_should_match: usize,
|
||||
}
|
||||
|
||||
impl Clone for BooleanQuery {
|
||||
fn clone(&self) -> Self {
|
||||
self.subqueries
|
||||
let subqueries = self
|
||||
.subqueries
|
||||
.iter()
|
||||
.map(|(occur, subquery)| (*occur, subquery.box_clone()))
|
||||
.collect::<Vec<_>>()
|
||||
.into()
|
||||
.into();
|
||||
Self {
|
||||
subqueries,
|
||||
minimum_number_should_match: self.minimum_number_should_match,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,8 +166,9 @@ impl Query for BooleanQuery {
|
||||
.iter()
|
||||
.map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?)))
|
||||
.collect::<crate::Result<_>>()?;
|
||||
Ok(Box::new(BooleanWeight::new(
|
||||
Ok(Box::new(BooleanWeight::with_minimum_number_should_match(
|
||||
sub_weights,
|
||||
self.minimum_number_should_match,
|
||||
enable_scoring.is_scoring_enabled(),
|
||||
Box::new(SumWithCoordsCombiner::default),
|
||||
)))
|
||||
@@ -166,7 +184,41 @@ impl Query for BooleanQuery {
|
||||
impl BooleanQuery {
|
||||
/// Creates a new boolean query.
|
||||
pub fn new(subqueries: Vec<(Occur, Box<dyn Query>)>) -> BooleanQuery {
|
||||
BooleanQuery { subqueries }
|
||||
// If the bool query includes at least one should clause
|
||||
// and no Must or MustNot clauses, the default value is 1. Otherwise, the default value is
|
||||
// 0. Keep pace with Elasticsearch.
|
||||
let mut minimum_required = 0;
|
||||
for (occur, _) in &subqueries {
|
||||
match occur {
|
||||
Occur::Should => minimum_required = 1,
|
||||
Occur::Must | Occur::MustNot => {
|
||||
minimum_required = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::with_minimum_required_clauses(subqueries, minimum_required)
|
||||
}
|
||||
|
||||
/// Create a new boolean query with minimum number of required should clauses specified.
|
||||
pub fn with_minimum_required_clauses(
|
||||
subqueries: Vec<(Occur, Box<dyn Query>)>,
|
||||
minimum_number_should_match: usize,
|
||||
) -> BooleanQuery {
|
||||
BooleanQuery {
|
||||
subqueries,
|
||||
minimum_number_should_match,
|
||||
}
|
||||
}
|
||||
|
||||
/// Getter for `minimum_number_should_match`
|
||||
pub fn get_minimum_number_should_match(&self) -> usize {
|
||||
self.minimum_number_should_match
|
||||
}
|
||||
|
||||
/// Setter for `minimum_number_should_match`
|
||||
pub fn set_minimum_number_should_match(&mut self, minimum_number_should_match: usize) {
|
||||
self.minimum_number_should_match = minimum_number_should_match;
|
||||
}
|
||||
|
||||
/// Returns the intersection of the queries.
|
||||
@@ -181,6 +233,18 @@ impl BooleanQuery {
|
||||
BooleanQuery::new(subqueries)
|
||||
}
|
||||
|
||||
/// Returns the union of the queries with minimum required clause.
|
||||
pub fn union_with_minimum_required_clauses(
|
||||
queries: Vec<Box<dyn Query>>,
|
||||
minimum_required_clauses: usize,
|
||||
) -> BooleanQuery {
|
||||
let subqueries = queries
|
||||
.into_iter()
|
||||
.map(|sub_query| (Occur::Should, sub_query))
|
||||
.collect();
|
||||
BooleanQuery::with_minimum_required_clauses(subqueries, minimum_required_clauses)
|
||||
}
|
||||
|
||||
/// Helper method to create a boolean query matching a given list of terms.
|
||||
/// The resulting query is a disjunction of the terms.
|
||||
pub fn new_multiterms_query(terms: Vec<Term>) -> BooleanQuery {
|
||||
@@ -203,11 +267,13 @@ impl BooleanQuery {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::BooleanQuery;
|
||||
use crate::collector::{Count, DocSetCollector};
|
||||
use crate::query::{QueryClone, QueryParser, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, TEXT};
|
||||
use crate::{DocAddress, Index, Term};
|
||||
use crate::query::{Query, QueryClone, QueryParser, TermQuery};
|
||||
use crate::schema::{Field, IndexRecordOption, Schema, TEXT};
|
||||
use crate::{DocAddress, DocId, Index, Term};
|
||||
|
||||
fn create_test_index() -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
@@ -223,6 +289,73 @@ mod tests {
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimum_required() -> crate::Result<()> {
|
||||
fn create_test_index_with<T: IntoIterator<Item = &'static str>>(
|
||||
docs: T,
|
||||
) -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text = schema_builder.add_text_field("text", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests()?;
|
||||
for doc in docs {
|
||||
writer.add_document(doc!(text => doc))?;
|
||||
}
|
||||
writer.commit()?;
|
||||
Ok(index)
|
||||
}
|
||||
fn create_boolean_query_with_mr<T: IntoIterator<Item = &'static str>>(
|
||||
queries: T,
|
||||
field: Field,
|
||||
mr: usize,
|
||||
) -> BooleanQuery {
|
||||
let terms = queries
|
||||
.into_iter()
|
||||
.map(|t| Term::from_field_text(field, t))
|
||||
.map(|t| TermQuery::new(t, IndexRecordOption::Basic))
|
||||
.map(|q| -> Box<dyn Query> { Box::new(q) })
|
||||
.collect();
|
||||
BooleanQuery::union_with_minimum_required_clauses(terms, mr)
|
||||
}
|
||||
fn check_doc_id<T: IntoIterator<Item = DocId>>(
|
||||
expected: T,
|
||||
actually: HashSet<DocAddress>,
|
||||
seg: u32,
|
||||
) {
|
||||
assert_eq!(
|
||||
actually,
|
||||
expected
|
||||
.into_iter()
|
||||
.map(|id| DocAddress::new(seg, id))
|
||||
.collect()
|
||||
);
|
||||
}
|
||||
let index = create_test_index_with(["a b c", "a c e", "d f g", "z z z", "c i b"])?;
|
||||
let searcher = index.reader()?.searcher();
|
||||
let text = index.schema().get_field("text").unwrap();
|
||||
// Documents contains 'a c' 'a z' 'a i' 'c z' 'c i' or 'z i' shall be return.
|
||||
let q1 = create_boolean_query_with_mr(["a", "c", "z", "i"], text, 2);
|
||||
let docs = searcher.search(&q1, &DocSetCollector)?;
|
||||
check_doc_id([0, 1, 4], docs, 0);
|
||||
// Documents contains 'a b c', 'a b e', 'a c e' or 'b c e' shall be return.
|
||||
let q2 = create_boolean_query_with_mr(["a", "b", "c", "e"], text, 3);
|
||||
let docs = searcher.search(&q2, &DocSetCollector)?;
|
||||
check_doc_id([0, 1], docs, 0);
|
||||
// Nothing queried since minimum_required is too large.
|
||||
let q3 = create_boolean_query_with_mr(["a", "b"], text, 3);
|
||||
let docs = searcher.search(&q3, &DocSetCollector)?;
|
||||
assert!(docs.is_empty());
|
||||
// When mr is set to zero or one, there are no difference with `Boolean::Union`.
|
||||
let q4 = create_boolean_query_with_mr(["a", "z"], text, 1);
|
||||
let docs = searcher.search(&q4, &DocSetCollector)?;
|
||||
check_doc_id([0, 1, 3], docs, 0);
|
||||
let q5 = create_boolean_query_with_mr(["a", "b"], text, 0);
|
||||
let docs = searcher.search(&q5, &DocSetCollector)?;
|
||||
check_doc_id([0, 1, 4], docs, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_union() -> crate::Result<()> {
|
||||
let index = create_test_index()?;
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::collections::HashMap;
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::postings::FreqReadingOption;
|
||||
use crate::query::disjunction::Disjunction;
|
||||
use crate::query::explanation::does_not_match;
|
||||
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
||||
use crate::query::term_query::TermScorer;
|
||||
@@ -18,6 +19,26 @@ enum SpecializedScorer {
|
||||
Other(Box<dyn Scorer>),
|
||||
}
|
||||
|
||||
fn scorer_disjunction<TScoreCombiner>(
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
score_combiner: TScoreCombiner,
|
||||
minimum_match_required: usize,
|
||||
) -> Box<dyn Scorer>
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
{
|
||||
debug_assert!(!scorers.is_empty());
|
||||
debug_assert!(minimum_match_required > 1);
|
||||
if scorers.len() == 1 {
|
||||
return scorers.into_iter().next().unwrap(); // Safe unwrap.
|
||||
}
|
||||
Box::new(Disjunction::new(
|
||||
scorers,
|
||||
score_combiner,
|
||||
minimum_match_required,
|
||||
))
|
||||
}
|
||||
|
||||
fn scorer_union<TScoreCombiner>(
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
@@ -70,6 +91,7 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
/// Weight associated to the `BoolQuery`.
|
||||
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
minimum_number_should_match: usize,
|
||||
scoring_enabled: bool,
|
||||
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
|
||||
}
|
||||
@@ -85,6 +107,22 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
weights,
|
||||
scoring_enabled,
|
||||
score_combiner_fn,
|
||||
minimum_number_should_match: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new boolean weight with minimum number of required should clauses specified.
|
||||
pub fn with_minimum_number_should_match(
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
minimum_number_should_match: usize,
|
||||
scoring_enabled: bool,
|
||||
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
|
||||
) -> BooleanWeight<TScoreCombiner> {
|
||||
BooleanWeight {
|
||||
weights,
|
||||
minimum_number_should_match,
|
||||
scoring_enabled,
|
||||
score_combiner_fn,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,43 +149,89 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
|
||||
) -> crate::Result<SpecializedScorer> {
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||
|
||||
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
|
||||
.remove(&Occur::Should)
|
||||
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
|
||||
// Indicate how should clauses are combined with other clauses.
|
||||
enum CombinationMethod {
|
||||
Ignored,
|
||||
// Only contributes to final score.
|
||||
Optional(SpecializedScorer),
|
||||
// Must be fitted.
|
||||
Required(Box<dyn Scorer>),
|
||||
}
|
||||
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
|
||||
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
|
||||
{
|
||||
let num_of_should_scorers = should_scorers.len();
|
||||
if self.minimum_number_should_match > num_of_should_scorers {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
match self.minimum_number_should_match {
|
||||
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
|
||||
1 => CombinationMethod::Required(into_box_scorer(
|
||||
scorer_union(should_scorers, &score_combiner_fn),
|
||||
&score_combiner_fn,
|
||||
)),
|
||||
n @ _ if num_of_should_scorers == n => {
|
||||
// When num_of_should_scorers equals the number of should clauses,
|
||||
// they are no different from must clauses.
|
||||
must_scorers = match must_scorers.take() {
|
||||
Some(mut must_scorers) => {
|
||||
must_scorers.append(&mut should_scorers);
|
||||
Some(must_scorers)
|
||||
}
|
||||
None => Some(should_scorers),
|
||||
};
|
||||
CombinationMethod::Ignored
|
||||
}
|
||||
_ => CombinationMethod::Required(scorer_disjunction(
|
||||
should_scorers,
|
||||
score_combiner_fn(),
|
||||
self.minimum_number_should_match,
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
// None of should clauses are provided.
|
||||
if self.minimum_number_should_match > 0 {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
} else {
|
||||
CombinationMethod::Ignored
|
||||
}
|
||||
};
|
||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::MustNot)
|
||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
|
||||
.map(|specialized_scorer| {
|
||||
.map(|specialized_scorer: SpecializedScorer| {
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
||||
});
|
||||
|
||||
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::Must)
|
||||
.map(intersect_scorers);
|
||||
|
||||
let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) {
|
||||
(Some(should_scorer), Some(must_scorer)) => {
|
||||
let positive_scorer = match (should_opt, must_scorers) {
|
||||
(CombinationMethod::Ignored, Some(must_scorers)) => {
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
||||
}
|
||||
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
|
||||
let must_scorer = intersect_scorers(must_scorers);
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
|
||||
Box<dyn Scorer>,
|
||||
Box<dyn Scorer>,
|
||||
TComplexScoreCombiner,
|
||||
>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
||||
)))
|
||||
SpecializedScorer::Other(Box::new(
|
||||
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
}
|
||||
}
|
||||
(None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer),
|
||||
(Some(should_scorer), None) => should_scorer,
|
||||
(None, None) => {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
|
||||
must_scorers.push(should_scorer);
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
||||
}
|
||||
(CombinationMethod::Ignored, None) => {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
|
||||
}
|
||||
(CombinationMethod::Required(should_scorer), None) => {
|
||||
SpecializedScorer::Other(should_scorer)
|
||||
}
|
||||
// Optional options are promoted to required if no must scorers exists.
|
||||
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
|
||||
};
|
||||
|
||||
if let Some(exclude_scorer) = exclude_scorer_opt {
|
||||
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
|
||||
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
||||
|
||||
327
src/query/disjunction.rs
Normal file
327
src/query/disjunction.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
use crate::query::score_combiner::DoNothingCombiner;
|
||||
use crate::query::{ScoreCombiner, Scorer};
|
||||
use crate::{DocId, DocSet, Score, TERMINATED};
|
||||
|
||||
/// `Disjunction` is responsible for merging `DocSet` from multiple
|
||||
/// source. Specifically, It takes the union of two or more `DocSet`s
|
||||
/// then filtering out elements that appear fewer times than a
|
||||
/// specified threshold.
|
||||
pub struct Disjunction<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
chains: BinaryHeap<ScorerWrapper<TScorer>>,
|
||||
minimum_matches_required: usize,
|
||||
score_combiner: TScoreCombiner,
|
||||
|
||||
current_doc: DocId,
|
||||
current_score: Score,
|
||||
}
|
||||
|
||||
/// A wrapper around a `Scorer` that caches the current `doc_id` and implements the `DocSet` trait.
|
||||
/// Also, the `Ord` trait and it's family are implemented reversely. So that we can combine
|
||||
/// `std::BinaryHeap<ScorerWrapper<T>>` to gain a min-heap with current doc id as key.
|
||||
struct ScorerWrapper<T> {
|
||||
scorer: T,
|
||||
current_doc: DocId,
|
||||
}
|
||||
|
||||
impl<T: Scorer> ScorerWrapper<T> {
|
||||
fn new(scorer: T) -> Self {
|
||||
let current_doc = scorer.doc();
|
||||
Self {
|
||||
scorer,
|
||||
current_doc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> PartialEq for ScorerWrapper<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.doc() == other.doc()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> Eq for ScorerWrapper<T> {}
|
||||
|
||||
impl<T: Scorer> PartialOrd for ScorerWrapper<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> Ord for ScorerWrapper<T> {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.doc().cmp(&other.doc()).reverse()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> DocSet for ScorerWrapper<T> {
|
||||
fn advance(&mut self) -> DocId {
|
||||
let doc_id = self.scorer.advance();
|
||||
self.current_doc = doc_id;
|
||||
doc_id
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.current_doc
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.scorer.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
|
||||
pub fn new<T: IntoIterator<Item = TScorer>>(
|
||||
docsets: T,
|
||||
score_combiner: TScoreCombiner,
|
||||
minimum_matches_required: usize,
|
||||
) -> Self {
|
||||
debug_assert!(
|
||||
minimum_matches_required > 1,
|
||||
"union scorer works better if just one matches required"
|
||||
);
|
||||
let chains = docsets
|
||||
.into_iter()
|
||||
.map(|doc| ScorerWrapper::new(doc))
|
||||
.collect();
|
||||
let mut disjunction = Self {
|
||||
chains,
|
||||
score_combiner,
|
||||
current_doc: TERMINATED,
|
||||
minimum_matches_required,
|
||||
current_score: 0.0,
|
||||
};
|
||||
if minimum_matches_required > disjunction.chains.len() {
|
||||
return disjunction;
|
||||
}
|
||||
disjunction.advance();
|
||||
disjunction
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
|
||||
for Disjunction<TScorer, TScoreCombiner>
|
||||
{
|
||||
fn advance(&mut self) -> DocId {
|
||||
let mut current_num_matches = 0;
|
||||
while let Some(mut candidate) = self.chains.pop() {
|
||||
let next = candidate.doc();
|
||||
if next != TERMINATED {
|
||||
// Peek next doc.
|
||||
if self.current_doc != next {
|
||||
if current_num_matches >= self.minimum_matches_required {
|
||||
self.chains.push(candidate);
|
||||
self.current_score = self.score_combiner.score();
|
||||
return self.current_doc;
|
||||
}
|
||||
// Reset current_num_matches and scores.
|
||||
current_num_matches = 0;
|
||||
self.current_doc = next;
|
||||
self.score_combiner.clear();
|
||||
}
|
||||
current_num_matches += 1;
|
||||
self.score_combiner.update(&mut candidate.scorer);
|
||||
candidate.advance();
|
||||
self.chains.push(candidate);
|
||||
}
|
||||
}
|
||||
if current_num_matches < self.minimum_matches_required {
|
||||
self.current_doc = TERMINATED;
|
||||
}
|
||||
self.current_score = self.score_combiner.score();
|
||||
return self.current_doc;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn doc(&self) -> DocId {
|
||||
self.current_doc
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.chains
|
||||
.iter()
|
||||
.map(|docset| docset.size_hint())
|
||||
.max()
|
||||
.unwrap_or(0u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
|
||||
for Disjunction<TScorer, TScoreCombiner>
|
||||
{
|
||||
fn score(&mut self) -> Score {
|
||||
self.current_score
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use super::Disjunction;
|
||||
use crate::query::score_combiner::DoNothingCombiner;
|
||||
use crate::query::{ConstScorer, Scorer, SumCombiner, VecDocSet};
|
||||
use crate::{DocId, DocSet, Score, TERMINATED};
|
||||
|
||||
fn conjunct<T: Ord + Copy>(arrays: &[Vec<T>], pass_line: usize) -> Vec<T> {
|
||||
let mut counts = BTreeMap::new();
|
||||
for array in arrays {
|
||||
for &element in array {
|
||||
*counts.entry(element).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
counts
|
||||
.iter()
|
||||
.filter_map(|(&element, &count)| {
|
||||
if count >= pass_line {
|
||||
Some(element)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn aux_test_conjunction(vals: Vec<Vec<u32>>, min_match: usize) {
|
||||
let mut union_expected = VecDocSet::from(conjunct(&vals, min_match));
|
||||
let make_scorer = || {
|
||||
Disjunction::new(
|
||||
vals.iter()
|
||||
.cloned()
|
||||
.map(VecDocSet::from)
|
||||
.map(|d| ConstScorer::new(d, 1.0)),
|
||||
DoNothingCombiner::default(),
|
||||
min_match,
|
||||
)
|
||||
};
|
||||
let mut scorer: Disjunction<_, DoNothingCombiner> = make_scorer();
|
||||
let mut count = 0;
|
||||
while scorer.doc() != TERMINATED {
|
||||
assert_eq!(union_expected.doc(), scorer.doc());
|
||||
assert_eq!(union_expected.advance(), scorer.advance());
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(union_expected.advance(), TERMINATED);
|
||||
assert_eq!(count, make_scorer().count_including_deleted());
|
||||
}
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn test_arg_check1() {
|
||||
aux_test_conjunction(vec![], 0);
|
||||
}
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn test_arg_check2() {
|
||||
aux_test_conjunction(vec![], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_corner_case() {
|
||||
aux_test_conjunction(vec![], 2);
|
||||
aux_test_conjunction(vec![vec![]; 1000], 2);
|
||||
aux_test_conjunction(vec![vec![]; 100], usize::MAX);
|
||||
aux_test_conjunction(vec![vec![0xC0FFEE]; 10000], usize::MAX);
|
||||
aux_test_conjunction((1..10000u32).map(|i| vec![i]).collect::<Vec<_>>(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conjunction() {
|
||||
aux_test_conjunction(
|
||||
vec![
|
||||
vec![1, 3333, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
],
|
||||
2,
|
||||
);
|
||||
aux_test_conjunction(
|
||||
vec![vec![8], vec![3, 4, 0xC0FFEEu32], vec![1, 2, 100000000u32]],
|
||||
2,
|
||||
);
|
||||
aux_test_conjunction(
|
||||
vec![
|
||||
vec![1, 3333, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
],
|
||||
3,
|
||||
)
|
||||
}
|
||||
|
||||
// This dummy scorer does nothing but yield doc id increasingly.
|
||||
// with constant score 1.0
|
||||
#[derive(Clone)]
|
||||
struct DummyScorer {
|
||||
cursor: usize,
|
||||
foo: Vec<(DocId, f32)>,
|
||||
}
|
||||
|
||||
impl DummyScorer {
|
||||
fn new(doc_score: Vec<(DocId, f32)>) -> Self {
|
||||
Self {
|
||||
cursor: 0,
|
||||
foo: doc_score,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for DummyScorer {
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.cursor += 1;
|
||||
self.doc()
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.foo.get(self.cursor).map(|x| x.0).unwrap_or(TERMINATED)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.foo.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
impl Scorer for DummyScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_calculate() {
|
||||
let mut scorer = Disjunction::new(
|
||||
vec![
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (4, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
],
|
||||
SumCombiner::default(),
|
||||
3,
|
||||
);
|
||||
assert_eq!(scorer.score(), 5.0);
|
||||
assert_eq!(scorer.advance(), 2);
|
||||
assert_eq!(scorer.score(), 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_calculate_corner_case() {
|
||||
let mut scorer = Disjunction::new(
|
||||
vec![
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
|
||||
],
|
||||
SumCombiner::default(),
|
||||
2,
|
||||
);
|
||||
assert_eq!(scorer.doc(), 1);
|
||||
assert_eq!(scorer.score(), 3.0);
|
||||
assert_eq!(scorer.advance(), 3);
|
||||
assert_eq!(scorer.score(), 2.0);
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ mod bm25;
|
||||
mod boolean_query;
|
||||
mod boost_query;
|
||||
mod const_score_query;
|
||||
mod disjunction;
|
||||
mod disjunction_max_query;
|
||||
mod empty_query;
|
||||
mod exclude;
|
||||
|
||||
@@ -1815,7 +1815,8 @@ mod test {
|
||||
\"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \
|
||||
(Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \
|
||||
type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \
|
||||
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })] }"
|
||||
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })], \
|
||||
minimum_number_should_match: 1 }"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1880,7 +1881,8 @@ mod test {
|
||||
format!("{query:?}"),
|
||||
"BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \
|
||||
type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \
|
||||
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))] }"
|
||||
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))], \
|
||||
minimum_number_should_match: 1 }"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1897,7 +1899,8 @@ mod test {
|
||||
format!("{query:?}"),
|
||||
"BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \
|
||||
\"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \
|
||||
distance: 2, transposition_cost_one: false, prefix: true })] }"
|
||||
distance: 2, transposition_cost_one: false, prefix: true })], \
|
||||
minimum_number_should_match: 1 }"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user