diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 497e6767a..f0a43268b 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -4,7 +4,7 @@ use std::{fmt, io}; use crate::collector::Collector; use crate::core::{Executor, SegmentReader}; -use crate::query::{EnableScoring, Query}; +use crate::query::{Bm25StatisticsProvider, EnableScoring, Query}; use crate::schema::{Document, Schema, Term}; use crate::space_usage::SearcherSpaceUsage; use crate::store::{CacheStats, StoreReader}; @@ -176,8 +176,27 @@ impl Searcher { query: &dyn Query, collector: &C, ) -> crate::Result { + self.search_with_statistics_provider(query, collector, self) + } + + /// Same as [`search(...)`](Searcher::search) but allows specifying + /// a [Bm25StatisticsProvider]. + /// + /// This can be used to adjust the statistics used in computing BM25 + /// scores. + pub fn search_with_statistics_provider( + &self, + query: &dyn Query, + collector: &C, + statistics_provider: &dyn Bm25StatisticsProvider, + ) -> crate::Result { + let enabled_scoring = if collector.requires_scoring() { + EnableScoring::enabled_from_statistics_provider(statistics_provider, self) + } else { + EnableScoring::disabled_from_searcher(self) + }; let executor = self.inner.index.search_executor(); - self.search_with_executor(query, collector, executor) + self.search_with_executor(query, collector, executor, enabled_scoring) } /// Same as [`search(...)`](Searcher::search) but multithreaded. @@ -197,12 +216,8 @@ impl Searcher { query: &dyn Query, collector: &C, executor: &Executor, + enabled_scoring: EnableScoring, ) -> crate::Result { - let enabled_scoring = if collector.requires_scoring() { - EnableScoring::enabled_from_searcher(self) - } else { - EnableScoring::disabled_from_searcher(self) - }; let weight = query.weight(enabled_scoring)?; let segment_readers = self.segment_readers(); let fruits = executor.map( diff --git a/src/fieldnorm/mod.rs b/src/fieldnorm/mod.rs index e362a9af6..75fcff1e3 100644 --- a/src/fieldnorm/mod.rs +++ b/src/fieldnorm/mod.rs @@ -112,7 +112,7 @@ mod tests { Term::from_field_text(text, "hello"), IndexRecordOption::WithFreqs, ); - let weight = query.weight(EnableScoring::Enabled(&searcher))?; + let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?; assert_eq!(scorer.doc(), 0); assert!((scorer.score() - 0.22920431).abs() < 0.001f32); @@ -141,7 +141,7 @@ mod tests { Term::from_field_text(text, "hello"), IndexRecordOption::WithFreqs, ); - let weight = query.weight(EnableScoring::Enabled(&searcher))?; + let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?; assert_eq!(scorer.doc(), 0); assert!((scorer.score() - 0.22920431).abs() < 0.001f32); diff --git a/src/indexer/merger.rs b/src/indexer/merger.rs index b91af06d6..25cc6edd9 100644 --- a/src/indexer/merger.rs +++ b/src/indexer/merger.rs @@ -1628,7 +1628,7 @@ mod tests { let reader = index.reader()?; let searcher = reader.searcher(); let mut term_scorer = term_query - .specialized_weight(EnableScoring::Enabled(&searcher))? + .specialized_weight(EnableScoring::enabled_from_searcher(&searcher))? .specialized_scorer(searcher.segment_reader(0u32), 1.0)?; assert_eq!(term_scorer.doc(), 0); assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855); @@ -1643,7 +1643,7 @@ mod tests { assert_eq!(searcher.segment_readers().len(), 2); for segment_reader in searcher.segment_readers() { let mut term_scorer = term_query - .specialized_weight(EnableScoring::Enabled(&searcher))? + .specialized_weight(EnableScoring::enabled_from_searcher(&searcher))? .specialized_scorer(segment_reader, 1.0)?; // the difference compared to before is intrinsic to the bm25 formula. no worries // there. @@ -1668,7 +1668,7 @@ mod tests { let segment_reader = searcher.segment_reader(0u32); let mut term_scorer = term_query - .specialized_weight(EnableScoring::Enabled(&searcher))? + .specialized_weight(EnableScoring::enabled_from_searcher(&searcher))? .specialized_scorer(segment_reader, 1.0)?; // the difference compared to before is intrinsic to the bm25 formula. no worries there. for doc in segment_reader.doc_ids_alive() { diff --git a/src/query/bm25.rs b/src/query/bm25.rs index 9d47ba929..b071920e6 100644 --- a/src/query/bm25.rs +++ b/src/query/bm25.rs @@ -2,11 +2,53 @@ use serde::{Deserialize, Serialize}; use crate::fieldnorm::FieldNormReader; use crate::query::Explanation; +use crate::schema::Field; use crate::{Score, Searcher, Term}; const K1: Score = 1.2; const B: Score = 0.75; +/// An interface to compute the statistics needed in BM25 scoring. +/// +/// The standard implementation is a [Searcher] but you can also +/// create your own to adjust the statistics. +pub trait Bm25StatisticsProvider { + /// The total number of tokens in a given field across all documents in + /// the index. + fn total_num_tokens(&self, field: Field) -> crate::Result; + + /// The total number of documents in the index. + fn total_num_docs(&self) -> crate::Result; + + /// The number of documents containing the given term. + fn doc_freq(&self, term: &Term) -> crate::Result; +} + +impl Bm25StatisticsProvider for Searcher { + fn total_num_tokens(&self, field: Field) -> crate::Result { + let mut total_num_tokens = 0u64; + + for segment_reader in self.segment_readers() { + let inverted_index = segment_reader.inverted_index(field)?; + total_num_tokens += inverted_index.total_num_tokens(); + } + Ok(total_num_tokens) + } + + fn total_num_docs(&self) -> crate::Result { + let mut total_num_docs = 0u64; + + for segment_reader in self.segment_readers() { + total_num_docs += u64::from(segment_reader.max_doc()); + } + Ok(total_num_docs) + } + + fn doc_freq(&self, term: &Term) -> crate::Result { + self.doc_freq(term) + } +} + pub(crate) fn idf(doc_freq: u64, doc_count: u64) -> Score { assert!(doc_count >= doc_freq, "{} >= {}", doc_count, doc_freq); let x = ((doc_count - doc_freq) as Score + 0.5) / (doc_freq as Score + 0.5); @@ -32,6 +74,7 @@ pub struct Bm25Params { pub avg_fieldnorm: Score, } +/// A struct used for computing BM25 scores. #[derive(Clone)] pub struct Bm25Weight { idf_explain: Explanation, @@ -41,6 +84,7 @@ pub struct Bm25Weight { } impl Bm25Weight { + /// Increase the weight by a multiplicative factor. pub fn boost_by(&self, boost: Score) -> Bm25Weight { Bm25Weight { idf_explain: self.idf_explain.clone(), @@ -50,7 +94,11 @@ impl Bm25Weight { } } - pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> crate::Result { + /// Construct a [Bm25Weight] for a phrase of terms. + pub fn for_terms( + statistics: &dyn Bm25StatisticsProvider, + terms: &[Term], + ) -> crate::Result { assert!(!terms.is_empty(), "Bm25 requires at least one term"); let field = terms[0].field(); for term in &terms[1..] { @@ -61,17 +109,12 @@ impl Bm25Weight { ); } - let mut total_num_tokens = 0u64; - let mut total_num_docs = 0u64; - for segment_reader in searcher.segment_readers() { - let inverted_index = segment_reader.inverted_index(field)?; - total_num_tokens += inverted_index.total_num_tokens(); - total_num_docs += u64::from(segment_reader.max_doc()); - } + let total_num_tokens = statistics.total_num_tokens(field)?; + let total_num_docs = statistics.total_num_docs()?; let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score; if terms.len() == 1 { - let term_doc_freq = searcher.doc_freq(&terms[0])?; + let term_doc_freq = statistics.doc_freq(&terms[0])?; Ok(Bm25Weight::for_one_term( term_doc_freq, total_num_docs, @@ -80,7 +123,7 @@ impl Bm25Weight { } else { let mut idf_sum: Score = 0.0; for term in terms { - let term_doc_freq = searcher.doc_freq(term)?; + let term_doc_freq = statistics.doc_freq(term)?; idf_sum += idf(term_doc_freq, total_num_docs); } let idf_explain = Explanation::new("idf", idf_sum); @@ -88,6 +131,7 @@ impl Bm25Weight { } } + /// Construct a [Bm25Weight] for a single term. pub fn for_one_term( term_doc_freq: u64, total_num_docs: u64, @@ -114,11 +158,13 @@ impl Bm25Weight { } } + /// Compute the BM25 score of a single document. #[inline] pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score { self.weight * self.tf_factor(fieldnorm_id, term_freq) } + /// Compute the maximum possible BM25 score given this weight. pub fn max_score(&self) -> Score { self.score(255u8, 2_013_265_944) } @@ -130,6 +176,7 @@ impl Bm25Weight { term_freq / (term_freq + norm) } + /// Produce an [Explanation] of a BM25 score. pub fn explain(&self, fieldnorm_id: u8, term_freq: u32) -> Explanation { // The explain format is directly copied from Lucene's. // (So, Kudos to Lucene) diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 219f6b725..b5bf519df 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -55,7 +55,7 @@ mod tests { let query_parser = QueryParser::for_index(&index, vec![text_field]); let query = query_parser.parse_query("+a")?; let searcher = index.reader()?.searcher(); - let weight = query.weight(EnableScoring::Enabled(&searcher))?; + let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?; assert!(scorer.is::()); Ok(()) @@ -68,13 +68,13 @@ mod tests { let searcher = index.reader()?.searcher(); { let query = query_parser.parse_query("+a +b +c")?; - let weight = query.weight(EnableScoring::Enabled(&searcher))?; + let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?; assert!(scorer.is::>()); } { let query = query_parser.parse_query("+a +(b c)")?; - let weight = query.weight(EnableScoring::Enabled(&searcher))?; + let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?; assert!(scorer.is::>>()); } @@ -88,7 +88,7 @@ mod tests { let searcher = index.reader()?.searcher(); { let query = query_parser.parse_query("+a b")?; - let weight = query.weight(EnableScoring::Enabled(&searcher))?; + let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?; assert!(scorer.is::, @@ -243,7 +243,7 @@ mod tests { let boolean_query = BooleanQuery::new(vec![(Occur::Should, term_a), (Occur::Should, term_b)]); let boolean_weight = boolean_query - .weight(EnableScoring::Enabled(&searcher)) + .weight(EnableScoring::enabled_from_searcher(&searcher)) .unwrap(); { let mut boolean_scorer = boolean_weight.scorer(searcher.segment_reader(0u32), 1.0)?; diff --git a/src/query/mod.rs b/src/query/mod.rs index ed0672070..dd606693f 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -33,7 +33,7 @@ pub use tantivy_query_grammar::Occur; pub use self::all_query::{AllQuery, AllScorer, AllWeight}; pub use self::automaton_weight::AutomatonWeight; pub use self::bitset::BitSetDocSet; -pub(crate) use self::bm25::Bm25Weight; +pub use self::bm25::{Bm25StatisticsProvider, Bm25Weight}; pub use self::boolean_query::BooleanQuery; pub(crate) use self::boolean_query::BooleanWeight; pub use self::boost_query::BoostQuery; diff --git a/src/query/more_like_this/query.rs b/src/query/more_like_this/query.rs index 9e6c36424..940b79b2b 100644 --- a/src/query/more_like_this/query.rs +++ b/src/query/more_like_this/query.rs @@ -44,7 +44,7 @@ impl MoreLikeThisQuery { impl Query for MoreLikeThisQuery { fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result> { let searcher = match enable_scoring { - EnableScoring::Enabled(searcher) => searcher, + EnableScoring::Enabled { searcher, .. } => searcher, EnableScoring::Disabled { .. } => { let err = "MoreLikeThisQuery requires to enable scoring.".to_string(); return Err(crate::TantivyError::InvalidArgument(err)); diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index efc31a044..86bcd8ef9 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -108,7 +108,10 @@ impl PhraseQuery { } let terms = self.phrase_terms(); let bm25_weight_opt = match enable_scoring { - EnableScoring::Enabled(searcher) => Some(Bm25Weight::for_terms(searcher, &terms)?), + EnableScoring::Enabled { + statistics_provider, + .. + } => Some(Bm25Weight::for_terms(statistics_provider, &terms)?), EnableScoring::Disabled { .. } => None, }; let mut weight = PhraseWeight::new(self.phrase_terms.clone(), bm25_weight_opt); diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index cb7e7ba5d..5b61eafb8 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -132,7 +132,7 @@ mod tests { Term::from_field_text(text_field, "a"), Term::from_field_text(text_field, "b"), ]); - let enable_scoring = EnableScoring::Enabled(&searcher); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); let phrase_weight = phrase_query.phrase_weight(enable_scoring).unwrap(); let mut phrase_scorer = phrase_weight .phrase_scorer(searcher.segment_reader(0u32), 1.0)? diff --git a/src/query/query.rs b/src/query/query.rs index 3899a81d5..4808849ba 100644 --- a/src/query/query.rs +++ b/src/query/query.rs @@ -2,6 +2,7 @@ use std::fmt; use downcast_rs::impl_downcast; +use super::bm25::Bm25StatisticsProvider; use super::Weight; use crate::core::searcher::Searcher; use crate::query::Explanation; @@ -12,7 +13,16 @@ use crate::{DocAddress, Term}; #[derive(Copy, Clone)] pub enum EnableScoring<'a> { /// Pass this to enable scoring. - Enabled(&'a Searcher), + Enabled { + /// The searcher to use during scoring. + searcher: &'a Searcher, + + /// A [Bm25StatisticsProvider] used to compute BM25 scores. + /// + /// Normally this should be the [Searcher], but you can specify a custom + /// one to adjust the statistics. + statistics_provider: &'a dyn Bm25StatisticsProvider, + }, /// Pass this to disable scoring. /// This can improve performance. Disabled { @@ -26,7 +36,21 @@ pub enum EnableScoring<'a> { impl<'a> EnableScoring<'a> { /// Create using [Searcher] with scoring enabled. pub fn enabled_from_searcher(searcher: &'a Searcher) -> EnableScoring<'a> { - EnableScoring::Enabled(searcher) + EnableScoring::Enabled { + searcher, + statistics_provider: searcher, + } + } + + /// Create using a custom [Bm25StatisticsProvider] with scoring enabled. + pub fn enabled_from_statistics_provider( + statistics_provider: &'a dyn Bm25StatisticsProvider, + searcher: &'a Searcher, + ) -> EnableScoring<'a> { + EnableScoring::Enabled { + statistics_provider, + searcher, + } } /// Create using [Searcher] with scoring disabled. @@ -48,7 +72,7 @@ impl<'a> EnableScoring<'a> { /// Returns the searcher if available. pub fn searcher(&self) -> Option<&Searcher> { match self { - EnableScoring::Enabled(searcher) => Some(searcher), + EnableScoring::Enabled { searcher, .. } => Some(*searcher), EnableScoring::Disabled { searcher_opt, .. } => searcher_opt.to_owned(), } } @@ -56,14 +80,14 @@ impl<'a> EnableScoring<'a> { /// Returns the schema. pub fn schema(&self) -> &Schema { match self { - EnableScoring::Enabled(searcher) => searcher.schema(), + EnableScoring::Enabled { searcher, .. } => searcher.schema(), EnableScoring::Disabled { schema, .. } => schema, } } /// Returns true if the scoring is enabled. pub fn is_scoring_enabled(&self) -> bool { - matches!(self, EnableScoring::Enabled(..)) + matches!(self, EnableScoring::Enabled { .. }) } } diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index e07d313c6..8ea4b1ce2 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -34,7 +34,7 @@ mod tests { Term::from_field_text(text_field, "a"), IndexRecordOption::Basic, ); - let term_weight = term_query.weight(EnableScoring::Enabled(&searcher))?; + let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let segment_reader = searcher.segment_reader(0); let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?; assert_eq!(term_scorer.doc(), 0); @@ -62,7 +62,7 @@ mod tests { Term::from_field_text(text_field, "a"), IndexRecordOption::Basic, ); - let term_weight = term_query.weight(EnableScoring::Enabled(&searcher))?; + let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let segment_reader = searcher.segment_reader(0); let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?; for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 { diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 2e17de7fa..e382b988a 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -96,9 +96,10 @@ impl TermQuery { return Err(crate::TantivyError::SchemaError(error_msg)); } let bm25_weight = match enable_scoring { - EnableScoring::Enabled(searcher) => { - Bm25Weight::for_terms(searcher, &[self.term.clone()])? - } + EnableScoring::Enabled { + statistics_provider, + .. + } => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?, EnableScoring::Disabled { .. } => { Bm25Weight::new(Explanation::new("".to_string(), 1.0f32), 1.0f32) } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 2e7aeeaa4..007d27c03 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -250,7 +250,8 @@ mod tests { } fn test_block_wand_aux(term_query: &TermQuery, searcher: &Searcher) -> crate::Result<()> { - let term_weight = term_query.specialized_weight(EnableScoring::Enabled(searcher))?; + let term_weight = + term_query.specialized_weight(EnableScoring::enabled_from_searcher(searcher))?; for reader in searcher.segment_readers() { let mut block_max_scores = vec![]; let mut block_max_scores_b = vec![];