diff --git a/src/aggregation/bucket/filter.rs b/src/aggregation/bucket/filter.rs index d4461bf1f..cd5f0d6cc 100644 --- a/src/aggregation/bucket/filter.rs +++ b/src/aggregation/bucket/filter.rs @@ -727,12 +727,13 @@ mod tests { let schema = schema_builder.build(); let index = Index::create_in_ram(schema); - let mut writer: IndexWriter = index.writer(50_000_000)?; + let mut writer: IndexWriter = index.writer_for_tests()?; writer.add_document(doc!( category => "electronics", brand => "apple", price => 999u64, rating => 4.5f64, in_stock => true ))?; + writer.commit()?; writer.add_document(doc!( category => "electronics", brand => "samsung", price => 799u64, rating => 4.2f64, in_stock => true @@ -936,7 +937,7 @@ mod tests { let index = create_standard_test_index()?; let reader = index.reader()?; let searcher = reader.searcher(); - + assert_eq!(searcher.segment_readers().len(), 2); let agg = json!({ "premium_electronics": { "filter": "category:electronics AND price:[800 TO *]", diff --git a/src/indexer/merger.rs b/src/indexer/merger.rs index 1af64607b..47ac5a55b 100644 --- a/src/indexer/merger.rs +++ b/src/indexer/merger.rs @@ -1518,7 +1518,8 @@ mod tests { let searcher = reader.searcher(); let mut term_scorer = term_query .specialized_weight(EnableScoring::enabled_from_searcher(&searcher))? - .specialized_scorer(searcher.segment_reader(0u32), 1.0)?; + .term_scorer_for_test(searcher.segment_reader(0u32), 1.0)? + .unwrap(); assert_eq!(term_scorer.doc(), 0); assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855); assert_nearly_equals!(term_scorer.score(), 0.0079681855); @@ -1533,7 +1534,8 @@ mod tests { for segment_reader in searcher.segment_readers() { let mut term_scorer = term_query .specialized_weight(EnableScoring::enabled_from_searcher(&searcher))? - .specialized_scorer(segment_reader, 1.0)?; + .term_scorer_for_test(segment_reader, 1.0)? + .unwrap(); // the difference compared to before is intrinsic to the bm25 formula. no worries // there. for doc in segment_reader.doc_ids_alive() { @@ -1558,7 +1560,8 @@ mod tests { let segment_reader = searcher.segment_reader(0u32); let mut term_scorer = term_query .specialized_weight(EnableScoring::enabled_from_searcher(&searcher))? - .specialized_scorer(segment_reader, 1.0)?; + .term_scorer_for_test(segment_reader, 1.0)? + .unwrap(); // the difference compared to before is intrinsic to the bm25 formula. no worries there. for doc in segment_reader.doc_ids_alive() { assert_eq!(term_scorer.doc(), doc); diff --git a/src/query/bm25.rs b/src/query/bm25.rs index 34ad974d4..d662f0eb5 100644 --- a/src/query/bm25.rs +++ b/src/query/bm25.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::fieldnorm::FieldNormReader; use crate::query::Explanation; use crate::schema::Field; @@ -57,13 +59,13 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score { K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm) } -fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] { +fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> { let mut cache: [Score; 256] = [0.0; 256]; for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() { let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8); *cache_mut = cached_tf_component(fieldnorm, average_fieldnorm); } - cache + Arc::new(cache) } /// A struct used for computing BM25 scores. @@ -71,17 +73,20 @@ fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] { pub struct Bm25Weight { idf_explain: Option, weight: Score, - cache: [Score; 256], + cache: Arc<[Score; 256]>, average_fieldnorm: Score, } impl Bm25Weight { /// Increase the weight by a multiplicative factor. pub fn boost_by(&self, boost: Score) -> Bm25Weight { + if boost == 1.0f32 { + return self.clone(); + } Bm25Weight { idf_explain: self.idf_explain.clone(), weight: self.weight * boost, - cache: self.cache, + cache: self.cache.clone(), average_fieldnorm: self.average_fieldnorm, } } diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 5dbd5ea44..9e8cedf2e 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -9,7 +9,7 @@ use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::term_query::TermScorer; use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer}; use crate::query::{ - intersect_scorers, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur, + intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer, Weight, }; use crate::{DocId, Score}; @@ -97,6 +97,15 @@ fn into_box_scorer( } } +enum ShouldScorersCombinationMethod { + // Should scorers are irrelevant. + Ignored, + // Only contributes to final score. + Optional(SpecializedScorer), + // Regardless of score, the should scorers may impact whether a document is matching or not. + Required(SpecializedScorer), +} + /// Weight associated to the `BoolQuery`. pub struct BooleanWeight { weights: Vec<(Occur, Box)>, @@ -159,27 +168,50 @@ impl BooleanWeight { ) -> crate::Result { let num_docs = reader.num_docs(); let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; - // Indicate how should clauses are combined with other clauses. - enum CombinationMethod { - Ignored, - // Only contributes to final score. - Optional(SpecializedScorer), - Required(SpecializedScorer), + + // Indicate how should clauses are combined with must clauses. + let mut must_scorers: Vec> = + per_occur_scorers.remove(&Occur::Must).unwrap_or_default(); + let must_special_scorer_counts = remove_and_count_all_and_empty_scorers(&mut must_scorers); + + if must_special_scorer_counts.num_empty_scorers > 0 { + return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } - let mut must_scorers = per_occur_scorers.remove(&Occur::Must); - let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should) - { + + let mut should_scorers = per_occur_scorers.remove(&Occur::Should).unwrap_or_default(); + let should_special_scorer_counts = + remove_and_count_all_and_empty_scorers(&mut should_scorers); + + let mut exclude_scorers: Vec> = per_occur_scorers + .remove(&Occur::MustNot) + .unwrap_or_default(); + let exclude_special_scorer_counts = + remove_and_count_all_and_empty_scorers(&mut exclude_scorers); + + if exclude_special_scorer_counts.num_all_scorers > 0 { + // We exclude all documents at one point. + return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); + } + + let minimum_number_should_match = self + .minimum_number_should_match + .saturating_sub(should_special_scorer_counts.num_all_scorers); + + let should_scorers: ShouldScorersCombinationMethod = { let num_of_should_scorers = should_scorers.len(); - if self.minimum_number_should_match > num_of_should_scorers { + if minimum_number_should_match > num_of_should_scorers { + // We don't have enough scorers to satisfy the minimum number of should matches. + // The request will match no documents. return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); } - match self.minimum_number_should_match { - 0 => CombinationMethod::Optional(scorer_union( + match minimum_number_should_match { + 0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored, + 0 => ShouldScorersCombinationMethod::Optional(scorer_union( should_scorers, &score_combiner_fn, num_docs, )), - 1 => CombinationMethod::Required(scorer_union( + 1 => ShouldScorersCombinationMethod::Required(scorer_union( should_scorers, &score_combiner_fn, num_docs, @@ -187,76 +219,120 @@ impl BooleanWeight { n if num_of_should_scorers == n => { // When num_of_should_scorers equals the number of should clauses, // they are no different from must clauses. - must_scorers = match must_scorers.take() { - Some(mut must_scorers) => { - must_scorers.append(&mut should_scorers); - Some(must_scorers) - } - None => Some(should_scorers), - }; - CombinationMethod::Ignored + must_scorers.append(&mut should_scorers); + ShouldScorersCombinationMethod::Ignored } - _ => CombinationMethod::Required(SpecializedScorer::Other(scorer_disjunction( - should_scorers, - score_combiner_fn(), - self.minimum_number_should_match, - ))), - } - } else { - // None of should clauses are provided. - if self.minimum_number_should_match > 0 { - return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); - } else { - CombinationMethod::Ignored + _ => ShouldScorersCombinationMethod::Required(SpecializedScorer::Other( + scorer_disjunction( + should_scorers, + score_combiner_fn(), + self.minimum_number_should_match, + ), + )), } }; - let exclude_scorer_opt: Option> = per_occur_scorers - .remove(&Occur::MustNot) - .map(|scorers| scorer_union(scorers, DoNothingCombiner::default, num_docs)) - .map(|specialized_scorer: SpecializedScorer| { - into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs) - }); - let positive_scorer = match (should_opt, must_scorers) { - (CombinationMethod::Ignored, Some(must_scorers)) => { - SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) + + let exclude_scorer_opt: Option> = if exclude_scorers.is_empty() { + None + } else { + let exclude_specialized_scorer: SpecializedScorer = + scorer_union(exclude_scorers, DoNothingCombiner::default, num_docs); + Some(into_box_scorer( + exclude_specialized_scorer, + DoNothingCombiner::default, + num_docs, + )) + }; + + let include_scorer = match (should_scorers, must_scorers) { + (ShouldScorersCombinationMethod::Ignored, must_scorers) => { + let boxed_scorer: Box = if must_scorers.is_empty() { + // We do not have any should scorers, nor all scorers. + // There are still two cases here. + // + // If this follows the removal of some AllScorers in the should/must clauses, + // then we match all documents. + // + // Otherwise, it is really just an EmptyScorer. + if must_special_scorer_counts.num_all_scorers + + should_special_scorer_counts.num_all_scorers + > 0 + { + Box::new(AllScorer::new(reader.max_doc())) + } else { + Box::new(EmptyScorer) + } + } else { + intersect_scorers(must_scorers, num_docs) + }; + SpecializedScorer::Other(boxed_scorer) } - (CombinationMethod::Optional(should_scorer), Some(must_scorers)) => { - let must_scorer = intersect_scorers(must_scorers, num_docs); - if self.scoring_enabled { - SpecializedScorer::Other(Box::new( - RequiredOptionalScorer::<_, _, TScoreCombiner>::new( + (ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => { + if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 { + // Optional options are promoted to required if no must scorers exists. + should_scorer + } else { + let must_scorer = intersect_scorers(must_scorers, num_docs); + if self.scoring_enabled { + SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< + _, + _, + TScoreCombiner, + >::new( must_scorer, into_box_scorer(should_scorer, &score_combiner_fn, num_docs), - ), - )) - } else { - SpecializedScorer::Other(must_scorer) + ))) + } else { + SpecializedScorer::Other(must_scorer) + } } } - (CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => { - must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs)); - SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) + (ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => { + if must_scorers.is_empty() { + should_scorer + } else { + must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs)); + SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs)) + } } - (CombinationMethod::Ignored, None) => { - return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))) - } - (CombinationMethod::Required(should_scorer), None) => should_scorer, - // Optional options are promoted to required if no must scorers exists. - (CombinationMethod::Optional(should_scorer), None) => should_scorer, }; if let Some(exclude_scorer) = exclude_scorer_opt { - let positive_scorer_boxed = - into_box_scorer(positive_scorer, &score_combiner_fn, num_docs); + let include_scorer_boxed = + into_box_scorer(include_scorer, &score_combiner_fn, num_docs); Ok(SpecializedScorer::Other(Box::new(Exclude::new( - positive_scorer_boxed, + include_scorer_boxed, exclude_scorer, )))) } else { - Ok(positive_scorer) + Ok(include_scorer) } } } +#[derive(Default, Copy, Clone, Debug)] +struct AllAndEmptyScorerCounts { + num_all_scorers: usize, + num_empty_scorers: usize, +} + +fn remove_and_count_all_and_empty_scorers( + scorers: &mut Vec>, +) -> AllAndEmptyScorerCounts { + let mut counts = AllAndEmptyScorerCounts::default(); + scorers.retain(|scorer| { + if scorer.is::() { + counts.num_all_scorers += 1; + false + } else if scorer.is::() { + counts.num_empty_scorers += 1; + false + } else { + true + } + }); + counts +} + impl Weight for BooleanWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { let num_docs = reader.num_docs(); @@ -293,7 +369,7 @@ impl Weight for BooleanWeight Weight for BooleanWeight bool { +fn is_include_occur(occur: Occur) -> bool { match occur { Occur::Must | Occur::Should => true, Occur::MustNot => false, diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index b384c275b..cacd30e57 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -14,8 +14,8 @@ mod tests { use crate::collector::TopDocs; use crate::query::term_query::TermScorer; use crate::query::{ - EnableScoring, Intersection, Occur, Query, QueryParser, RequiredOptionalScorer, Scorer, - SumCombiner, TermQuery, + AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, + RequiredOptionalScorer, Scorer, SumCombiner, TermQuery, }; use crate::schema::*; use crate::{assert_nearly_equals, DocAddress, DocId, Index, IndexWriter, Score}; @@ -311,4 +311,67 @@ mod tests { assert_nearly_equals!(explanation.value(), std::f32::consts::LN_2); Ok(()) } + + #[test] + pub fn test_boolean_weight_optimization() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer: IndexWriter = index.writer_for_tests()?; + index_writer.add_document(doc!(text_field=>"hello"))?; + index_writer.add_document(doc!(text_field=>"hello happy"))?; + index_writer.commit()?; + let searcher = index.reader()?.searcher(); + let term_match_all: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "hello"), + IndexRecordOption::Basic, + )); + let term_match_some: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "happy"), + IndexRecordOption::Basic, + )); + let term_match_none: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "tax"), + IndexRecordOption::Basic, + )); + { + let query = BooleanQuery::from(vec![ + (Occur::Must, term_match_all.box_clone()), + (Occur::Must, term_match_some.box_clone()), + ]); + let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; + let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?; + assert!(scorer.is::()); + } + { + let query = BooleanQuery::from(vec![ + (Occur::Must, term_match_all.box_clone()), + (Occur::Must, term_match_some.box_clone()), + (Occur::Must, term_match_none.box_clone()), + ]); + let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; + let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?; + assert!(scorer.is::()); + } + { + let query = BooleanQuery::from(vec![ + (Occur::Should, term_match_all.box_clone()), + (Occur::Should, term_match_none.box_clone()), + ]); + let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; + let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?; + assert!(scorer.is::()); + } + { + let query = BooleanQuery::from(vec![ + (Occur::Should, term_match_some.box_clone()), + (Occur::Should, term_match_none.box_clone()), + ]); + let weight = query.weight(EnableScoring::disabled_from_searcher(&searcher))?; + let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0f32)?; + assert!(scorer.is::()); + } + Ok(()) + } } diff --git a/src/query/exist_query.rs b/src/query/exist_query.rs index f97e9e5c7..7eb09722c 100644 --- a/src/query/exist_query.rs +++ b/src/query/exist_query.rs @@ -127,7 +127,11 @@ impl Weight for ExistsWeight { .any(|col| matches!(col.column_index(), ColumnIndex::Full)) { let all_scorer = AllScorer::new(max_doc); - return Ok(Box::new(BoostScorer::new(all_scorer, boost))); + if boost != 1.0f32 { + return Ok(Box::new(BoostScorer::new(all_scorer, boost))); + } else { + return Ok(Box::new(all_scorer)); + } } // If we have a single dynamic column, use ExistsDocSet diff --git a/src/query/range_query/range_query.rs b/src/query/range_query/range_query.rs index 5035c43f1..7e9e691ec 100644 --- a/src/query/range_query/range_query.rs +++ b/src/query/range_query/range_query.rs @@ -266,8 +266,11 @@ mod tests { use super::RangeQuery; use crate::collector::{Count, TopDocs}; use crate::indexer::NoMergePolicy; + use crate::query::range_query::fast_field_range_doc_set::RangeDocSet; use crate::query::range_query::range_query::InvertedIndexRangeQuery; - use crate::query::QueryParser; + use crate::query::{ + AllScorer, BitSetDocSet, ConstScorer, EmptyScorer, EnableScoring, Query, QueryParser, + }; use crate::schema::{ Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT, }; @@ -660,4 +663,46 @@ mod tests { 0 ); } + + #[test] + fn test_range_query_simplified() { + // This test checks that if the targeted column values are entirely + // within the range, and the column is full, we end up with a AllScorer. + let mut schema_builder = Schema::builder(); + let u64_field = schema_builder.add_u64_field("u64_field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index.writer_for_tests().unwrap(); + index_writer.add_document(doc!(u64_field=> 2u64)).unwrap(); + index_writer.add_document(doc!(u64_field=> 4u64)).unwrap(); + index_writer.commit().unwrap(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + assert_eq!(searcher.segment_readers().len(), 1); + let make_term = |value: u64| Term::from_field_u64(u64_field, value); + let make_scorer = move |lower_bound: Bound, upper_bound: Bound| { + let lower_bound_term = lower_bound.map(make_term); + let upper_bound_term = upper_bound.map(make_term); + let range_query = RangeQuery::new(lower_bound_term, upper_bound_term); + let range_weight = range_query + .weight(EnableScoring::disabled_from_schema(&schema)) + .unwrap(); + let range_scorer = range_weight + .scorer(&searcher.segment_readers()[0], 1.0f32) + .unwrap(); + range_scorer + }; + let range_scorer = make_scorer(Bound::Included(1), Bound::Included(4)); + assert!(range_scorer.is::()); + let range_scorer = make_scorer(Bound::Included(0), Bound::Included(2)); + assert!(range_scorer.is::>>()); + let range_scorer = make_scorer(Bound::Included(3), Bound::Included(10)); + assert!(range_scorer.is::>>()); + let range_scorer = make_scorer(Bound::Included(10), Bound::Included(12)); + assert!(range_scorer.is::>>()); + let range_scorer = make_scorer(Bound::Included(0), Bound::Included(1)); + assert!(range_scorer.is::()); + let range_scorer = make_scorer(Bound::Included(0), Bound::Excluded(2)); + assert!(range_scorer.is::()); + } } diff --git a/src/query/range_query/range_query_fastfield.rs b/src/query/range_query/range_query_fastfield.rs index b17694cfa..0246ee526 100644 --- a/src/query/range_query/range_query_fastfield.rs +++ b/src/query/range_query/range_query_fastfield.rs @@ -6,8 +6,8 @@ use std::net::Ipv6Addr; use std::ops::{Bound, RangeInclusive}; use columnar::{ - Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, NumericalType, - StrColumn, + Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, + NumericalType, StrColumn, }; use common::bounds::{BoundsRange, TransformBound}; @@ -397,6 +397,8 @@ fn search_on_u64_ff( boost: Score, bounds: BoundsRange, ) -> crate::Result> { + let col_min_value = column.min_value(); + let col_max_value = column.max_value(); #[expect(clippy::reversed_empty_ranges)] let value_range = bound_to_value_range( &bounds.lower_bound, @@ -408,6 +410,22 @@ fn search_on_u64_ff( if value_range.is_empty() { return Ok(Box::new(EmptyScorer)); } + if col_min_value >= *value_range.start() && col_max_value <= *value_range.end() { + // all values in the column are within the range. + if column.index.get_cardinality() == Cardinality::Full { + if boost != 1.0f32 { + return Ok(Box::new(ConstScorer::new( + AllScorer::new(column.num_docs()), + boost, + ))); + } else { + return Ok(Box::new(AllScorer::new(column.num_docs()))); + } + } else { + // TODO Make it a field presence request for that specific column + } + } + let docset = RangeDocSet::new(value_range, column); Ok(Box::new(ConstScorer::new(docset, boost))) } diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index 408fac233..1c1fa8389 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -10,7 +10,10 @@ mod tests { use crate::collector::TopDocs; use crate::docset::DocSet; use crate::postings::compression::COMPRESSION_BLOCK_SIZE; - use crate::query::{EnableScoring, Query, QueryParser, Scorer, TermQuery}; + use crate::query::term_query::TermScorer; + use crate::query::{ + AllScorer, EmptyScorer, EnableScoring, Query, QueryParser, Scorer, TermQuery, + }; use crate::schema::{Field, IndexRecordOption, Schema, FAST, STRING, TEXT}; use crate::{assert_nearly_equals, DocAddress, Index, IndexWriter, Term, TERMINATED}; @@ -440,4 +443,82 @@ mod tests { Ok(()) } + + #[test] + fn test_term_weight_all_query_optimization() { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", crate::schema::TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index.writer_for_tests().unwrap(); + index_writer + .add_document(doc!(text_field=>"hello")) + .unwrap(); + index_writer + .add_document(doc!(text_field=>"hello happy")) + .unwrap(); + index_writer.commit().unwrap(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let get_scorer_for_term = |term: &str| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, term), + IndexRecordOption::Basic, + ); + let term_weight = term_query + .weight(EnableScoring::disabled_from_schema(&schema)) + .unwrap(); + term_weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap() + }; + // Should be an allscorer + let match_all_scorer = get_scorer_for_term("hello"); + // Should be a term scorer + let match_some_scorer = get_scorer_for_term("happy"); + // Should be an empty scorer + let empty_scorer = get_scorer_for_term("tax"); + assert!(match_all_scorer.is::()); + assert!(match_some_scorer.is::()); + assert!(empty_scorer.is::()); + } + + #[test] + fn test_term_weight_all_query_optimization_disable_when_scoring_enabled() { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", crate::schema::TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index.writer_for_tests().unwrap(); + index_writer + .add_document(doc!(text_field=>"hello")) + .unwrap(); + index_writer + .add_document(doc!(text_field=>"hello happy")) + .unwrap(); + index_writer.commit().unwrap(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let get_scorer_for_term = |term: &str| { + let term_query = TermQuery::new( + Term::from_field_text(text_field, term), + IndexRecordOption::Basic, + ); + let term_weight = term_query + .weight(EnableScoring::enabled_from_searcher(&searcher)) + .unwrap(); + term_weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap() + }; + // Should be an allscorer + let match_all_scorer = get_scorer_for_term("hello"); + // Should be a term scorer + let one_scorer = get_scorer_for_term("happy"); + // Should be an empty scorer + let empty_scorer = get_scorer_for_term("tax"); + assert!(match_all_scorer.is::()); + assert!(one_scorer.is::()); + assert!(empty_scorer.is::()); + } } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 09f7502b9..5c020febd 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -259,7 +259,7 @@ mod tests { let mut block_max_scores_b = vec![]; let mut docs = vec![]; { - let mut term_scorer = term_weight.specialized_scorer(reader, 1.0)?; + let mut term_scorer = term_weight.term_scorer_for_test(reader, 1.0)?.unwrap(); while term_scorer.doc() != TERMINATED { let mut score = term_scorer.score(); docs.push(term_scorer.doc()); @@ -273,7 +273,7 @@ mod tests { } } { - let mut term_scorer = term_weight.specialized_scorer(reader, 1.0)?; + let mut term_scorer = term_weight.term_scorer_for_test(reader, 1.0)?.unwrap(); for d in docs { term_scorer.seek_block(d); block_max_scores_b.push(term_scorer.block_max_score()); diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index a70c8ce8f..39f59d147 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -6,9 +6,9 @@ use crate::postings::SegmentPostings; use crate::query::bm25::Bm25Weight; use crate::query::explanation::does_not_match; use crate::query::weight::{for_each_docset_buffered, for_each_scorer}; -use crate::query::{Explanation, Scorer, Weight}; +use crate::query::{AllScorer, AllWeight, EmptyScorer, Explanation, Scorer, Weight}; use crate::schema::IndexRecordOption; -use crate::{DocId, Score, Term}; +use crate::{DocId, Score, TantivyError, Term}; pub struct TermWeight { term: Term, @@ -17,20 +17,42 @@ pub struct TermWeight { scoring_enabled: bool, } +enum TermOrEmptyOrAllScorer { + TermScorer(TermScorer), + Empty, + AllMatch(AllScorer), +} + +impl TermOrEmptyOrAllScorer { + pub fn into_boxed_scorer(self) -> Box { + match self { + TermOrEmptyOrAllScorer::TermScorer(scorer) => Box::new(scorer), + TermOrEmptyOrAllScorer::Empty => Box::new(EmptyScorer), + TermOrEmptyOrAllScorer::AllMatch(scorer) => Box::new(scorer), + } + } +} + impl Weight for TermWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { - let term_scorer = self.specialized_scorer(reader, boost)?; - Ok(Box::new(term_scorer)) + Ok(self.specialized_scorer(reader, boost)?.into_boxed_scorer()) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { - let mut scorer = self.specialized_scorer(reader, 1.0)?; - if scorer.doc() > doc || scorer.seek(doc) != doc { - return Err(does_not_match(doc)); + match self.specialized_scorer(reader, 1.0)? { + TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => { + if term_scorer.doc() > doc || term_scorer.seek(doc) != doc { + return Err(does_not_match(doc)); + } + let mut explanation = term_scorer.explain(); + explanation.add_context(format!("Term={:?}", self.term,)); + Ok(explanation) + } + TermOrEmptyOrAllScorer::Empty => { + return Err(does_not_match(doc)); + } + TermOrEmptyOrAllScorer::AllMatch(_) => AllWeight.explain(reader, doc), } - let mut explanation = scorer.explain(); - explanation.add_context(format!("Term={:?}", self.term,)); - Ok(explanation) } fn count(&self, reader: &SegmentReader) -> crate::Result { @@ -51,8 +73,15 @@ impl Weight for TermWeight { reader: &SegmentReader, callback: &mut dyn FnMut(DocId, Score), ) -> crate::Result<()> { - let mut scorer = self.specialized_scorer(reader, 1.0)?; - for_each_scorer(&mut scorer, callback); + match self.specialized_scorer(reader, 1.0)? { + TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => { + for_each_scorer(&mut term_scorer, callback); + } + TermOrEmptyOrAllScorer::Empty => {} + TermOrEmptyOrAllScorer::AllMatch(mut all_scorer) => { + for_each_scorer(&mut all_scorer, callback); + } + } Ok(()) } @@ -63,9 +92,18 @@ impl Weight for TermWeight { reader: &SegmentReader, callback: &mut dyn FnMut(&[DocId]), ) -> crate::Result<()> { - let mut scorer = self.specialized_scorer(reader, 1.0)?; - let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; - for_each_docset_buffered(&mut scorer, &mut buffer, callback); + match self.specialized_scorer(reader, 1.0)? { + TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => { + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; + for_each_docset_buffered(&mut term_scorer, &mut buffer, callback); + } + TermOrEmptyOrAllScorer::Empty => {} + TermOrEmptyOrAllScorer::AllMatch(mut all_scorer) => { + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; + for_each_docset_buffered(&mut all_scorer, &mut buffer, callback); + } + }; + Ok(()) } @@ -85,8 +123,22 @@ impl Weight for TermWeight { reader: &SegmentReader, callback: &mut dyn FnMut(DocId, Score) -> Score, ) -> crate::Result<()> { - let scorer = self.specialized_scorer(reader, 1.0)?; - crate::query::boolean_query::block_wand_single_scorer(scorer, threshold, callback); + let specialized_scorer = self.specialized_scorer(reader, 1.0)?; + match specialized_scorer { + TermOrEmptyOrAllScorer::TermScorer(term_scorer) => { + crate::query::boolean_query::block_wand_single_scorer( + term_scorer, + threshold, + callback, + ); + } + TermOrEmptyOrAllScorer::Empty => {} + TermOrEmptyOrAllScorer::AllMatch(_) => { + return Err(TantivyError::InvalidArgument( + "for each pruning should only be called if scoring is enabled".to_string(), + )); + } + } Ok(()) } } @@ -110,35 +162,63 @@ impl TermWeight { &self.term } - pub(crate) fn specialized_scorer( + /// We need a method to access the actual `TermScorer` implementation + /// for `white box` test, checking in particular that the block max + /// is correct. + #[cfg(test)] + pub(crate) fn term_scorer_for_test( &self, reader: &SegmentReader, boost: Score, - ) -> crate::Result { + ) -> crate::Result> { + let scorer = self.specialized_scorer(reader, boost)?; + Ok(match scorer { + TermOrEmptyOrAllScorer::TermScorer(scorer) => Some(scorer), + _ => None, + }) + } + + fn specialized_scorer( + &self, + reader: &SegmentReader, + boost: Score, + ) -> crate::Result { let field = self.term.field(); let inverted_index = reader.inverted_index(field)?; - let fieldnorm_reader_opt = if self.scoring_enabled { - reader.fieldnorms_readers().get_field(field)? - } else { - None + let Some(term_info) = inverted_index.get_term_info(&self.term)? else { + // The term was not found. + return Ok(TermOrEmptyOrAllScorer::Empty); }; - let fieldnorm_reader = - fieldnorm_reader_opt.unwrap_or_else(|| FieldNormReader::constant(reader.max_doc(), 1)); - let similarity_weight = self.similarity_weight.boost_by(boost); - let postings_opt: Option = - inverted_index.read_postings(&self.term, self.index_record_option)?; - if let Some(segment_postings) = postings_opt { - Ok(TermScorer::new( - segment_postings, - fieldnorm_reader, - similarity_weight, - )) - } else { - Ok(TermScorer::new( - SegmentPostings::empty(), - fieldnorm_reader, - similarity_weight, - )) + + // If we don't care about scores, and our posting lists matches all doc, we can return the + // AllMatch scorer. + if !self.scoring_enabled && term_info.doc_freq == reader.max_doc() { + return Ok(TermOrEmptyOrAllScorer::AllMatch(AllScorer::new( + reader.max_doc(), + ))); } + + let segment_postings: SegmentPostings = + inverted_index.read_postings_from_terminfo(&term_info, self.index_record_option)?; + + let fieldnorm_reader = self.fieldnorm_reader(reader)?; + let similarity_weight = self.similarity_weight.boost_by(boost); + Ok(TermOrEmptyOrAllScorer::TermScorer(TermScorer::new( + segment_postings, + fieldnorm_reader, + similarity_weight, + ))) + } + + fn fieldnorm_reader(&self, segment_reader: &SegmentReader) -> crate::Result { + if self.scoring_enabled { + if let Some(field_norm_reader) = segment_reader + .fieldnorms_readers() + .get_field(self.term.field())? + { + return Ok(field_norm_reader); + } + } + Ok(FieldNormReader::constant(segment_reader.max_doc(), 1)) } }