feat(query): Make BooleanQuery supports minimum_number_should_match (#2405)

* feat(query): Make `BooleanQuery` supports `minimum_number_should_match`. see issue #2398

In this commit, a novel scorer named DisjunctionScorer is introduced, which performs the union of inverted chains with the minimal required elements. BTW, it's implemented via a min-heap. Necessary modifications on `BooleanQuery` and `BooleanWeight` are performed as well.

* fixup! fix test

* fixup!: refactor code.

1. More meaningful names.
2. Add Cache for `Disjunction`'s scorers, and fix bug.
3. Optimize `BooleanWeight::complex_scorer`

Thanks
 Paul Masurel <paul@quickwit.io>

* squash!: come up with better variable naming.

* squash!: fix naming issues.

* squash!: fix typo.

* squash!: Remove CombinationMethod::FullIntersection
This commit is contained in:
落叶乌龟
2024-07-01 15:39:41 +08:00
committed by GitHub
parent d9db5302d9
commit f9ae295507
5 changed files with 590 additions and 42 deletions

View File

@@ -66,6 +66,10 @@ use crate::schema::{IndexRecordOption, Term};
/// Term::from_field_text(title, "diary"),
/// IndexRecordOption::Basic,
/// ));
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "cow"),
/// IndexRecordOption::Basic
/// ));
/// // A TermQuery with "found" in the body
/// let body_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(body, "found"),
@@ -74,7 +78,7 @@ use crate::schema::{IndexRecordOption, Term};
/// // TermQuery "diary" must and "girl" must not be present
/// let queries_with_occurs1 = vec![
/// (Occur::Must, diary_term_query.box_clone()),
/// (Occur::MustNot, girl_term_query),
/// (Occur::MustNot, girl_term_query.box_clone()),
/// ];
/// // Make a BooleanQuery equivalent to
/// // title:+diary title:-girl
@@ -82,15 +86,10 @@ use crate::schema::{IndexRecordOption, Term};
/// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?;
/// assert_eq!(count1, 1);
///
/// // TermQuery for "cow" in the title
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "cow"),
/// IndexRecordOption::Basic,
/// ));
/// // "title:diary OR title:cow"
/// let title_diary_or_cow = BooleanQuery::new(vec![
/// (Occur::Should, diary_term_query.box_clone()),
/// (Occur::Should, cow_term_query),
/// (Occur::Should, cow_term_query.box_clone()),
/// ]);
/// let count2 = searcher.search(&title_diary_or_cow, &Count)?;
/// assert_eq!(count2, 4);
@@ -118,21 +117,39 @@ use crate::schema::{IndexRecordOption, Term};
/// ]);
/// let count4 = searcher.search(&nested_query, &Count)?;
/// assert_eq!(count4, 1);
///
/// // You may call `with_minimum_required_clauses` to
/// // specify the number of should clauses the returned documents must match.
/// let minimum_required_query = BooleanQuery::with_minimum_required_clauses(vec![
/// (Occur::Should, cow_term_query.box_clone()),
/// (Occur::Should, girl_term_query.box_clone()),
/// (Occur::Should, diary_term_query.box_clone()),
/// ], 2);
/// // Return documents contains "Diary Cow", "Diary Girl" or "Cow Girl"
/// // Notice: "Diary" isn't "Dairy". ;-)
/// let count5 = searcher.search(&minimum_required_query, &Count)?;
/// assert_eq!(count5, 1);
/// Ok(())
/// }
/// ```
#[derive(Debug)]
pub struct BooleanQuery {
subqueries: Vec<(Occur, Box<dyn Query>)>,
minimum_number_should_match: usize,
}
impl Clone for BooleanQuery {
fn clone(&self) -> Self {
self.subqueries
let subqueries = self
.subqueries
.iter()
.map(|(occur, subquery)| (*occur, subquery.box_clone()))
.collect::<Vec<_>>()
.into()
.into();
Self {
subqueries,
minimum_number_should_match: self.minimum_number_should_match,
}
}
}
@@ -149,8 +166,9 @@ impl Query for BooleanQuery {
.iter()
.map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?)))
.collect::<crate::Result<_>>()?;
Ok(Box::new(BooleanWeight::new(
Ok(Box::new(BooleanWeight::with_minimum_number_should_match(
sub_weights,
self.minimum_number_should_match,
enable_scoring.is_scoring_enabled(),
Box::new(SumWithCoordsCombiner::default),
)))
@@ -166,7 +184,41 @@ impl Query for BooleanQuery {
impl BooleanQuery {
/// Creates a new boolean query.
pub fn new(subqueries: Vec<(Occur, Box<dyn Query>)>) -> BooleanQuery {
BooleanQuery { subqueries }
// If the bool query includes at least one should clause
// and no Must or MustNot clauses, the default value is 1. Otherwise, the default value is
// 0. Keep pace with Elasticsearch.
let mut minimum_required = 0;
for (occur, _) in &subqueries {
match occur {
Occur::Should => minimum_required = 1,
Occur::Must | Occur::MustNot => {
minimum_required = 0;
break;
}
}
}
Self::with_minimum_required_clauses(subqueries, minimum_required)
}
/// Create a new boolean query with minimum number of required should clauses specified.
pub fn with_minimum_required_clauses(
subqueries: Vec<(Occur, Box<dyn Query>)>,
minimum_number_should_match: usize,
) -> BooleanQuery {
BooleanQuery {
subqueries,
minimum_number_should_match,
}
}
/// Getter for `minimum_number_should_match`
pub fn get_minimum_number_should_match(&self) -> usize {
self.minimum_number_should_match
}
/// Setter for `minimum_number_should_match`
pub fn set_minimum_number_should_match(&mut self, minimum_number_should_match: usize) {
self.minimum_number_should_match = minimum_number_should_match;
}
/// Returns the intersection of the queries.
@@ -181,6 +233,18 @@ impl BooleanQuery {
BooleanQuery::new(subqueries)
}
/// Returns the union of the queries with minimum required clause.
pub fn union_with_minimum_required_clauses(
queries: Vec<Box<dyn Query>>,
minimum_required_clauses: usize,
) -> BooleanQuery {
let subqueries = queries
.into_iter()
.map(|sub_query| (Occur::Should, sub_query))
.collect();
BooleanQuery::with_minimum_required_clauses(subqueries, minimum_required_clauses)
}
/// Helper method to create a boolean query matching a given list of terms.
/// The resulting query is a disjunction of the terms.
pub fn new_multiterms_query(terms: Vec<Term>) -> BooleanQuery {
@@ -203,11 +267,13 @@ impl BooleanQuery {
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::BooleanQuery;
use crate::collector::{Count, DocSetCollector};
use crate::query::{QueryClone, QueryParser, TermQuery};
use crate::schema::{IndexRecordOption, Schema, TEXT};
use crate::{DocAddress, Index, Term};
use crate::query::{Query, QueryClone, QueryParser, TermQuery};
use crate::schema::{Field, IndexRecordOption, Schema, TEXT};
use crate::{DocAddress, DocId, Index, Term};
fn create_test_index() -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
@@ -223,6 +289,73 @@ mod tests {
Ok(index)
}
#[test]
fn test_minimum_required() -> crate::Result<()> {
fn create_test_index_with<T: IntoIterator<Item = &'static str>>(
docs: T,
) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let text = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests()?;
for doc in docs {
writer.add_document(doc!(text => doc))?;
}
writer.commit()?;
Ok(index)
}
fn create_boolean_query_with_mr<T: IntoIterator<Item = &'static str>>(
queries: T,
field: Field,
mr: usize,
) -> BooleanQuery {
let terms = queries
.into_iter()
.map(|t| Term::from_field_text(field, t))
.map(|t| TermQuery::new(t, IndexRecordOption::Basic))
.map(|q| -> Box<dyn Query> { Box::new(q) })
.collect();
BooleanQuery::union_with_minimum_required_clauses(terms, mr)
}
fn check_doc_id<T: IntoIterator<Item = DocId>>(
expected: T,
actually: HashSet<DocAddress>,
seg: u32,
) {
assert_eq!(
actually,
expected
.into_iter()
.map(|id| DocAddress::new(seg, id))
.collect()
);
}
let index = create_test_index_with(["a b c", "a c e", "d f g", "z z z", "c i b"])?;
let searcher = index.reader()?.searcher();
let text = index.schema().get_field("text").unwrap();
// Documents contains 'a c' 'a z' 'a i' 'c z' 'c i' or 'z i' shall be return.
let q1 = create_boolean_query_with_mr(["a", "c", "z", "i"], text, 2);
let docs = searcher.search(&q1, &DocSetCollector)?;
check_doc_id([0, 1, 4], docs, 0);
// Documents contains 'a b c', 'a b e', 'a c e' or 'b c e' shall be return.
let q2 = create_boolean_query_with_mr(["a", "b", "c", "e"], text, 3);
let docs = searcher.search(&q2, &DocSetCollector)?;
check_doc_id([0, 1], docs, 0);
// Nothing queried since minimum_required is too large.
let q3 = create_boolean_query_with_mr(["a", "b"], text, 3);
let docs = searcher.search(&q3, &DocSetCollector)?;
assert!(docs.is_empty());
// When mr is set to zero or one, there are no difference with `Boolean::Union`.
let q4 = create_boolean_query_with_mr(["a", "z"], text, 1);
let docs = searcher.search(&q4, &DocSetCollector)?;
check_doc_id([0, 1, 3], docs, 0);
let q5 = create_boolean_query_with_mr(["a", "b"], text, 0);
let docs = searcher.search(&q5, &DocSetCollector)?;
check_doc_id([0, 1, 4], docs, 0);
Ok(())
}
#[test]
fn test_union() -> crate::Result<()> {
let index = create_test_index()?;

View File

@@ -3,6 +3,7 @@ use std::collections::HashMap;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::index::SegmentReader;
use crate::postings::FreqReadingOption;
use crate::query::disjunction::Disjunction;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
@@ -18,6 +19,26 @@ enum SpecializedScorer {
Other(Box<dyn Scorer>),
}
fn scorer_disjunction<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner: TScoreCombiner,
minimum_match_required: usize,
) -> Box<dyn Scorer>
where
TScoreCombiner: ScoreCombiner,
{
debug_assert!(!scorers.is_empty());
debug_assert!(minimum_match_required > 1);
if scorers.len() == 1 {
return scorers.into_iter().next().unwrap(); // Safe unwrap.
}
Box::new(Disjunction::new(
scorers,
score_combiner,
minimum_match_required,
))
}
fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner,
@@ -70,6 +91,7 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
/// Weight associated to the `BoolQuery`.
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>,
minimum_number_should_match: usize,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
}
@@ -85,6 +107,22 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
weights,
scoring_enabled,
score_combiner_fn,
minimum_number_should_match: 1,
}
}
/// Create a new boolean weight with minimum number of required should clauses specified.
pub fn with_minimum_number_should_match(
weights: Vec<(Occur, Box<dyn Weight>)>,
minimum_number_should_match: usize,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
) -> BooleanWeight<TScoreCombiner> {
BooleanWeight {
weights,
minimum_number_should_match,
scoring_enabled,
score_combiner_fn,
}
}
@@ -111,43 +149,89 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> {
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
.remove(&Occur::Should)
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
// Indicate how should clauses are combined with other clauses.
enum CombinationMethod {
Ignored,
// Only contributes to final score.
Optional(SpecializedScorer),
// Must be fitted.
Required(Box<dyn Scorer>),
}
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
{
let num_of_should_scorers = should_scorers.len();
if self.minimum_number_should_match > num_of_should_scorers {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
1 => CombinationMethod::Required(into_box_scorer(
scorer_union(should_scorers, &score_combiner_fn),
&score_combiner_fn,
)),
n @ _ if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses.
must_scorers = match must_scorers.take() {
Some(mut must_scorers) => {
must_scorers.append(&mut should_scorers);
Some(must_scorers)
}
None => Some(should_scorers),
};
CombinationMethod::Ignored
}
_ => CombinationMethod::Required(scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
)),
}
} else {
// None of should clauses are provided.
if self.minimum_number_should_match > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} else {
CombinationMethod::Ignored
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
.map(|specialized_scorer| {
.map(|specialized_scorer: SpecializedScorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
});
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Must)
.map(intersect_scorers);
let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) {
(Some(should_scorer), Some(must_scorer)) => {
let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => {
SpecializedScorer::Other(intersect_scorers(must_scorers))
}
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
Box<dyn Scorer>,
Box<dyn Scorer>,
TComplexScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn),
)))
SpecializedScorer::Other(Box::new(
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn),
),
))
} else {
SpecializedScorer::Other(must_scorer)
}
}
(None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer),
(Some(should_scorer), None) => should_scorer,
(None, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
must_scorers.push(should_scorer);
SpecializedScorer::Other(intersect_scorers(must_scorers))
}
(CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
}
(CombinationMethod::Required(should_scorer), None) => {
SpecializedScorer::Other(should_scorer)
}
// Optional options are promoted to required if no must scorers exists.
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
};
if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(

327
src/query/disjunction.rs Normal file
View File

@@ -0,0 +1,327 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ScoreCombiner, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
/// `Disjunction` is responsible for merging `DocSet` from multiple
/// source. Specifically, It takes the union of two or more `DocSet`s
/// then filtering out elements that appear fewer times than a
/// specified threshold.
pub struct Disjunction<TScorer, TScoreCombiner = DoNothingCombiner> {
chains: BinaryHeap<ScorerWrapper<TScorer>>,
minimum_matches_required: usize,
score_combiner: TScoreCombiner,
current_doc: DocId,
current_score: Score,
}
/// A wrapper around a `Scorer` that caches the current `doc_id` and implements the `DocSet` trait.
/// Also, the `Ord` trait and it's family are implemented reversely. So that we can combine
/// `std::BinaryHeap<ScorerWrapper<T>>` to gain a min-heap with current doc id as key.
struct ScorerWrapper<T> {
scorer: T,
current_doc: DocId,
}
impl<T: Scorer> ScorerWrapper<T> {
fn new(scorer: T) -> Self {
let current_doc = scorer.doc();
Self {
scorer,
current_doc,
}
}
}
impl<T: Scorer> PartialEq for ScorerWrapper<T> {
fn eq(&self, other: &Self) -> bool {
self.doc() == other.doc()
}
}
impl<T: Scorer> Eq for ScorerWrapper<T> {}
impl<T: Scorer> PartialOrd for ScorerWrapper<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: Scorer> Ord for ScorerWrapper<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.doc().cmp(&other.doc()).reverse()
}
}
impl<T: Scorer> DocSet for ScorerWrapper<T> {
fn advance(&mut self) -> DocId {
let doc_id = self.scorer.advance();
self.current_doc = doc_id;
doc_id
}
fn doc(&self) -> DocId {
self.current_doc
}
fn size_hint(&self) -> u32 {
self.scorer.size_hint()
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
pub fn new<T: IntoIterator<Item = TScorer>>(
docsets: T,
score_combiner: TScoreCombiner,
minimum_matches_required: usize,
) -> Self {
debug_assert!(
minimum_matches_required > 1,
"union scorer works better if just one matches required"
);
let chains = docsets
.into_iter()
.map(|doc| ScorerWrapper::new(doc))
.collect();
let mut disjunction = Self {
chains,
score_combiner,
current_doc: TERMINATED,
minimum_matches_required,
current_score: 0.0,
};
if minimum_matches_required > disjunction.chains.len() {
return disjunction;
}
disjunction.advance();
disjunction
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
for Disjunction<TScorer, TScoreCombiner>
{
fn advance(&mut self) -> DocId {
let mut current_num_matches = 0;
while let Some(mut candidate) = self.chains.pop() {
let next = candidate.doc();
if next != TERMINATED {
// Peek next doc.
if self.current_doc != next {
if current_num_matches >= self.minimum_matches_required {
self.chains.push(candidate);
self.current_score = self.score_combiner.score();
return self.current_doc;
}
// Reset current_num_matches and scores.
current_num_matches = 0;
self.current_doc = next;
self.score_combiner.clear();
}
current_num_matches += 1;
self.score_combiner.update(&mut candidate.scorer);
candidate.advance();
self.chains.push(candidate);
}
}
if current_num_matches < self.minimum_matches_required {
self.current_doc = TERMINATED;
}
self.current_score = self.score_combiner.score();
return self.current_doc;
}
#[inline]
fn doc(&self) -> DocId {
self.current_doc
}
fn size_hint(&self) -> u32 {
self.chains
.iter()
.map(|docset| docset.size_hint())
.max()
.unwrap_or(0u32)
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
for Disjunction<TScorer, TScoreCombiner>
{
fn score(&mut self) -> Score {
self.current_score
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use super::Disjunction;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ConstScorer, Scorer, SumCombiner, VecDocSet};
use crate::{DocId, DocSet, Score, TERMINATED};
fn conjunct<T: Ord + Copy>(arrays: &[Vec<T>], pass_line: usize) -> Vec<T> {
let mut counts = BTreeMap::new();
for array in arrays {
for &element in array {
*counts.entry(element).or_insert(0) += 1;
}
}
counts
.iter()
.filter_map(|(&element, &count)| {
if count >= pass_line {
Some(element)
} else {
None
}
})
.collect()
}
fn aux_test_conjunction(vals: Vec<Vec<u32>>, min_match: usize) {
let mut union_expected = VecDocSet::from(conjunct(&vals, min_match));
let make_scorer = || {
Disjunction::new(
vals.iter()
.cloned()
.map(VecDocSet::from)
.map(|d| ConstScorer::new(d, 1.0)),
DoNothingCombiner::default(),
min_match,
)
};
let mut scorer: Disjunction<_, DoNothingCombiner> = make_scorer();
let mut count = 0;
while scorer.doc() != TERMINATED {
assert_eq!(union_expected.doc(), scorer.doc());
assert_eq!(union_expected.advance(), scorer.advance());
count += 1;
}
assert_eq!(union_expected.advance(), TERMINATED);
assert_eq!(count, make_scorer().count_including_deleted());
}
#[should_panic]
#[test]
fn test_arg_check1() {
aux_test_conjunction(vec![], 0);
}
#[should_panic]
#[test]
fn test_arg_check2() {
aux_test_conjunction(vec![], 1);
}
#[test]
fn test_corner_case() {
aux_test_conjunction(vec![], 2);
aux_test_conjunction(vec![vec![]; 1000], 2);
aux_test_conjunction(vec![vec![]; 100], usize::MAX);
aux_test_conjunction(vec![vec![0xC0FFEE]; 10000], usize::MAX);
aux_test_conjunction((1..10000u32).map(|i| vec![i]).collect::<Vec<_>>(), 2);
}
#[test]
fn test_conjunction() {
aux_test_conjunction(
vec![
vec![1, 3333, 100000000u32],
vec![1, 2, 100000000u32],
vec![1, 2, 100000000u32],
],
2,
);
aux_test_conjunction(
vec![vec![8], vec![3, 4, 0xC0FFEEu32], vec![1, 2, 100000000u32]],
2,
);
aux_test_conjunction(
vec![
vec![1, 3333, 100000000u32],
vec![1, 2, 100000000u32],
vec![1, 2, 100000000u32],
],
3,
)
}
// This dummy scorer does nothing but yield doc id increasingly.
// with constant score 1.0
#[derive(Clone)]
struct DummyScorer {
cursor: usize,
foo: Vec<(DocId, f32)>,
}
impl DummyScorer {
fn new(doc_score: Vec<(DocId, f32)>) -> Self {
Self {
cursor: 0,
foo: doc_score,
}
}
}
impl DocSet for DummyScorer {
fn advance(&mut self) -> DocId {
self.cursor += 1;
self.doc()
}
fn doc(&self) -> DocId {
self.foo.get(self.cursor).map(|x| x.0).unwrap_or(TERMINATED)
}
fn size_hint(&self) -> u32 {
self.foo.len() as u32
}
}
impl Scorer for DummyScorer {
fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
}
}
#[test]
fn test_score_calculate() {
let mut scorer = Disjunction::new(
vec![
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (4, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
],
SumCombiner::default(),
3,
);
assert_eq!(scorer.score(), 5.0);
assert_eq!(scorer.advance(), 2);
assert_eq!(scorer.score(), 3.0);
}
#[test]
fn test_score_calculate_corner_case() {
let mut scorer = Disjunction::new(
vec![
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
],
SumCombiner::default(),
2,
);
assert_eq!(scorer.doc(), 1);
assert_eq!(scorer.score(), 3.0);
assert_eq!(scorer.advance(), 3);
assert_eq!(scorer.score(), 2.0);
}
}

View File

@@ -5,6 +5,7 @@ mod bm25;
mod boolean_query;
mod boost_query;
mod const_score_query;
mod disjunction;
mod disjunction_max_query;
mod empty_query;
mod exclude;

View File

@@ -1815,7 +1815,8 @@ mod test {
\"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \
(Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \
type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })] }"
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })], \
minimum_number_should_match: 1 }"
);
}
@@ -1880,7 +1881,8 @@ mod test {
format!("{query:?}"),
"BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \
type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))] }"
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))], \
minimum_number_should_match: 1 }"
);
}
@@ -1897,7 +1899,8 @@ mod test {
format!("{query:?}"),
"BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \
\"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \
distance: 2, transposition_cost_one: false, prefix: true })] }"
distance: 2, transposition_cost_one: false, prefix: true })], \
minimum_number_should_match: 1 }"
);
}
}