diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 066803182..f905e52c2 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -105,6 +105,7 @@ pub mod tests { offset: DocId, segment_max_doc: DocId, docs: Vec, + scores: Vec, } impl TestCollector { @@ -112,14 +113,19 @@ pub mod tests { pub fn docs(self) -> Vec { self.docs } + + pub fn scores(self) -> Vec { + self.scores + } } impl Default for TestCollector { fn default() -> TestCollector { TestCollector { - docs: Vec::new(), offset: 0, segment_max_doc: 0, + docs: Vec::new(), + scores: Vec::new(), } } } @@ -131,12 +137,13 @@ pub mod tests { Ok(()) } - fn collect(&mut self, doc: DocId, _score: Score) { + fn collect(&mut self, doc: DocId, score: Score) { self.docs.push(doc + self.offset); + self.scores.push(score); } fn requires_scoring(&self) -> bool { - false + true } } diff --git a/src/query/bm25.rs b/src/query/bm25.rs new file mode 100644 index 000000000..8f7aca577 --- /dev/null +++ b/src/query/bm25.rs @@ -0,0 +1,94 @@ +use fieldnorm::FieldNormReader; +use Term; +use Searcher; +use Score; + +const K1: f32 = 1.2; +const B: f32 = 0.75; + +fn idf(doc_freq: u64, doc_count: u64) -> f32 { + let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5); + (1f32 + x).ln() +} + + +fn cached_tf_component(fieldnorm: u32, average_fieldnorm: f32) -> f32 { + K1 * (1f32 - B + B * fieldnorm as f32 / average_fieldnorm) +} + +fn compute_tf_cache(average_fieldnorm: f32) -> [f32; 256] { + let mut cache = [0f32; 256]; + for fieldnorm_id in 0..256 { + let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8); + cache[fieldnorm_id] = cached_tf_component(fieldnorm, average_fieldnorm); + } + cache +} + +#[derive(Clone)] +pub struct BM25Weight { + weight: f32, + cache: [f32; 256], +} + +impl BM25Weight { + + pub fn null() -> BM25Weight { + BM25Weight { + weight: 0f32, + cache: [1f32; 256] + } + } + + pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> BM25Weight { + assert!(!terms.is_empty(), "BM25 requires at least one term"); + let field = terms[0].field(); + for term in &terms[1..] { + assert_eq!(term.field(), field, "All terms must belong to the same field."); + } + + let mut total_num_tokens = 0u64; + let mut total_num_docs = 0u64; + for segment_reader in searcher.segment_readers() { + let inverted_index = segment_reader.inverted_index(field); + total_num_tokens += inverted_index.total_num_tokens(); + total_num_docs += segment_reader.max_doc() as u64; + } + let average_fieldnorm = total_num_tokens as f32 / total_num_docs as f32; + + let idf = terms.iter() + .map(|term| { + let term_doc_freq = searcher.doc_freq(term); + idf(term_doc_freq, total_num_docs) + }) + .sum::(); + BM25Weight::new(idf, average_fieldnorm) + } + + fn new(idf: f32, average_fieldnorm: f32) -> BM25Weight { + BM25Weight { + weight: idf * (1f32 + K1), + cache: compute_tf_cache(average_fieldnorm), + } + } + + #[inline(always)] + pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score { + let norm = self.cache[fieldnorm_id as usize]; + let term_freq = term_freq as f32; + self.weight * term_freq / (term_freq + norm) + } +} + +#[cfg(test)] +mod tests { + + use tests::assert_nearly_equals; + use super::idf; + + #[test] + fn test_idf() { + assert_nearly_equals(idf(1, 2), 0.6931472); + } + +} diff --git a/src/query/mod.rs b/src/query/mod.rs index b3ed49591..e0ada1b17 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -17,6 +17,7 @@ mod exclude; mod union; mod intersection; mod reqopt_scorer; +mod bm25; #[cfg(test)] mod vec_docset; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index b1aaaa6ca..4ef92a3bd 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -13,43 +13,37 @@ mod tests { use core::Index; use schema::{SchemaBuilder, Term, TEXT}; use collector::tests::TestCollector; + use tests::assert_nearly_equals; - #[test] - pub fn test_phrase_query() { + + fn create_index(texts: &[&'static str]) -> Index { let mut schema_builder = SchemaBuilder::default(); let text_field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); let index = Index::create_in_ram(schema); { let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap(); - { - // 0 - let doc = doc!(text_field=>"b b b d c g c"); - index_writer.add_document(doc); - } - { - // 1 - let doc = doc!(text_field=>"a b b d c g c"); - index_writer.add_document(doc); - } - { - // 2 - let doc = doc!(text_field=>"a b a b c"); - index_writer.add_document(doc); - } - { - // 3 - let doc = doc!(text_field=>"c a b a d ga a"); - index_writer.add_document(doc); - } - { - // 4 - let doc = doc!(text_field=>"a b c"); + for &text in texts { + let doc = doc!(text_field=>text); index_writer.add_document(doc); } assert!(index_writer.commit().is_ok()); } + index.load_searchers().unwrap(); + index + } + #[test] + pub fn test_phrase_query() { + let index = create_index(&[ + "b b b d c g c", + "a b b d c g c", + "a b a b c", + "c a b a d ga a", + "a b c" + ]); + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); index.load_searchers().unwrap(); let searcher = index.searcher(); let test_query = |texts: Vec<&str>| { @@ -58,20 +52,43 @@ mod tests { .iter() .map(|text| Term::from_field_text(text_field, text)) .collect(); - let phrase_query = PhraseQuery::from(terms); + let phrase_query = PhraseQuery::new(terms); searcher .search(&phrase_query, &mut test_collector) .expect("search should succeed"); test_collector.docs() }; - - let empty_vec = Vec::::new(); - assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]); assert_eq!(test_query(vec!["a", "b"]), vec![1, 2, 3, 4]); assert_eq!(test_query(vec!["b", "b"]), vec![0, 1]); - assert_eq!(test_query(vec!["g", "ewrwer"]), empty_vec); - assert_eq!(test_query(vec!["g", "a"]), empty_vec); + assert!(test_query(vec!["g", "ewrwer"]).is_empty()); + assert!(test_query(vec!["g", "a"]).is_empty()); + } + + + #[test] + pub fn test_phrase_score() { + let index = create_index(&["a b c", "a b c a b"]); + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + index.load_searchers().unwrap(); + let searcher = index.searcher(); + let test_query = |texts: Vec<&str>| { + let mut test_collector = TestCollector::default(); + let terms: Vec = texts + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let phrase_query = PhraseQuery::new(terms); + searcher + .search(&phrase_query, &mut test_collector) + .expect("search should succeed"); + test_collector.scores() + }; + let scores = test_query(vec!["a", "b"]); + assert_nearly_equals(scores[0], 0.40618482); + assert_nearly_equals(scores[1], 0.46844664); + } #[test] // motivated by #234 @@ -108,7 +125,7 @@ mod tests { .iter() .map(|text| Term::from_field_text(text_field, text)) .collect(); - let phrase_query = PhraseQuery::from(terms); + let phrase_query = PhraseQuery::new(terms); searcher .search(&phrase_query, &mut test_collector) .expect("search should succeed"); diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index b17980ab0..74b2800ab 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -4,6 +4,7 @@ use core::searcher::Searcher; use super::PhraseWeight; use query::Weight; use Result; +use query::bm25::BM25Weight; /// `PhraseQuery` matches a specific sequence of words. /// @@ -24,21 +25,37 @@ pub struct PhraseQuery { phrase_terms: Vec, } +impl PhraseQuery { + + /// Creates a new `PhraseQuery` given a list of terms. + /// + /// There must be at least two terms, and all terms + /// must belong to the same field. + pub fn new(terms: Vec) -> PhraseQuery { + assert!(terms.len() > 1, "A phrase query is required to have strictly more than one term."); + assert!(terms[1..].iter().all(|term| term.field() == terms[0].field()), "All terms from a phrase query must belong to the same field"); + PhraseQuery { + phrase_terms: terms + } + } +} + impl Query for PhraseQuery { /// Create the weight associated to a query. /// /// See [`Weight`](./trait.Weight.html). - fn weight(&self, _searcher: &Searcher, scoring_enabled: bool) -> Result> { - Ok(box PhraseWeight::new( - self.phrase_terms.clone(), - scoring_enabled, - )) - } -} + fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result> { + let terms = self.phrase_terms.clone(); + if scoring_enabled { + let bm25_weight = BM25Weight::for_terms(searcher, &terms); + Ok(box PhraseWeight::new( + terms, + bm25_weight, + true + )) + } else { + Ok(box PhraseWeight::new(terms, BM25Weight::null(), false)) + } -impl From> for PhraseQuery { - fn from(phrase_terms: Vec) -> PhraseQuery { - assert!(phrase_terms.len() > 1); - PhraseQuery { phrase_terms } } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 6709c57c3..7e453297e 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -2,6 +2,8 @@ use DocId; use docset::{DocSet, SkipResult}; use postings::Postings; use query::{Intersection, Scorer}; +use query::bm25::BM25Weight; +use fieldnorm::FieldNormReader; struct PostingsWithOffset { offset: u32, @@ -43,11 +45,32 @@ pub struct PhraseScorer { intersection_docset: Intersection, PostingsWithOffset>, num_docsets: usize, left: Vec, - right: Vec + right: Vec, + phrase_count: u32, + fieldnorm_reader: FieldNormReader, + similarity_weight: BM25Weight, + score_needed: bool } -/// Computes the length of the intersection of two sorted arrays. +/// Returns true iff the two sorted array contain a common element +fn intersection_exists(left: &[u32], right: &[u32]) -> bool { + let mut left_i = 0; + let mut right_i = 0; + while left_i < left.len() && right_i < right.len() { + let left_val = left[left_i]; + let right_val = right[right_i]; + if left_val < right_val { + left_i += 1; + } else if right_val < left_val { + right_i += 1; + } else { + return true; + } + } + false +} + fn intersection_count(left: &[u32], right: &[u32]) -> usize { let mut left_i = 0; let mut right_i = 0; @@ -98,7 +121,10 @@ fn intersection(left: &mut [u32], right: &[u32]) -> usize { impl PhraseScorer { - pub fn new(term_postings: Vec) -> PhraseScorer { + pub fn new(term_postings: Vec, + similarity_weight: BM25Weight, + fieldnorm_reader: FieldNormReader, + score_needed: bool) -> PhraseScorer { let num_docsets = term_postings.len(); let postings_with_offsets = term_postings .into_iter() @@ -109,12 +135,26 @@ impl PhraseScorer { intersection_docset: Intersection::new(postings_with_offsets), num_docsets, left: Vec::with_capacity(100), - right: Vec::with_capacity(100) + right: Vec::with_capacity(100), + phrase_count: 0u32, + similarity_weight, + fieldnorm_reader, + score_needed, } } fn phrase_match(&mut self) -> bool { - // TODO early exit when we don't care about the phrase frequency + if self.score_needed { + let count = self.phrase_count(); + self.phrase_count = count; + count > 0u32 + } else { + self.phrase_exists() + } + } + + + fn phrase_exists(&mut self) -> bool { { self.intersection_docset .docset_mut_specialized(0) @@ -132,8 +172,28 @@ impl PhraseScorer { } self.intersection_docset.docset_mut_specialized(self.num_docsets - 1).positions(&mut self.right); - intersection_len = intersection_count(&mut self.left[..intersection_len], &self.right[..]); - intersection_len > 0 + intersection_exists(&self.left[..intersection_len], &self.right[..]) + } + + fn phrase_count(&mut self) -> u32 { + { + self.intersection_docset + .docset_mut_specialized(0) + .positions(&mut self.left); + } + let mut intersection_len = self.left.len(); + for i in 1..self.num_docsets - 1 { + { + self.intersection_docset.docset_mut_specialized(i).positions(&mut self.right); + } + intersection_len = intersection(&mut self.left[..intersection_len], &self.right[..]); + if intersection_len == 0 { + return 0u32; + } + } + + self.intersection_docset.docset_mut_specialized(self.num_docsets - 1).positions(&mut self.right); + intersection_count(&self.left[..intersection_len], &self.right[..]) as u32 } } @@ -176,7 +236,9 @@ impl DocSet for PhraseScorer { impl Scorer for PhraseScorer { fn score(&mut self) -> f32 { - 1f32 + let doc = self.doc(); + let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc); + self.similarity_weight.score(fieldnorm_id, self.phrase_count) } } diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 026a9c838..302230127 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -6,26 +6,32 @@ use core::SegmentReader; use super::PhraseScorer; use query::EmptyScorer; use Result; +use query::bm25::BM25Weight; pub struct PhraseWeight { phrase_terms: Vec, + similarity_weight: BM25Weight, + score_needed: bool, } impl PhraseWeight { /// Creates a new phrase weight. - /// - /// Right now `scoring_enabled` is actually ignored. - /// In the future, disabling scoring will result in a small performance boost. - // TODO use the scoring disable information to avoid compute the - // phrase freq in that case, and compute the phrase freq when scoring is enabled. - // Right now we never compute it :| - pub fn new(phrase_terms: Vec, _scoring_enabled: bool) -> PhraseWeight { - PhraseWeight { phrase_terms } + pub fn new(phrase_terms: Vec, + similarity_weight: BM25Weight, + score_needed: bool) -> PhraseWeight { + PhraseWeight { + phrase_terms, + similarity_weight, + score_needed + } } } impl Weight for PhraseWeight { fn scorer(&self, reader: &SegmentReader) -> Result> { + let similarity_weight = self.similarity_weight.clone(); + let field = self.phrase_terms[0].field(); + let fieldnorm_reader = reader.get_fieldnorms_reader(field).expect("Failed to find fieldnorm for field"); if reader.has_deletes() { let mut term_postings_list = Vec::new(); for term in &self.phrase_terms { @@ -37,7 +43,7 @@ impl Weight for PhraseWeight { return Ok(box EmptyScorer); } } - Ok(box PhraseScorer::new(term_postings_list)) + Ok(box PhraseScorer::new(term_postings_list, similarity_weight, fieldnorm_reader, self.score_needed)) } else { let mut term_postings_list = Vec::new(); for term in &self.phrase_terms { @@ -49,7 +55,7 @@ impl Weight for PhraseWeight { return Ok(box EmptyScorer); } } - Ok(box PhraseScorer::new(term_postings_list)) + Ok(box PhraseScorer::new(term_postings_list, similarity_weight, fieldnorm_reader, self.score_needed)) } } } diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index d51225e5f..3a9d67d28 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -305,7 +305,7 @@ fn compose_occur(left: Occur, right: Occur) -> Occur { fn convert_literal_to_query(logical_literal: LogicalLiteral) -> Box { match logical_literal { LogicalLiteral::Term(term) => box TermQuery::new(term, IndexRecordOption::WithFreqs), - LogicalLiteral::Phrase(terms) => box PhraseQuery::from(terms), + LogicalLiteral::Phrase(terms) => box PhraseQuery::new(terms), } } diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index e1b031b58..861b275c6 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -13,6 +13,9 @@ use fastfield::DeleteBitSet; pub(crate) type TermScorerWithDeletes = TermScorer>; pub(crate) type TermScorerNoDeletes = TermScorer>; + + + #[cfg(test)] mod tests { diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index fe522a03d..466489769 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -5,6 +5,7 @@ use query::Query; use query::Weight; use schema::IndexRecordOption; use Searcher; +use query::bm25::BM25Weight; /// A Term query matches all of the documents /// containing a specific term. @@ -36,26 +37,17 @@ impl TermQuery { /// this method return a specific implementation. /// This is useful for optimization purpose. pub fn specialized_weight(&self, searcher: &Searcher, scoring_enabled: bool) -> TermWeight { - let mut total_num_tokens = 0; - let mut total_num_docs = 0; - for segment_reader in searcher.segment_readers() { - let inverted_index = segment_reader.inverted_index(self.term.field()); - total_num_tokens += inverted_index.total_num_tokens(); - total_num_docs += segment_reader.max_doc(); - } - let average_field_norm = total_num_tokens as f32 / total_num_docs as f32; - + let term = self.term.clone(); + let bm25_weight = BM25Weight::for_terms(searcher, &[term]); let index_record_option = if scoring_enabled { self.index_record_option } else { IndexRecordOption::Basic }; TermWeight::new( - searcher.doc_freq(&self.term), - searcher.num_docs(), - average_field_norm, self.term.clone(), - index_record_option + index_record_option, + bm25_weight ) } } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index de7481662..a85c1476e 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -5,12 +5,12 @@ use query::Scorer; use postings::Postings; use fieldnorm::FieldNormReader; +use query::bm25::BM25Weight; pub struct TermScorer { - pub fieldnorm_reader_opt: Option, + pub fieldnorm_reader: FieldNormReader, pub postings: TPostings, - pub weight: f32, - pub cache: [f32; 256], + pub similarity_weight: BM25Weight, } impl DocSet for TermScorer { @@ -34,15 +34,8 @@ impl DocSet for TermScorer { impl Scorer for TermScorer { fn score(&mut self) -> Score { let doc = self.doc(); - let fieldnorm_id = self.fieldnorm_reader_opt - .as_ref() - .map(|fieldnorm_reader| { - fieldnorm_reader.fieldnorm_id(doc) - }) - .unwrap_or(0u8); - let norm = self.cache[fieldnorm_id as usize]; - let term_freq = self.postings.term_freq() as f32; - self.weight * term_freq / (term_freq + norm) + let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc); + self.similarity_weight.score(fieldnorm_id, self.postings.term_freq()) } } diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 5f46d4298..6d50cba6a 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -9,14 +9,12 @@ use super::term_scorer::TermScorer; use fastfield::DeleteBitSet; use postings::NoDelete; use Result; -use fieldnorm::FieldNormReader; -use std::f32; +use query::bm25::BM25Weight; pub struct TermWeight { term: Term, index_record_option: IndexRecordOption, - weight: f32, - cache: [f32; 256], + similarity_weight: BM25Weight, } impl Weight for TermWeight { @@ -24,7 +22,7 @@ impl Weight for TermWeight { fn scorer(&self, reader: &SegmentReader) -> Result> { let field = self.term.field(); let inverted_index = reader.inverted_index(field); - let fieldnorm_reader_opt = reader.get_fieldnorms_reader(field); + let fieldnorm_reader = reader.get_fieldnorms_reader(field).expect("Failed to find fieldnorm reader for field."); let scorer: Box; if reader.has_deletes() { let postings_opt: Option> = @@ -32,17 +30,15 @@ impl Weight for TermWeight { scorer = if let Some(segment_postings) = postings_opt { box TermScorer { - fieldnorm_reader_opt, + fieldnorm_reader, postings: segment_postings, - weight: self.weight, - cache: self.cache + similarity_weight: self.similarity_weight.clone() } } else { box TermScorer { - fieldnorm_reader_opt: None, + fieldnorm_reader, postings: SegmentPostings::::empty(), - weight: self.weight, - cache: self.cache + similarity_weight: self.similarity_weight.clone() } }; } else { @@ -51,17 +47,15 @@ impl Weight for TermWeight { scorer = if let Some(segment_postings) = postings_opt { box TermScorer { - fieldnorm_reader_opt, + fieldnorm_reader, postings: segment_postings, - weight: self.weight, - cache: self.cache + similarity_weight: self.similarity_weight.clone() } } else { box TermScorer { - fieldnorm_reader_opt: None, + fieldnorm_reader, postings: SegmentPostings::::empty(), - weight: self.weight, - cache: self.cache + similarity_weight: self.similarity_weight.clone() } }; } @@ -82,55 +76,17 @@ impl Weight for TermWeight { } } -const K1: f32 = 1.2; -const B: f32 = 0.75; - -fn cached_tf_component(fieldnorm: u32, average_fieldnorm: f32) -> f32 { - K1 * (1f32 - B + B * fieldnorm as f32 / average_fieldnorm) -} - -fn compute_tf_cache(average_fieldnorm: f32) -> [f32; 256] { - let mut cache = [0f32; 256]; - for fieldnorm_id in 0..256 { - let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8); - cache[fieldnorm_id] = cached_tf_component(fieldnorm, average_fieldnorm); - } - cache -} - - -fn idf(doc_freq: u64, doc_count: u64) -> f32 { - let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5); - (1f32 + x).ln() -} - impl TermWeight { - pub fn new(doc_freq: u64, - doc_count: u64, - average_fieldnorm: f32, - term: Term, - index_record_option: IndexRecordOption) -> TermWeight { - let idf = idf(doc_freq, doc_count); + pub fn new(term: Term, + index_record_option: IndexRecordOption, + similarity_weight: BM25Weight) -> TermWeight { TermWeight { term, index_record_option, - weight: idf * (1f32 + K1), - cache: compute_tf_cache(average_fieldnorm), + similarity_weight, } } } -#[cfg(test)] -mod tests { - - use tests::assert_nearly_equals; - use super::idf; - - #[test] - fn test_idf() { - assert_nearly_equals(idf(1, 2), 0.6931472); - } - -}