From aeb8ae3ef0a9ba849674c21c815a903b7c88dc82 Mon Sep 17 00:00:00 2001 From: Audun Halland Date: Tue, 21 Jan 2020 00:11:22 +0100 Subject: [PATCH 1/2] Add TermQuery, PhraseQuery boost_by --- src/query/bm25.rs | 8 +++---- src/query/phrase_query/phrase_query.rs | 9 +++++++- src/query/term_query/mod.rs | 29 ++++++++++++++++++++++++++ src/query/term_query/term_query.rs | 9 +++++++- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/query/bm25.rs b/src/query/bm25.rs index 9a90b95f7..1d153f136 100644 --- a/src/query/bm25.rs +++ b/src/query/bm25.rs @@ -34,7 +34,7 @@ pub struct BM25Weight { } impl BM25Weight { - pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> BM25Weight { + pub fn for_terms(searcher: &Searcher, terms: &[Term], boost: f32) -> BM25Weight { assert!(!terms.is_empty(), "BM25 requires at least one term"); let field = terms[0].field(); for term in &terms[1..] { @@ -75,11 +75,11 @@ impl BM25Weight { .sum::(); idf_explain = Explanation::new("idf", idf); } - BM25Weight::new(idf_explain, average_fieldnorm) + BM25Weight::new(idf_explain, average_fieldnorm, boost) } - fn new(idf_explain: Explanation, average_fieldnorm: f32) -> BM25Weight { - let weight = idf_explain.value() * (1f32 + K1); + fn new(idf_explain: Explanation, average_fieldnorm: f32, boost: f32) -> BM25Weight { + let weight = idf_explain.value() * (1f32 + K1) * boost; BM25Weight { idf_explain, weight, diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index be95b32ee..4d2986e7d 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -27,6 +27,7 @@ use std::collections::BTreeSet; pub struct PhraseQuery { field: Field, phrase_terms: Vec<(usize, Term)>, + boost: f32, } impl PhraseQuery { @@ -57,9 +58,15 @@ impl PhraseQuery { PhraseQuery { field, phrase_terms: terms, + boost: 1.0, } } + /// Boost the query score by the given factor. + pub fn boost_by(self, boost: f32) -> Self { + Self { boost, ..self } + } + /// The `Field` this `PhraseQuery` is targeting. pub fn field(&self) -> Field { self.field @@ -97,7 +104,7 @@ impl PhraseQuery { ))); } let terms = self.phrase_terms(); - let bm25_weight = BM25Weight::for_terms(searcher, &terms); + let bm25_weight = BM25Weight::for_terms(searcher, &terms, self.boost); Ok(PhraseWeight::new( self.phrase_terms.clone(), bm25_weight, diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index d5a29f9fd..3fc8bbff0 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -45,6 +45,35 @@ mod tests { assert_eq!(term_scorer.score(), 0.28768212); } + #[test] + pub fn test_term_query_boost_by() { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", STRING); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + // writing the segment + let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + { + let doc = doc!(text_field => "a"); + index_writer.add_document(doc); + } + assert!(index_writer.commit().is_ok()); + } + let searcher = index.reader().unwrap().searcher(); + let term_query = TermQuery::new( + Term::from_field_text(text_field, "a"), + IndexRecordOption::Basic, + ) + .boost_by(42.0); + let term_weight = term_query.weight(&searcher, true).unwrap(); + let segment_reader = searcher.segment_reader(0); + let mut term_scorer = term_weight.scorer(segment_reader).unwrap(); + assert!(term_scorer.advance()); + assert_eq!(term_scorer.doc(), 0); + assert_nearly_equals(0.28768212 * 42.0, term_scorer.score()); + } + #[test] pub fn test_term_weight() { let mut schema_builder = Schema::builder(); diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 402a1d738..c3d80026c 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -61,6 +61,7 @@ use std::fmt; pub struct TermQuery { term: Term, index_record_option: IndexRecordOption, + boost: f32, } impl fmt::Debug for TermQuery { @@ -75,9 +76,15 @@ impl TermQuery { TermQuery { term, index_record_option: segment_postings_options, + boost: 1.0, } } + /// Boost the query score by the given factor. + pub fn boost_by(self, boost: f32) -> Self { + Self { boost, ..self } + } + /// The `Term` this query is built out of. pub fn term(&self) -> &Term { &self.term @@ -90,7 +97,7 @@ impl TermQuery { /// This is useful for optimization purpose. pub fn specialized_weight(&self, searcher: &Searcher, scoring_enabled: bool) -> TermWeight { let term = self.term.clone(); - let bm25_weight = BM25Weight::for_terms(searcher, &[term]); + let bm25_weight = BM25Weight::for_terms(searcher, &[term], self.boost); let index_record_option = if scoring_enabled { self.index_record_option } else { From 9f04f42b649eb0746c5370f85f52272a20f40b33 Mon Sep 17 00:00:00 2001 From: Audun Halland Date: Tue, 21 Jan 2020 00:30:34 +0100 Subject: [PATCH 2/2] Add boost_by to FuzzyQuery, RegexQuery --- src/query/automaton_weight.rs | 9 ++++++++- src/query/fuzzy_query.rs | 10 +++++++++- src/query/regex_query.rs | 9 ++++++++- src/query/scorer.rs | 5 +++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index be012680f..d7ca81ae8 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -15,6 +15,7 @@ use tantivy_fst::Automaton; pub struct AutomatonWeight { field: Field, automaton: Arc, + boost: f32, } impl AutomatonWeight @@ -26,9 +27,15 @@ where AutomatonWeight { field, automaton: automaton.into(), + boost: 1.0, } } + /// Boost the scorer by the given factor. + pub fn boost_by(self, boost: f32) -> Self { + Self { boost, ..self } + } + fn automaton_stream<'a>(&'a self, term_dict: &'a TermDictionary) -> TermStreamer<'a, &'a A> { let automaton: &A = &*self.automaton; let term_stream_builder = term_dict.search(automaton); @@ -58,7 +65,7 @@ where } } let doc_bitset = BitSetDocSet::from(doc_bitset); - Ok(Box::new(ConstScorer::new(doc_bitset))) + Ok(Box::new(ConstScorer::with_score(doc_bitset, self.boost))) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index d50cc8a25..32c8a19db 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -79,6 +79,7 @@ pub struct FuzzyTermQuery { transposition_cost_one: bool, /// prefix: bool, + boost: f32, } impl FuzzyTermQuery { @@ -89,6 +90,7 @@ impl FuzzyTermQuery { distance, transposition_cost_one, prefix: false, + boost: 1.0, } } @@ -99,16 +101,22 @@ impl FuzzyTermQuery { distance, transposition_cost_one, prefix: true, + boost: 1.0, } } + /// Boost the query score by the given factor. + pub fn boost_by(self, boost: f32) -> Self { + Self { boost, ..self } + } + fn specialized_weight(&self) -> Result> { // LEV_BUILDER is a HashMap, whose `get` method returns an Option match LEV_BUILDER.get(&(self.distance, false)) { // Unwrap the option and build the Ok(AutomatonWeight) Some(automaton_builder) => { let automaton = automaton_builder.build_dfa(self.term.text()); - Ok(AutomatonWeight::new(self.term.field(), automaton)) + Ok(AutomatonWeight::new(self.term.field(), automaton).boost_by(self.boost)) } None => Err(InvalidArgument(format!( "Levenshtein distance of {} is not allowed. Choose a value in the {:?} range", diff --git a/src/query/regex_query.rs b/src/query/regex_query.rs index 2280ba67e..4891635c6 100644 --- a/src/query/regex_query.rs +++ b/src/query/regex_query.rs @@ -54,6 +54,7 @@ use tantivy_fst::Regex; pub struct RegexQuery { regex: Arc, field: Field, + boost: f32, } impl RegexQuery { @@ -69,11 +70,17 @@ impl RegexQuery { RegexQuery { regex: regex.into(), field, + boost: 1.0, } } + /// Boost the query score by the given factor. + pub fn boost_by(self, boost: f32) -> Self { + Self { boost, ..self } + } + fn specialized_weight(&self) -> AutomatonWeight { - AutomatonWeight::new(self.field, self.regex.clone()) + AutomatonWeight::new(self.field, self.regex.clone()).boost_by(self.boost) } } diff --git a/src/query/scorer.rs b/src/query/scorer.rs index a67d4c634..9d43f72f3 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -56,6 +56,11 @@ impl ConstScorer { } } + /// Creates a new `ConstScorer` with a custom score value + pub fn with_score(docset: TDocSet, score: f32) -> ConstScorer { + ConstScorer { docset, score } + } + /// Sets the constant score to a different value. pub fn set_score(&mut self, score: Score) { self.score = score;