Make BM25 scoring more flexible (#1855)

* Introduce Bm25StatisticsProvider to inject statistics

* fix formatting I accidentally changed
This commit is contained in:
Alex Cole
2023-02-16 02:14:12 -08:00
committed by GitHub
parent 71f43ace1d
commit f2f38c43ce
13 changed files with 133 additions and 42 deletions

View File

@@ -4,7 +4,7 @@ use std::{fmt, io};
use crate::collector::Collector;
use crate::core::{Executor, SegmentReader};
use crate::query::{EnableScoring, Query};
use crate::query::{Bm25StatisticsProvider, EnableScoring, Query};
use crate::schema::{Document, Schema, Term};
use crate::space_usage::SearcherSpaceUsage;
use crate::store::{CacheStats, StoreReader};
@@ -176,8 +176,27 @@ impl Searcher {
query: &dyn Query,
collector: &C,
) -> crate::Result<C::Fruit> {
self.search_with_statistics_provider(query, collector, self)
}
/// Same as [`search(...)`](Searcher::search) but allows specifying
/// a [Bm25StatisticsProvider].
///
/// This can be used to adjust the statistics used in computing BM25
/// scores.
pub fn search_with_statistics_provider<C: Collector>(
&self,
query: &dyn Query,
collector: &C,
statistics_provider: &dyn Bm25StatisticsProvider,
) -> crate::Result<C::Fruit> {
let enabled_scoring = if collector.requires_scoring() {
EnableScoring::enabled_from_statistics_provider(statistics_provider, self)
} else {
EnableScoring::disabled_from_searcher(self)
};
let executor = self.inner.index.search_executor();
self.search_with_executor(query, collector, executor)
self.search_with_executor(query, collector, executor, enabled_scoring)
}
/// Same as [`search(...)`](Searcher::search) but multithreaded.
@@ -197,12 +216,8 @@ impl Searcher {
query: &dyn Query,
collector: &C,
executor: &Executor,
enabled_scoring: EnableScoring,
) -> crate::Result<C::Fruit> {
let enabled_scoring = if collector.requires_scoring() {
EnableScoring::enabled_from_searcher(self)
} else {
EnableScoring::disabled_from_searcher(self)
};
let weight = query.weight(enabled_scoring)?;
let segment_readers = self.segment_readers();
let fruits = executor.map(

View File

@@ -112,7 +112,7 @@ mod tests {
Term::from_field_text(text, "hello"),
IndexRecordOption::WithFreqs,
);
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?;
assert_eq!(scorer.doc(), 0);
assert!((scorer.score() - 0.22920431).abs() < 0.001f32);
@@ -141,7 +141,7 @@ mod tests {
Term::from_field_text(text, "hello"),
IndexRecordOption::WithFreqs,
);
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?;
assert_eq!(scorer.doc(), 0);
assert!((scorer.score() - 0.22920431).abs() < 0.001f32);

View File

@@ -1628,7 +1628,7 @@ mod tests {
let reader = index.reader()?;
let searcher = reader.searcher();
let mut term_scorer = term_query
.specialized_weight(EnableScoring::Enabled(&searcher))?
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(searcher.segment_reader(0u32), 1.0)?;
assert_eq!(term_scorer.doc(), 0);
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
@@ -1643,7 +1643,7 @@ mod tests {
assert_eq!(searcher.segment_readers().len(), 2);
for segment_reader in searcher.segment_readers() {
let mut term_scorer = term_query
.specialized_weight(EnableScoring::Enabled(&searcher))?
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(segment_reader, 1.0)?;
// the difference compared to before is intrinsic to the bm25 formula. no worries
// there.
@@ -1668,7 +1668,7 @@ mod tests {
let segment_reader = searcher.segment_reader(0u32);
let mut term_scorer = term_query
.specialized_weight(EnableScoring::Enabled(&searcher))?
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.specialized_scorer(segment_reader, 1.0)?;
// the difference compared to before is intrinsic to the bm25 formula. no worries there.
for doc in segment_reader.doc_ids_alive() {

View File

@@ -2,11 +2,53 @@ use serde::{Deserialize, Serialize};
use crate::fieldnorm::FieldNormReader;
use crate::query::Explanation;
use crate::schema::Field;
use crate::{Score, Searcher, Term};
const K1: Score = 1.2;
const B: Score = 0.75;
/// An interface to compute the statistics needed in BM25 scoring.
///
/// The standard implementation is a [Searcher] but you can also
/// create your own to adjust the statistics.
pub trait Bm25StatisticsProvider {
/// The total number of tokens in a given field across all documents in
/// the index.
fn total_num_tokens(&self, field: Field) -> crate::Result<u64>;
/// The total number of documents in the index.
fn total_num_docs(&self) -> crate::Result<u64>;
/// The number of documents containing the given term.
fn doc_freq(&self, term: &Term) -> crate::Result<u64>;
}
impl Bm25StatisticsProvider for Searcher {
fn total_num_tokens(&self, field: Field) -> crate::Result<u64> {
let mut total_num_tokens = 0u64;
for segment_reader in self.segment_readers() {
let inverted_index = segment_reader.inverted_index(field)?;
total_num_tokens += inverted_index.total_num_tokens();
}
Ok(total_num_tokens)
}
fn total_num_docs(&self) -> crate::Result<u64> {
let mut total_num_docs = 0u64;
for segment_reader in self.segment_readers() {
total_num_docs += u64::from(segment_reader.max_doc());
}
Ok(total_num_docs)
}
fn doc_freq(&self, term: &Term) -> crate::Result<u64> {
self.doc_freq(term)
}
}
pub(crate) fn idf(doc_freq: u64, doc_count: u64) -> Score {
assert!(doc_count >= doc_freq, "{} >= {}", doc_count, doc_freq);
let x = ((doc_count - doc_freq) as Score + 0.5) / (doc_freq as Score + 0.5);
@@ -32,6 +74,7 @@ pub struct Bm25Params {
pub avg_fieldnorm: Score,
}
/// A struct used for computing BM25 scores.
#[derive(Clone)]
pub struct Bm25Weight {
idf_explain: Explanation,
@@ -41,6 +84,7 @@ pub struct Bm25Weight {
}
impl Bm25Weight {
/// Increase the weight by a multiplicative factor.
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
Bm25Weight {
idf_explain: self.idf_explain.clone(),
@@ -50,7 +94,11 @@ impl Bm25Weight {
}
}
pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> crate::Result<Bm25Weight> {
/// Construct a [Bm25Weight] for a phrase of terms.
pub fn for_terms(
statistics: &dyn Bm25StatisticsProvider,
terms: &[Term],
) -> crate::Result<Bm25Weight> {
assert!(!terms.is_empty(), "Bm25 requires at least one term");
let field = terms[0].field();
for term in &terms[1..] {
@@ -61,17 +109,12 @@ impl Bm25Weight {
);
}
let mut total_num_tokens = 0u64;
let mut total_num_docs = 0u64;
for segment_reader in searcher.segment_readers() {
let inverted_index = segment_reader.inverted_index(field)?;
total_num_tokens += inverted_index.total_num_tokens();
total_num_docs += u64::from(segment_reader.max_doc());
}
let total_num_tokens = statistics.total_num_tokens(field)?;
let total_num_docs = statistics.total_num_docs()?;
let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score;
if terms.len() == 1 {
let term_doc_freq = searcher.doc_freq(&terms[0])?;
let term_doc_freq = statistics.doc_freq(&terms[0])?;
Ok(Bm25Weight::for_one_term(
term_doc_freq,
total_num_docs,
@@ -80,7 +123,7 @@ impl Bm25Weight {
} else {
let mut idf_sum: Score = 0.0;
for term in terms {
let term_doc_freq = searcher.doc_freq(term)?;
let term_doc_freq = statistics.doc_freq(term)?;
idf_sum += idf(term_doc_freq, total_num_docs);
}
let idf_explain = Explanation::new("idf", idf_sum);
@@ -88,6 +131,7 @@ impl Bm25Weight {
}
}
/// Construct a [Bm25Weight] for a single term.
pub fn for_one_term(
term_doc_freq: u64,
total_num_docs: u64,
@@ -114,11 +158,13 @@ impl Bm25Weight {
}
}
/// Compute the BM25 score of a single document.
#[inline]
pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
self.weight * self.tf_factor(fieldnorm_id, term_freq)
}
/// Compute the maximum possible BM25 score given this weight.
pub fn max_score(&self) -> Score {
self.score(255u8, 2_013_265_944)
}
@@ -130,6 +176,7 @@ impl Bm25Weight {
term_freq / (term_freq + norm)
}
/// Produce an [Explanation] of a BM25 score.
pub fn explain(&self, fieldnorm_id: u8, term_freq: u32) -> Explanation {
// The explain format is directly copied from Lucene's.
// (So, Kudos to Lucene)

View File

@@ -55,7 +55,7 @@ mod tests {
let query_parser = QueryParser::for_index(&index, vec![text_field]);
let query = query_parser.parse_query("+a")?;
let searcher = index.reader()?.searcher();
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<TermScorer>());
Ok(())
@@ -68,13 +68,13 @@ mod tests {
let searcher = index.reader()?.searcher();
{
let query = query_parser.parse_query("+a +b +c")?;
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<Intersection<TermScorer>>());
}
{
let query = query_parser.parse_query("+a +(b c)")?;
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
}
@@ -88,7 +88,7 @@ mod tests {
let searcher = index.reader()?.searcher();
{
let query = query_parser.parse_query("+a b")?;
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
assert!(scorer.is::<RequiredOptionalScorer<
Box<dyn Scorer>,
@@ -243,7 +243,7 @@ mod tests {
let boolean_query =
BooleanQuery::new(vec![(Occur::Should, term_a), (Occur::Should, term_b)]);
let boolean_weight = boolean_query
.weight(EnableScoring::Enabled(&searcher))
.weight(EnableScoring::enabled_from_searcher(&searcher))
.unwrap();
{
let mut boolean_scorer = boolean_weight.scorer(searcher.segment_reader(0u32), 1.0)?;

View File

@@ -33,7 +33,7 @@ pub use tantivy_query_grammar::Occur;
pub use self::all_query::{AllQuery, AllScorer, AllWeight};
pub use self::automaton_weight::AutomatonWeight;
pub use self::bitset::BitSetDocSet;
pub(crate) use self::bm25::Bm25Weight;
pub use self::bm25::{Bm25StatisticsProvider, Bm25Weight};
pub use self::boolean_query::BooleanQuery;
pub(crate) use self::boolean_query::BooleanWeight;
pub use self::boost_query::BoostQuery;

View File

@@ -44,7 +44,7 @@ impl MoreLikeThisQuery {
impl Query for MoreLikeThisQuery {
fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
let searcher = match enable_scoring {
EnableScoring::Enabled(searcher) => searcher,
EnableScoring::Enabled { searcher, .. } => searcher,
EnableScoring::Disabled { .. } => {
let err = "MoreLikeThisQuery requires to enable scoring.".to_string();
return Err(crate::TantivyError::InvalidArgument(err));

View File

@@ -108,7 +108,10 @@ impl PhraseQuery {
}
let terms = self.phrase_terms();
let bm25_weight_opt = match enable_scoring {
EnableScoring::Enabled(searcher) => Some(Bm25Weight::for_terms(searcher, &terms)?),
EnableScoring::Enabled {
statistics_provider,
..
} => Some(Bm25Weight::for_terms(statistics_provider, &terms)?),
EnableScoring::Disabled { .. } => None,
};
let mut weight = PhraseWeight::new(self.phrase_terms.clone(), bm25_weight_opt);

View File

@@ -132,7 +132,7 @@ mod tests {
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
let enable_scoring = EnableScoring::Enabled(&searcher);
let enable_scoring = EnableScoring::enabled_from_searcher(&searcher);
let phrase_weight = phrase_query.phrase_weight(enable_scoring).unwrap();
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32), 1.0)?

View File

@@ -2,6 +2,7 @@ use std::fmt;
use downcast_rs::impl_downcast;
use super::bm25::Bm25StatisticsProvider;
use super::Weight;
use crate::core::searcher::Searcher;
use crate::query::Explanation;
@@ -12,7 +13,16 @@ use crate::{DocAddress, Term};
#[derive(Copy, Clone)]
pub enum EnableScoring<'a> {
/// Pass this to enable scoring.
Enabled(&'a Searcher),
Enabled {
/// The searcher to use during scoring.
searcher: &'a Searcher,
/// A [Bm25StatisticsProvider] used to compute BM25 scores.
///
/// Normally this should be the [Searcher], but you can specify a custom
/// one to adjust the statistics.
statistics_provider: &'a dyn Bm25StatisticsProvider,
},
/// Pass this to disable scoring.
/// This can improve performance.
Disabled {
@@ -26,7 +36,21 @@ pub enum EnableScoring<'a> {
impl<'a> EnableScoring<'a> {
/// Create using [Searcher] with scoring enabled.
pub fn enabled_from_searcher(searcher: &'a Searcher) -> EnableScoring<'a> {
EnableScoring::Enabled(searcher)
EnableScoring::Enabled {
searcher,
statistics_provider: searcher,
}
}
/// Create using a custom [Bm25StatisticsProvider] with scoring enabled.
pub fn enabled_from_statistics_provider(
statistics_provider: &'a dyn Bm25StatisticsProvider,
searcher: &'a Searcher,
) -> EnableScoring<'a> {
EnableScoring::Enabled {
statistics_provider,
searcher,
}
}
/// Create using [Searcher] with scoring disabled.
@@ -48,7 +72,7 @@ impl<'a> EnableScoring<'a> {
/// Returns the searcher if available.
pub fn searcher(&self) -> Option<&Searcher> {
match self {
EnableScoring::Enabled(searcher) => Some(searcher),
EnableScoring::Enabled { searcher, .. } => Some(*searcher),
EnableScoring::Disabled { searcher_opt, .. } => searcher_opt.to_owned(),
}
}
@@ -56,14 +80,14 @@ impl<'a> EnableScoring<'a> {
/// Returns the schema.
pub fn schema(&self) -> &Schema {
match self {
EnableScoring::Enabled(searcher) => searcher.schema(),
EnableScoring::Enabled { searcher, .. } => searcher.schema(),
EnableScoring::Disabled { schema, .. } => schema,
}
}
/// Returns true if the scoring is enabled.
pub fn is_scoring_enabled(&self) -> bool {
matches!(self, EnableScoring::Enabled(..))
matches!(self, EnableScoring::Enabled { .. })
}
}

View File

@@ -34,7 +34,7 @@ mod tests {
Term::from_field_text(text_field, "a"),
IndexRecordOption::Basic,
);
let term_weight = term_query.weight(EnableScoring::Enabled(&searcher))?;
let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?;
assert_eq!(term_scorer.doc(), 0);
@@ -62,7 +62,7 @@ mod tests {
Term::from_field_text(text_field, "a"),
IndexRecordOption::Basic,
);
let term_weight = term_query.weight(EnableScoring::Enabled(&searcher))?;
let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?;
for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 {

View File

@@ -96,9 +96,10 @@ impl TermQuery {
return Err(crate::TantivyError::SchemaError(error_msg));
}
let bm25_weight = match enable_scoring {
EnableScoring::Enabled(searcher) => {
Bm25Weight::for_terms(searcher, &[self.term.clone()])?
}
EnableScoring::Enabled {
statistics_provider,
..
} => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?,
EnableScoring::Disabled { .. } => {
Bm25Weight::new(Explanation::new("<no score>".to_string(), 1.0f32), 1.0f32)
}

View File

@@ -250,7 +250,8 @@ mod tests {
}
fn test_block_wand_aux(term_query: &TermQuery, searcher: &Searcher) -> crate::Result<()> {
let term_weight = term_query.specialized_weight(EnableScoring::Enabled(searcher))?;
let term_weight =
term_query.specialized_weight(EnableScoring::enabled_from_searcher(searcher))?;
for reader in searcher.segment_readers() {
let mut block_max_scores = vec![];
let mut block_max_scores_b = vec![];