mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-04 08:12:54 +00:00
Make BM25 scoring more flexible (#1855)
* Introduce Bm25StatisticsProvider to inject statistics * fix formatting I accidentally changed
This commit is contained in:
@@ -4,7 +4,7 @@ use std::{fmt, io};
|
||||
|
||||
use crate::collector::Collector;
|
||||
use crate::core::{Executor, SegmentReader};
|
||||
use crate::query::{EnableScoring, Query};
|
||||
use crate::query::{Bm25StatisticsProvider, EnableScoring, Query};
|
||||
use crate::schema::{Document, Schema, Term};
|
||||
use crate::space_usage::SearcherSpaceUsage;
|
||||
use crate::store::{CacheStats, StoreReader};
|
||||
@@ -176,8 +176,27 @@ impl Searcher {
|
||||
query: &dyn Query,
|
||||
collector: &C,
|
||||
) -> crate::Result<C::Fruit> {
|
||||
self.search_with_statistics_provider(query, collector, self)
|
||||
}
|
||||
|
||||
/// Same as [`search(...)`](Searcher::search) but allows specifying
|
||||
/// a [Bm25StatisticsProvider].
|
||||
///
|
||||
/// This can be used to adjust the statistics used in computing BM25
|
||||
/// scores.
|
||||
pub fn search_with_statistics_provider<C: Collector>(
|
||||
&self,
|
||||
query: &dyn Query,
|
||||
collector: &C,
|
||||
statistics_provider: &dyn Bm25StatisticsProvider,
|
||||
) -> crate::Result<C::Fruit> {
|
||||
let enabled_scoring = if collector.requires_scoring() {
|
||||
EnableScoring::enabled_from_statistics_provider(statistics_provider, self)
|
||||
} else {
|
||||
EnableScoring::disabled_from_searcher(self)
|
||||
};
|
||||
let executor = self.inner.index.search_executor();
|
||||
self.search_with_executor(query, collector, executor)
|
||||
self.search_with_executor(query, collector, executor, enabled_scoring)
|
||||
}
|
||||
|
||||
/// Same as [`search(...)`](Searcher::search) but multithreaded.
|
||||
@@ -197,12 +216,8 @@ impl Searcher {
|
||||
query: &dyn Query,
|
||||
collector: &C,
|
||||
executor: &Executor,
|
||||
enabled_scoring: EnableScoring,
|
||||
) -> crate::Result<C::Fruit> {
|
||||
let enabled_scoring = if collector.requires_scoring() {
|
||||
EnableScoring::enabled_from_searcher(self)
|
||||
} else {
|
||||
EnableScoring::disabled_from_searcher(self)
|
||||
};
|
||||
let weight = query.weight(enabled_scoring)?;
|
||||
let segment_readers = self.segment_readers();
|
||||
let fruits = executor.map(
|
||||
|
||||
@@ -112,7 +112,7 @@ mod tests {
|
||||
Term::from_field_text(text, "hello"),
|
||||
IndexRecordOption::WithFreqs,
|
||||
);
|
||||
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?;
|
||||
assert_eq!(scorer.doc(), 0);
|
||||
assert!((scorer.score() - 0.22920431).abs() < 0.001f32);
|
||||
@@ -141,7 +141,7 @@ mod tests {
|
||||
Term::from_field_text(text, "hello"),
|
||||
IndexRecordOption::WithFreqs,
|
||||
);
|
||||
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let mut scorer = weight.scorer(searcher.segment_reader(0), 1.0f32)?;
|
||||
assert_eq!(scorer.doc(), 0);
|
||||
assert!((scorer.score() - 0.22920431).abs() < 0.001f32);
|
||||
|
||||
@@ -1628,7 +1628,7 @@ mod tests {
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let mut term_scorer = term_query
|
||||
.specialized_weight(EnableScoring::Enabled(&searcher))?
|
||||
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
|
||||
.specialized_scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
assert_eq!(term_scorer.doc(), 0);
|
||||
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
|
||||
@@ -1643,7 +1643,7 @@ mod tests {
|
||||
assert_eq!(searcher.segment_readers().len(), 2);
|
||||
for segment_reader in searcher.segment_readers() {
|
||||
let mut term_scorer = term_query
|
||||
.specialized_weight(EnableScoring::Enabled(&searcher))?
|
||||
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
|
||||
.specialized_scorer(segment_reader, 1.0)?;
|
||||
// the difference compared to before is intrinsic to the bm25 formula. no worries
|
||||
// there.
|
||||
@@ -1668,7 +1668,7 @@ mod tests {
|
||||
|
||||
let segment_reader = searcher.segment_reader(0u32);
|
||||
let mut term_scorer = term_query
|
||||
.specialized_weight(EnableScoring::Enabled(&searcher))?
|
||||
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
|
||||
.specialized_scorer(segment_reader, 1.0)?;
|
||||
// the difference compared to before is intrinsic to the bm25 formula. no worries there.
|
||||
for doc in segment_reader.doc_ids_alive() {
|
||||
|
||||
@@ -2,11 +2,53 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::fieldnorm::FieldNormReader;
|
||||
use crate::query::Explanation;
|
||||
use crate::schema::Field;
|
||||
use crate::{Score, Searcher, Term};
|
||||
|
||||
const K1: Score = 1.2;
|
||||
const B: Score = 0.75;
|
||||
|
||||
/// An interface to compute the statistics needed in BM25 scoring.
|
||||
///
|
||||
/// The standard implementation is a [Searcher] but you can also
|
||||
/// create your own to adjust the statistics.
|
||||
pub trait Bm25StatisticsProvider {
|
||||
/// The total number of tokens in a given field across all documents in
|
||||
/// the index.
|
||||
fn total_num_tokens(&self, field: Field) -> crate::Result<u64>;
|
||||
|
||||
/// The total number of documents in the index.
|
||||
fn total_num_docs(&self) -> crate::Result<u64>;
|
||||
|
||||
/// The number of documents containing the given term.
|
||||
fn doc_freq(&self, term: &Term) -> crate::Result<u64>;
|
||||
}
|
||||
|
||||
impl Bm25StatisticsProvider for Searcher {
|
||||
fn total_num_tokens(&self, field: Field) -> crate::Result<u64> {
|
||||
let mut total_num_tokens = 0u64;
|
||||
|
||||
for segment_reader in self.segment_readers() {
|
||||
let inverted_index = segment_reader.inverted_index(field)?;
|
||||
total_num_tokens += inverted_index.total_num_tokens();
|
||||
}
|
||||
Ok(total_num_tokens)
|
||||
}
|
||||
|
||||
fn total_num_docs(&self) -> crate::Result<u64> {
|
||||
let mut total_num_docs = 0u64;
|
||||
|
||||
for segment_reader in self.segment_readers() {
|
||||
total_num_docs += u64::from(segment_reader.max_doc());
|
||||
}
|
||||
Ok(total_num_docs)
|
||||
}
|
||||
|
||||
fn doc_freq(&self, term: &Term) -> crate::Result<u64> {
|
||||
self.doc_freq(term)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn idf(doc_freq: u64, doc_count: u64) -> Score {
|
||||
assert!(doc_count >= doc_freq, "{} >= {}", doc_count, doc_freq);
|
||||
let x = ((doc_count - doc_freq) as Score + 0.5) / (doc_freq as Score + 0.5);
|
||||
@@ -32,6 +74,7 @@ pub struct Bm25Params {
|
||||
pub avg_fieldnorm: Score,
|
||||
}
|
||||
|
||||
/// A struct used for computing BM25 scores.
|
||||
#[derive(Clone)]
|
||||
pub struct Bm25Weight {
|
||||
idf_explain: Explanation,
|
||||
@@ -41,6 +84,7 @@ pub struct Bm25Weight {
|
||||
}
|
||||
|
||||
impl Bm25Weight {
|
||||
/// Increase the weight by a multiplicative factor.
|
||||
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
|
||||
Bm25Weight {
|
||||
idf_explain: self.idf_explain.clone(),
|
||||
@@ -50,7 +94,11 @@ impl Bm25Weight {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> crate::Result<Bm25Weight> {
|
||||
/// Construct a [Bm25Weight] for a phrase of terms.
|
||||
pub fn for_terms(
|
||||
statistics: &dyn Bm25StatisticsProvider,
|
||||
terms: &[Term],
|
||||
) -> crate::Result<Bm25Weight> {
|
||||
assert!(!terms.is_empty(), "Bm25 requires at least one term");
|
||||
let field = terms[0].field();
|
||||
for term in &terms[1..] {
|
||||
@@ -61,17 +109,12 @@ impl Bm25Weight {
|
||||
);
|
||||
}
|
||||
|
||||
let mut total_num_tokens = 0u64;
|
||||
let mut total_num_docs = 0u64;
|
||||
for segment_reader in searcher.segment_readers() {
|
||||
let inverted_index = segment_reader.inverted_index(field)?;
|
||||
total_num_tokens += inverted_index.total_num_tokens();
|
||||
total_num_docs += u64::from(segment_reader.max_doc());
|
||||
}
|
||||
let total_num_tokens = statistics.total_num_tokens(field)?;
|
||||
let total_num_docs = statistics.total_num_docs()?;
|
||||
let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score;
|
||||
|
||||
if terms.len() == 1 {
|
||||
let term_doc_freq = searcher.doc_freq(&terms[0])?;
|
||||
let term_doc_freq = statistics.doc_freq(&terms[0])?;
|
||||
Ok(Bm25Weight::for_one_term(
|
||||
term_doc_freq,
|
||||
total_num_docs,
|
||||
@@ -80,7 +123,7 @@ impl Bm25Weight {
|
||||
} else {
|
||||
let mut idf_sum: Score = 0.0;
|
||||
for term in terms {
|
||||
let term_doc_freq = searcher.doc_freq(term)?;
|
||||
let term_doc_freq = statistics.doc_freq(term)?;
|
||||
idf_sum += idf(term_doc_freq, total_num_docs);
|
||||
}
|
||||
let idf_explain = Explanation::new("idf", idf_sum);
|
||||
@@ -88,6 +131,7 @@ impl Bm25Weight {
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a [Bm25Weight] for a single term.
|
||||
pub fn for_one_term(
|
||||
term_doc_freq: u64,
|
||||
total_num_docs: u64,
|
||||
@@ -114,11 +158,13 @@ impl Bm25Weight {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the BM25 score of a single document.
|
||||
#[inline]
|
||||
pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
|
||||
self.weight * self.tf_factor(fieldnorm_id, term_freq)
|
||||
}
|
||||
|
||||
/// Compute the maximum possible BM25 score given this weight.
|
||||
pub fn max_score(&self) -> Score {
|
||||
self.score(255u8, 2_013_265_944)
|
||||
}
|
||||
@@ -130,6 +176,7 @@ impl Bm25Weight {
|
||||
term_freq / (term_freq + norm)
|
||||
}
|
||||
|
||||
/// Produce an [Explanation] of a BM25 score.
|
||||
pub fn explain(&self, fieldnorm_id: u8, term_freq: u32) -> Explanation {
|
||||
// The explain format is directly copied from Lucene's.
|
||||
// (So, Kudos to Lucene)
|
||||
|
||||
@@ -55,7 +55,7 @@ mod tests {
|
||||
let query_parser = QueryParser::for_index(&index, vec![text_field]);
|
||||
let query = query_parser.parse_query("+a")?;
|
||||
let searcher = index.reader()?.searcher();
|
||||
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
assert!(scorer.is::<TermScorer>());
|
||||
Ok(())
|
||||
@@ -68,13 +68,13 @@ mod tests {
|
||||
let searcher = index.reader()?.searcher();
|
||||
{
|
||||
let query = query_parser.parse_query("+a +b +c")?;
|
||||
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
assert!(scorer.is::<Intersection<TermScorer>>());
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("+a +(b c)")?;
|
||||
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
|
||||
}
|
||||
@@ -88,7 +88,7 @@ mod tests {
|
||||
let searcher = index.reader()?.searcher();
|
||||
{
|
||||
let query = query_parser.parse_query("+a b")?;
|
||||
let weight = query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
assert!(scorer.is::<RequiredOptionalScorer<
|
||||
Box<dyn Scorer>,
|
||||
@@ -243,7 +243,7 @@ mod tests {
|
||||
let boolean_query =
|
||||
BooleanQuery::new(vec![(Occur::Should, term_a), (Occur::Should, term_b)]);
|
||||
let boolean_weight = boolean_query
|
||||
.weight(EnableScoring::Enabled(&searcher))
|
||||
.weight(EnableScoring::enabled_from_searcher(&searcher))
|
||||
.unwrap();
|
||||
{
|
||||
let mut boolean_scorer = boolean_weight.scorer(searcher.segment_reader(0u32), 1.0)?;
|
||||
|
||||
@@ -33,7 +33,7 @@ pub use tantivy_query_grammar::Occur;
|
||||
pub use self::all_query::{AllQuery, AllScorer, AllWeight};
|
||||
pub use self::automaton_weight::AutomatonWeight;
|
||||
pub use self::bitset::BitSetDocSet;
|
||||
pub(crate) use self::bm25::Bm25Weight;
|
||||
pub use self::bm25::{Bm25StatisticsProvider, Bm25Weight};
|
||||
pub use self::boolean_query::BooleanQuery;
|
||||
pub(crate) use self::boolean_query::BooleanWeight;
|
||||
pub use self::boost_query::BoostQuery;
|
||||
|
||||
@@ -44,7 +44,7 @@ impl MoreLikeThisQuery {
|
||||
impl Query for MoreLikeThisQuery {
|
||||
fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
|
||||
let searcher = match enable_scoring {
|
||||
EnableScoring::Enabled(searcher) => searcher,
|
||||
EnableScoring::Enabled { searcher, .. } => searcher,
|
||||
EnableScoring::Disabled { .. } => {
|
||||
let err = "MoreLikeThisQuery requires to enable scoring.".to_string();
|
||||
return Err(crate::TantivyError::InvalidArgument(err));
|
||||
|
||||
@@ -108,7 +108,10 @@ impl PhraseQuery {
|
||||
}
|
||||
let terms = self.phrase_terms();
|
||||
let bm25_weight_opt = match enable_scoring {
|
||||
EnableScoring::Enabled(searcher) => Some(Bm25Weight::for_terms(searcher, &terms)?),
|
||||
EnableScoring::Enabled {
|
||||
statistics_provider,
|
||||
..
|
||||
} => Some(Bm25Weight::for_terms(statistics_provider, &terms)?),
|
||||
EnableScoring::Disabled { .. } => None,
|
||||
};
|
||||
let mut weight = PhraseWeight::new(self.phrase_terms.clone(), bm25_weight_opt);
|
||||
|
||||
@@ -132,7 +132,7 @@ mod tests {
|
||||
Term::from_field_text(text_field, "a"),
|
||||
Term::from_field_text(text_field, "b"),
|
||||
]);
|
||||
let enable_scoring = EnableScoring::Enabled(&searcher);
|
||||
let enable_scoring = EnableScoring::enabled_from_searcher(&searcher);
|
||||
let phrase_weight = phrase_query.phrase_weight(enable_scoring).unwrap();
|
||||
let mut phrase_scorer = phrase_weight
|
||||
.phrase_scorer(searcher.segment_reader(0u32), 1.0)?
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::fmt;
|
||||
|
||||
use downcast_rs::impl_downcast;
|
||||
|
||||
use super::bm25::Bm25StatisticsProvider;
|
||||
use super::Weight;
|
||||
use crate::core::searcher::Searcher;
|
||||
use crate::query::Explanation;
|
||||
@@ -12,7 +13,16 @@ use crate::{DocAddress, Term};
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum EnableScoring<'a> {
|
||||
/// Pass this to enable scoring.
|
||||
Enabled(&'a Searcher),
|
||||
Enabled {
|
||||
/// The searcher to use during scoring.
|
||||
searcher: &'a Searcher,
|
||||
|
||||
/// A [Bm25StatisticsProvider] used to compute BM25 scores.
|
||||
///
|
||||
/// Normally this should be the [Searcher], but you can specify a custom
|
||||
/// one to adjust the statistics.
|
||||
statistics_provider: &'a dyn Bm25StatisticsProvider,
|
||||
},
|
||||
/// Pass this to disable scoring.
|
||||
/// This can improve performance.
|
||||
Disabled {
|
||||
@@ -26,7 +36,21 @@ pub enum EnableScoring<'a> {
|
||||
impl<'a> EnableScoring<'a> {
|
||||
/// Create using [Searcher] with scoring enabled.
|
||||
pub fn enabled_from_searcher(searcher: &'a Searcher) -> EnableScoring<'a> {
|
||||
EnableScoring::Enabled(searcher)
|
||||
EnableScoring::Enabled {
|
||||
searcher,
|
||||
statistics_provider: searcher,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create using a custom [Bm25StatisticsProvider] with scoring enabled.
|
||||
pub fn enabled_from_statistics_provider(
|
||||
statistics_provider: &'a dyn Bm25StatisticsProvider,
|
||||
searcher: &'a Searcher,
|
||||
) -> EnableScoring<'a> {
|
||||
EnableScoring::Enabled {
|
||||
statistics_provider,
|
||||
searcher,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create using [Searcher] with scoring disabled.
|
||||
@@ -48,7 +72,7 @@ impl<'a> EnableScoring<'a> {
|
||||
/// Returns the searcher if available.
|
||||
pub fn searcher(&self) -> Option<&Searcher> {
|
||||
match self {
|
||||
EnableScoring::Enabled(searcher) => Some(searcher),
|
||||
EnableScoring::Enabled { searcher, .. } => Some(*searcher),
|
||||
EnableScoring::Disabled { searcher_opt, .. } => searcher_opt.to_owned(),
|
||||
}
|
||||
}
|
||||
@@ -56,14 +80,14 @@ impl<'a> EnableScoring<'a> {
|
||||
/// Returns the schema.
|
||||
pub fn schema(&self) -> &Schema {
|
||||
match self {
|
||||
EnableScoring::Enabled(searcher) => searcher.schema(),
|
||||
EnableScoring::Enabled { searcher, .. } => searcher.schema(),
|
||||
EnableScoring::Disabled { schema, .. } => schema,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the scoring is enabled.
|
||||
pub fn is_scoring_enabled(&self) -> bool {
|
||||
matches!(self, EnableScoring::Enabled(..))
|
||||
matches!(self, EnableScoring::Enabled { .. })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ mod tests {
|
||||
Term::from_field_text(text_field, "a"),
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
let term_weight = term_query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?;
|
||||
assert_eq!(term_scorer.doc(), 0);
|
||||
@@ -62,7 +62,7 @@ mod tests {
|
||||
Term::from_field_text(text_field, "a"),
|
||||
IndexRecordOption::Basic,
|
||||
);
|
||||
let term_weight = term_query.weight(EnableScoring::Enabled(&searcher))?;
|
||||
let term_weight = term_query.weight(EnableScoring::enabled_from_searcher(&searcher))?;
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
let mut term_scorer = term_weight.scorer(segment_reader, 1.0)?;
|
||||
for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 {
|
||||
|
||||
@@ -96,9 +96,10 @@ impl TermQuery {
|
||||
return Err(crate::TantivyError::SchemaError(error_msg));
|
||||
}
|
||||
let bm25_weight = match enable_scoring {
|
||||
EnableScoring::Enabled(searcher) => {
|
||||
Bm25Weight::for_terms(searcher, &[self.term.clone()])?
|
||||
}
|
||||
EnableScoring::Enabled {
|
||||
statistics_provider,
|
||||
..
|
||||
} => Bm25Weight::for_terms(statistics_provider, &[self.term.clone()])?,
|
||||
EnableScoring::Disabled { .. } => {
|
||||
Bm25Weight::new(Explanation::new("<no score>".to_string(), 1.0f32), 1.0f32)
|
||||
}
|
||||
|
||||
@@ -250,7 +250,8 @@ mod tests {
|
||||
}
|
||||
|
||||
fn test_block_wand_aux(term_query: &TermQuery, searcher: &Searcher) -> crate::Result<()> {
|
||||
let term_weight = term_query.specialized_weight(EnableScoring::Enabled(searcher))?;
|
||||
let term_weight =
|
||||
term_query.specialized_weight(EnableScoring::enabled_from_searcher(searcher))?;
|
||||
for reader in searcher.segment_readers() {
|
||||
let mut block_max_scores = vec![];
|
||||
let mut block_max_scores_b = vec![];
|
||||
|
||||
Reference in New Issue
Block a user