diff --git a/CHANGELOG.md b/CHANGELOG.md index c05f14814..90a9e4edf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Tantivy 0.12.0 - Fixed a performance issue when searching for the posting lists of a missing term (@audunhalland) - Added a configurable maximum number of docs (10M by default) for a segment to be considered for merge (@hntd187, landed by @halvorboe #713) - Important Bugfix #777, causing tantivy to retain memory mapping. (diagnosed by @poljar) +- Added support for field boosting. (#547, @fulmicoton) ## How to update? diff --git a/query-grammar/src/query_grammar.rs b/query-grammar/src/query_grammar.rs index fc9a18306..a25638e78 100644 --- a/query-grammar/src/query_grammar.rs +++ b/query-grammar/src/query_grammar.rs @@ -21,8 +21,8 @@ parser! { fn word[I]()(I) -> String where [I: Stream] { ( - satisfy(|c: char| !c.is_whitespace() && !['-', '`', ':', '{', '}', '"', '[', ']', '(',')'].contains(&c) ), - many(satisfy(|c: char| !c.is_whitespace() && ![':', '{', '}', '"', '[', ']', '(',')'].contains(&c))) + satisfy(|c: char| !c.is_whitespace() && !['-', '^', '`', ':', '{', '}', '"', '[', ']', '(',')'].contains(&c) ), + many(satisfy(|c: char| !c.is_whitespace() && ![':', '^', '{', '}', '"', '[', ']', '(',')'].contains(&c))) ) .map(|(s1, s2): (char, String)| format!("{}{}", s1, s2)) .and_then(|s: String| @@ -170,6 +170,48 @@ parser! { } } +parser! { + fn positive_float_number[I]()(I) -> f32 + where [I: Stream] { + ( + many1(digit()), + optional( + (char('.'), many1(digit())) + ) + ) + .map(|(int_part, decimal_part_opt): (String, Option<(char, String)>)| { + let mut float_str = int_part; + if let Some((chr, decimal_str)) = decimal_part_opt { + float_str.push(chr); + float_str.push_str(&decimal_str); + } + float_str.parse::().unwrap() + }) + } +} + +parser! { + fn boost[I]()(I) -> f32 + where [I: Stream] { + (char('^'), positive_float_number()) + .map(|(_, boost)| boost) + } +} + +parser! { + fn boosted_leaf[I]()(I) -> UserInputAST + where [I: Stream] { + (leaf(), optional(boost())) + .map(|(leaf, boost_opt)| + match boost_opt { + Some(boost) if (boost - 1.0).abs() > std::f32::EPSILON => + UserInputAST::Boost(Box::new(leaf), boost), + _ => leaf + } + ) + } +} + #[derive(Clone, Copy)] enum BinaryOperand { Or, @@ -214,10 +256,10 @@ parser! { pub fn ast[I]()(I) -> UserInputAST where [I: Stream] { - let operand_leaf = (binary_operand().skip(spaces()), leaf().skip(spaces())); - let boolean_expr = (leaf().skip(spaces().silent()), many1(operand_leaf)).map( + let operand_leaf = (binary_operand().skip(spaces()), boosted_leaf().skip(spaces())); + let boolean_expr = (boosted_leaf().skip(spaces().silent()), many1(operand_leaf)).map( |(left, right)| aggregate_binary_expressions(left,right)); - let whitespace_separated_leaves = many1(leaf().skip(spaces().silent())) + let whitespace_separated_leaves = many1(boosted_leaf().skip(spaces().silent())) .map(|subqueries: Vec| if subqueries.len() == 1 { subqueries.into_iter().next().unwrap() @@ -243,6 +285,37 @@ mod test { use super::*; use combine::parser::Parser; + pub fn nearly_equals(a: f32, b: f32) -> bool { + (a - b).abs() < 0.0005 * (a + b).abs() + } + + fn assert_nearly_equals(expected: f32, val: f32) { + assert!( + nearly_equals(val, expected), + "Got {}, expected {}.", + val, + expected + ); + } + + #[test] + fn test_positive_float_number() { + fn valid_parse(float_str: &str, expected_val: f32, expected_remaining: &str) { + let (val, remaining) = positive_float_number().parse(float_str).unwrap(); + assert_eq!(remaining, expected_remaining); + assert_nearly_equals(val, expected_val); + } + fn error_parse(float_str: &str) { + assert!(positive_float_number().parse(float_str).is_err()); + } + valid_parse("1.0", 1.0f32, ""); + valid_parse("1", 1.0f32, ""); + valid_parse("0.234234 aaa", 0.234234f32, " aaa"); + error_parse(".3332"); + error_parse("1."); + error_parse("-1."); + } + fn test_parse_query_to_ast_helper(query: &str, expected: &str) { let query = parse_to_ast().parse(query).unwrap().0; let query_str = format!("{:?}", query); @@ -275,6 +348,15 @@ mod test { test_parse_query_to_ast_helper("NOT a", "-(\"a\")"); } + #[test] + fn test_boosting() { + assert!(parse_to_ast().parse("a^2^3").is_err()); + assert!(parse_to_ast().parse("a^2^").is_err()); + test_parse_query_to_ast_helper("a^3", "(\"a\")^3"); + test_parse_query_to_ast_helper("a^3 b^2", "((\"a\")^3 (\"b\")^2)"); + test_parse_query_to_ast_helper("a^1", "\"a\""); + } + #[test] fn test_parse_query_to_ast_binary_op() { test_parse_query_to_ast_helper("a AND b", "(+(\"a\") +(\"b\"))"); diff --git a/query-grammar/src/user_input_ast.rs b/query-grammar/src/user_input_ast.rs index 30452008b..5ff841869 100644 --- a/query-grammar/src/user_input_ast.rs +++ b/query-grammar/src/user_input_ast.rs @@ -88,6 +88,7 @@ pub enum UserInputAST { Clause(Vec), Unary(Occur, Box), Leaf(Box), + Boost(Box, f32), } impl UserInputAST { @@ -154,6 +155,7 @@ impl fmt::Debug for UserInputAST { write!(formatter, "{}({:?})", occur, subquery) } UserInputAST::Leaf(ref subquery) => write!(formatter, "{:?}", subquery), + UserInputAST::Boost(ref leaf, boost) => write!(formatter, "({:?})^{}", leaf, boost), } } } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 6324c0fb4..522d26778 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -84,7 +84,8 @@ impl CustomScorer for ScorerByField { .u64(self.field) .ok_or_else(|| { crate::TantivyError::SchemaError(format!( - "Field requested is not a i64/u64 fast field." + "Field requested ({:?}) is not a i64/u64 fast field.", + self.field )) })?; Ok(ScorerByFastFieldReader { ff_reader }) @@ -614,7 +615,10 @@ mod tests { let top_collector = TopDocs::with_limit(4).order_by_u64_field(size); let err = top_collector.for_segment(0, segment); if let Err(crate::TantivyError::SchemaError(msg)) = err { - assert_eq!(msg, "Field requested is not a i64/u64 fast field."); + assert_eq!( + msg, + "Field requested (Field(1)) is not a i64/u64 fast field." + ); } else { assert!(false); } diff --git a/src/core/searcher.rs b/src/core/searcher.rs index cf473e2ef..8aa808a8e 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -23,7 +23,7 @@ fn collect_segment( segment_ord: u32, segment_reader: &SegmentReader, ) -> crate::Result { - let mut scorer = weight.scorer(segment_reader)?; + let mut scorer = weight.scorer(segment_reader, 1.0f32)?; let mut segment_collector = collector.for_segment(segment_ord as u32, segment_reader)?; if let Some(delete_bitset) = segment_reader.delete_bitset() { scorer.for_each(&mut |doc, score| { diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 178e6a7b5..fb9380dd8 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -1,6 +1,7 @@ use crate::core::Searcher; use crate::core::SegmentReader; use crate::docset::DocSet; +use crate::query::boost_query::BoostScorer; use crate::query::explanation::does_not_match; use crate::query::{Explanation, Query, Scorer, Weight}; use crate::DocId; @@ -22,12 +23,13 @@ impl Query for AllQuery { pub struct AllWeight; impl Weight for AllWeight { - fn scorer(&self, reader: &SegmentReader) -> crate::Result> { - Ok(Box::new(AllScorer { + fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result> { + let all_scorer = AllScorer { state: State::NotStarted, doc: 0u32, max_doc: reader.max_doc(), - })) + }; + Ok(Box::new(BoostScorer::new(all_scorer, boost))) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { @@ -90,14 +92,12 @@ impl Scorer for AllScorer { #[cfg(test)] mod tests { - use super::AllQuery; use crate::query::Query; use crate::schema::{Schema, TEXT}; use crate::Index; - #[test] - fn test_all_query() { + fn create_test_index() -> Index { let mut schema_builder = Schema::builder(); let field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); @@ -108,13 +108,18 @@ mod tests { index_writer.commit().unwrap(); index_writer.add_document(doc!(field=>"ccc")); index_writer.commit().unwrap(); + index + } + + #[test] + fn test_all_query() { + let index = create_test_index(); let reader = index.reader().unwrap(); - reader.reload().unwrap(); let searcher = reader.searcher(); let weight = AllQuery.weight(&searcher, false).unwrap(); { let reader = searcher.segment_reader(0); - let mut scorer = weight.scorer(reader).unwrap(); + let mut scorer = weight.scorer(reader, 1.0f32).unwrap(); assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); assert!(scorer.advance()); @@ -123,10 +128,31 @@ mod tests { } { let reader = searcher.segment_reader(1); - let mut scorer = weight.scorer(reader).unwrap(); + let mut scorer = weight.scorer(reader, 1.0f32).unwrap(); assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); assert!(!scorer.advance()); } } + + #[test] + fn test_all_query_with_boost() { + let index = create_test_index(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let weight = AllQuery.weight(&searcher, false).unwrap(); + let reader = searcher.segment_reader(0); + { + let mut scorer = weight.scorer(reader, 2.0f32).unwrap(); + assert!(scorer.advance()); + assert_eq!(scorer.doc(), 0u32); + assert_eq!(scorer.score(), 2.0f32); + } + { + let mut scorer = weight.scorer(reader, 1.5f32).unwrap(); + assert!(scorer.advance()); + assert_eq!(scorer.doc(), 0u32); + assert_eq!(scorer.score(), 1.5f32); + } + } } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index be012680f..5f315190f 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -40,7 +40,7 @@ impl Weight for AutomatonWeight where A: Automaton + Send + Sync + 'static, { - fn scorer(&self, reader: &SegmentReader) -> Result> { + fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result> { let max_doc = reader.max_doc(); let mut doc_bitset = BitSet::with_max_value(max_doc); @@ -58,11 +58,12 @@ where } } let doc_bitset = BitSetDocSet::from(doc_bitset); - Ok(Box::new(ConstScorer::new(doc_bitset))) + let const_scorer = ConstScorer::new(doc_bitset, boost); + Ok(Box::new(const_scorer)) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { - let mut scorer = self.scorer(reader)?; + let mut scorer = self.scorer(reader, 1.0f32)?; if scorer.skip_next(doc) == SkipResult::Reached { Ok(Explanation::new("AutomatonScorer", 1.0f32)) } else { @@ -72,3 +73,95 @@ where } } } + +#[cfg(test)] +mod tests { + use super::AutomatonWeight; + use crate::query::Weight; + use crate::schema::{Schema, STRING}; + use crate::Index; + use tantivy_fst::Automaton; + + fn create_index() -> Index { + let mut schema = Schema::builder(); + let title = schema.add_text_field("title", STRING); + let index = Index::create_in_ram(schema.build()); + let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + index_writer.add_document(doc!(title=>"abc")); + index_writer.add_document(doc!(title=>"bcd")); + index_writer.add_document(doc!(title=>"abcd")); + assert!(index_writer.commit().is_ok()); + index + } + + enum State { + Start, + NotMatching, + AfterA, + } + + struct PrefixedByA; + + impl Automaton for PrefixedByA { + type State = State; + + fn start(&self) -> Self::State { + State::Start + } + + fn is_match(&self, state: &Self::State) -> bool { + match *state { + State::AfterA => true, + _ => false, + } + } + + fn accept(&self, state: &Self::State, byte: u8) -> Self::State { + match *state { + State::Start => { + if byte == b'a' { + State::AfterA + } else { + State::NotMatching + } + } + State::AfterA => State::AfterA, + State::NotMatching => State::NotMatching, + } + } + } + + #[test] + fn test_automaton_weight() { + let index = create_index(); + let field = index.schema().get_field("title").unwrap(); + let automaton_weight = AutomatonWeight::new(field, PrefixedByA); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let mut scorer = automaton_weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); + assert!(scorer.advance()); + assert_eq!(scorer.doc(), 0u32); + assert_eq!(scorer.score(), 1.0f32); + assert!(scorer.advance()); + assert_eq!(scorer.doc(), 2u32); + assert_eq!(scorer.score(), 1.0f32); + assert!(!scorer.advance()); + } + + #[test] + fn test_automaton_weight_boost() { + let index = create_index(); + let field = index.schema().get_field("title").unwrap(); + let automaton_weight = AutomatonWeight::new(field, PrefixedByA); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let mut scorer = automaton_weight + .scorer(searcher.segment_reader(0u32), 1.32f32) + .unwrap(); + assert!(scorer.advance()); + assert_eq!(scorer.doc(), 0u32); + assert_eq!(scorer.score(), 1.32f32); + } +} diff --git a/src/query/bm25.rs b/src/query/bm25.rs index 9a90b95f7..48f84fece 100644 --- a/src/query/bm25.rs +++ b/src/query/bm25.rs @@ -25,7 +25,6 @@ fn compute_tf_cache(average_fieldnorm: f32) -> [f32; 256] { cache } -#[derive(Clone)] pub struct BM25Weight { idf_explain: Explanation, weight: f32, @@ -34,6 +33,15 @@ pub struct BM25Weight { } impl BM25Weight { + pub fn boost_by(&self, boost: f32) -> BM25Weight { + BM25Weight { + idf_explain: self.idf_explain.clone(), + weight: self.weight * boost, + cache: self.cache, + average_fieldnorm: self.average_fieldnorm, + } + } + pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> BM25Weight { assert!(!terms.is_empty(), "BM25 requires at least one term"); let field = terms[0].field(); diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 334bbee23..f759db3aa 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -55,10 +55,11 @@ impl BooleanWeight { fn per_occur_scorers( &self, reader: &SegmentReader, + boost: f32, ) -> crate::Result>>> { let mut per_occur_scorers: HashMap>> = HashMap::new(); for &(ref occur, ref subweight) in &self.weights { - let sub_scorer: Box = subweight.scorer(reader)?; + let sub_scorer: Box = subweight.scorer(reader, boost)?; per_occur_scorers .entry(*occur) .or_insert_with(Vec::new) @@ -70,8 +71,9 @@ impl BooleanWeight { fn complex_scorer( &self, reader: &SegmentReader, + boost: f32, ) -> crate::Result> { - let mut per_occur_scorers = self.per_occur_scorers(reader)?; + let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; let should_scorer_opt: Option> = per_occur_scorers .remove(&Occur::Should) @@ -112,7 +114,7 @@ impl BooleanWeight { } impl Weight for BooleanWeight { - fn scorer(&self, reader: &SegmentReader) -> crate::Result> { + fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result> { if self.weights.is_empty() { Ok(Box::new(EmptyScorer)) } else if self.weights.len() == 1 { @@ -120,17 +122,17 @@ impl Weight for BooleanWeight { if occur == Occur::MustNot { Ok(Box::new(EmptyScorer)) } else { - weight.scorer(reader) + weight.scorer(reader, boost) } } else if self.scoring_enabled { - self.complex_scorer::(reader) + self.complex_scorer::(reader, boost) } else { - self.complex_scorer::(reader) + self.complex_scorer::(reader, boost) } } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { - let mut scorer = self.scorer(reader)?; + let mut scorer = self.scorer(reader, 1.0f32)?; if scorer.skip_next(doc) != SkipResult::Reached { return Err(does_not_match(doc)); } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 4b90729af..f61072b00 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -18,6 +18,7 @@ mod tests { use crate::query::Scorer; use crate::query::TermQuery; use crate::schema::*; + use crate::tests::assert_nearly_equals; use crate::Index; use crate::{DocAddress, DocId}; @@ -70,7 +71,9 @@ mod tests { let query = query_parser.parse_query("+a").unwrap(); let searcher = index.reader().unwrap().searcher(); let weight = query.weight(&searcher, true).unwrap(); - let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + let scorer = weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); assert!(scorer.is::()); } @@ -82,13 +85,17 @@ mod tests { { let query = query_parser.parse_query("+a +b +c").unwrap(); let weight = query.weight(&searcher, true).unwrap(); - let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + let scorer = weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); assert!(scorer.is::>()); } { let query = query_parser.parse_query("+a +(b c)").unwrap(); let weight = query.weight(&searcher, true).unwrap(); - let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + let scorer = weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); assert!(scorer.is::>>()); } } @@ -101,7 +108,9 @@ mod tests { { let query = query_parser.parse_query("+a b").unwrap(); let weight = query.weight(&searcher, true).unwrap(); - let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + let scorer = weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); assert!(scorer.is::, Box, @@ -111,7 +120,9 @@ mod tests { { let query = query_parser.parse_query("+a b").unwrap(); let weight = query.weight(&searcher, false).unwrap(); - let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + let scorer = weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); assert!(scorer.is::()); } } @@ -179,6 +190,50 @@ mod tests { } } + #[test] + pub fn test_boolean_query_with_weight() { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + index_writer.add_document(doc!(text_field => "a b c")); + index_writer.add_document(doc!(text_field => "a c")); + index_writer.add_document(doc!(text_field => "b c")); + assert!(index_writer.commit().is_ok()); + } + let term_a: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "a"), + IndexRecordOption::WithFreqs, + )); + let term_b: Box = Box::new(TermQuery::new( + Term::from_field_text(text_field, "b"), + IndexRecordOption::WithFreqs, + )); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let boolean_query = + BooleanQuery::from(vec![(Occur::Should, term_a), (Occur::Should, term_b)]); + let boolean_weight = boolean_query.weight(&searcher, true).unwrap(); + { + let mut boolean_scorer = boolean_weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); + assert!(boolean_scorer.advance()); + assert_eq!(boolean_scorer.doc(), 0u32); + assert_nearly_equals(boolean_scorer.score(), 0.84163445f32); + } + { + let mut boolean_scorer = boolean_weight + .scorer(searcher.segment_reader(0u32), 2.0f32) + .unwrap(); + assert!(boolean_scorer.advance()); + assert_eq!(boolean_scorer.doc(), 0u32); + assert_nearly_equals(boolean_scorer.score(), 1.6832689f32); + } + } + #[test] pub fn test_intersection_score() { let (index, text_field) = aux_test_helper(); @@ -249,7 +304,9 @@ mod tests { let query_parser = QueryParser::for_index(&index, vec![title, text]); let query = query_parser.parse_query("Оксана Лифенко").unwrap(); let weight = query.weight(&searcher, true).unwrap(); - let mut scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap(); + let mut scorer = weight + .scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap(); scorer.advance(); let explanation = query.explain(&searcher, DocAddress(0u32, 0u32)).unwrap(); diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs new file mode 100644 index 000000000..eda5431a0 --- /dev/null +++ b/src/query/boost_query.rs @@ -0,0 +1,164 @@ +use crate::common::BitSet; +use crate::fastfield::DeleteBitSet; +use crate::query::explanation::does_not_match; +use crate::query::{Explanation, Query, Scorer, Weight}; +use crate::{DocId, DocSet, Searcher, SegmentReader, SkipResult, Term}; +use std::collections::BTreeSet; +use std::fmt; + +/// `BoostQuery` is a wrapper over a query used to boost its score. +/// +/// The document set matched by the `BoostQuery` is strictly the same as the underlying query. +/// The score of each document, is the score of the underlying query multiplied by the `boost` +/// factor. +pub struct BoostQuery { + query: Box, + boost: f32, +} + +impl BoostQuery { + /// Builds a boost query. + pub fn new(query: Box, boost: f32) -> BoostQuery { + BoostQuery { query, boost } + } +} + +impl Clone for BoostQuery { + fn clone(&self) -> Self { + BoostQuery { + query: self.query.box_clone(), + boost: self.boost, + } + } +} + +impl fmt::Debug for BoostQuery { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Boost(query={:?}, boost={})", self.query, self.boost) + } +} + +impl Query for BoostQuery { + fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result> { + let weight_without_boost = self.query.weight(searcher, scoring_enabled)?; + let boosted_weight = if scoring_enabled { + Box::new(BoostWeight::new(weight_without_boost, self.boost)) + } else { + weight_without_boost + }; + Ok(boosted_weight) + } + + fn query_terms(&self, term_set: &mut BTreeSet) { + self.query.query_terms(term_set) + } +} + +pub(crate) struct BoostWeight { + weight: Box, + boost: f32, +} + +impl BoostWeight { + pub fn new(weight: Box, boost: f32) -> Self { + BoostWeight { weight, boost } + } +} + +impl Weight for BoostWeight { + fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result> { + self.weight.scorer(reader, boost * self.boost) + } + + fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result { + let mut scorer = self.scorer(reader, 1.0f32)?; + if scorer.skip_next(doc) != SkipResult::Reached { + return Err(does_not_match(doc)); + } + let mut explanation = + Explanation::new(format!("Boost x{} of ...", self.boost), scorer.score()); + let underlying_explanation = self.weight.explain(reader, doc)?; + explanation.add_detail(underlying_explanation); + Ok(explanation) + } + + fn count(&self, reader: &SegmentReader) -> crate::Result { + self.weight.count(reader) + } +} + +pub(crate) struct BoostScorer { + underlying: S, + boost: f32, +} + +impl BoostScorer { + pub fn new(underlying: S, boost: f32) -> BoostScorer { + BoostScorer { underlying, boost } + } +} + +impl DocSet for BoostScorer { + fn advance(&mut self) -> bool { + self.underlying.advance() + } + + fn skip_next(&mut self, target: DocId) -> SkipResult { + self.underlying.skip_next(target) + } + + fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { + self.underlying.fill_buffer(buffer) + } + + fn doc(&self) -> u32 { + self.underlying.doc() + } + + fn size_hint(&self) -> u32 { + self.underlying.size_hint() + } + + fn append_to_bitset(&mut self, bitset: &mut BitSet) { + self.underlying.append_to_bitset(bitset) + } + + fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 { + self.underlying.count(delete_bitset) + } + + fn count_including_deleted(&mut self) -> u32 { + self.underlying.count_including_deleted() + } +} + +impl Scorer for BoostScorer { + fn score(&mut self) -> f32 { + self.underlying.score() * self.boost + } +} + +#[cfg(test)] +mod tests { + use super::BoostQuery; + use crate::query::{AllQuery, Query}; + use crate::schema::Schema; + use crate::{DocAddress, Document, Index}; + + #[test] + fn test_boost_query_explain() { + let schema = Schema::builder().build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + index_writer.add_document(Document::new()); + assert!(index_writer.commit().is_ok()); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let query = BoostQuery::new(Box::new(AllQuery), 0.2); + let explanation = query.explain(&searcher, DocAddress(0, 0u32)).unwrap(); + assert_eq!( + explanation.to_pretty_json(), + "{\n \"value\": 0.2,\n \"description\": \"Boost x0.2 of ...\",\n \"details\": [\n {\n \"value\": 1.0,\n \"description\": \"AllQuery\"\n }\n ]\n}" + ) + } +} diff --git a/src/query/empty_query.rs b/src/query/empty_query.rs index d5ee37f1f..76932070e 100644 --- a/src/query/empty_query.rs +++ b/src/query/empty_query.rs @@ -33,7 +33,7 @@ impl Query for EmptyQuery { /// It is useful for tests and handling edge cases. pub struct EmptyWeight; impl Weight for EmptyWeight { - fn scorer(&self, _reader: &SegmentReader) -> crate::Result> { + fn scorer(&self, _reader: &SegmentReader, _boost: f32) -> crate::Result> { Ok(Box::new(EmptyScorer)) } diff --git a/src/query/mod.rs b/src/query/mod.rs index 82653cb81..2187e415a 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -7,6 +7,7 @@ mod automaton_weight; mod bitset; mod bm25; mod boolean_query; +mod boost_query; mod empty_query; mod exclude; mod explanation; @@ -37,6 +38,7 @@ pub use self::all_query::{AllQuery, AllScorer, AllWeight}; pub use self::automaton_weight::AutomatonWeight; pub use self::bitset::BitSetDocSet; pub use self::boolean_query::BooleanQuery; +pub use self::boost_query::BoostQuery; pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight}; pub use self::exclude::Exclude; pub use self::explanation::Explanation; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index d43b65f31..fbe10f597 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -7,7 +7,7 @@ pub use self::phrase_scorer::PhraseScorer; pub use self::phrase_weight::PhraseWeight; #[cfg(test)] -mod tests { +pub mod tests { use super::*; use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE}; @@ -15,10 +15,10 @@ mod tests { use crate::error::TantivyError; use crate::schema::{Schema, Term, TEXT}; use crate::tests::assert_nearly_equals; + use crate::DocAddress; use crate::DocId; - use crate::{DocAddress, DocSet}; - fn create_index(texts: &[&'static str]) -> Index { + pub fn create_index(texts: &[&'static str]) -> Index { let mut schema_builder = Schema::builder(); let text_field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); @@ -102,30 +102,6 @@ mod tests { assert!(test_query(vec!["g", "a"]).is_empty()); } - #[test] - pub fn test_phrase_count() { - let index = create_index(&["a c", "a a b d a b c", " a b"]); - let schema = index.schema(); - let text_field = schema.get_field("text").unwrap(); - let searcher = index.reader().unwrap().searcher(); - let phrase_query = PhraseQuery::new(vec![ - Term::from_field_text(text_field, "a"), - Term::from_field_text(text_field, "b"), - ]); - let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap(); - let mut phrase_scorer = phrase_weight - .phrase_scorer(searcher.segment_reader(0u32)) - .unwrap() - .unwrap(); - assert!(phrase_scorer.advance()); - assert_eq!(phrase_scorer.doc(), 1); - assert_eq!(phrase_scorer.phrase_count(), 2); - assert!(phrase_scorer.advance()); - assert_eq!(phrase_scorer.doc(), 2); - assert_eq!(phrase_scorer.phrase_count(), 1); - assert!(!phrase_scorer.advance()); - } - #[test] pub fn test_phrase_query_no_positions() { let mut schema_builder = Schema::builder(); diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 1fea04ec8..f82ca2288 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -37,11 +37,12 @@ impl PhraseWeight { reader.get_fieldnorms_reader(field) } - pub fn phrase_scorer( + fn phrase_scorer( &self, reader: &SegmentReader, + boost: f32, ) -> Result>> { - let similarity_weight = self.similarity_weight.clone(); + let similarity_weight = self.similarity_weight.boost_by(boost); let fieldnorm_reader = self.fieldnorm_reader(reader); if reader.has_deletes() { let mut term_postings_list = Vec::new(); @@ -84,8 +85,8 @@ impl PhraseWeight { } impl Weight for PhraseWeight { - fn scorer(&self, reader: &SegmentReader) -> Result> { - if let Some(scorer) = self.phrase_scorer(reader)? { + fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result> { + if let Some(scorer) = self.phrase_scorer(reader, boost)? { Ok(Box::new(scorer)) } else { Ok(Box::new(EmptyScorer)) @@ -93,7 +94,7 @@ impl Weight for PhraseWeight { } fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { - let scorer_opt = self.phrase_scorer(reader)?; + let scorer_opt = self.phrase_scorer(reader, 1.0f32)?; if scorer_opt.is_none() { return Err(does_not_match(doc)); } @@ -109,3 +110,34 @@ impl Weight for PhraseWeight { Ok(explanation) } } + +#[cfg(test)] +mod tests { + use super::super::tests::create_index; + use crate::query::PhraseQuery; + use crate::{DocSet, Term}; + + #[test] + pub fn test_phrase_count() { + let index = create_index(&["a c", "a a b d a b c", " a b"]); + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader().unwrap().searcher(); + let phrase_query = PhraseQuery::new(vec![ + Term::from_field_text(text_field, "a"), + Term::from_field_text(text_field, "b"), + ]); + let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0f32) + .unwrap() + .unwrap(); + assert!(phrase_scorer.advance()); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert!(phrase_scorer.advance()); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert!(!phrase_scorer.advance()); + } +} diff --git a/src/query/query_parser/logical_ast.rs b/src/query/query_parser/logical_ast.rs index 299aa1241..6e2e54c21 100644 --- a/src/query/query_parser/logical_ast.rs +++ b/src/query/query_parser/logical_ast.rs @@ -21,6 +21,17 @@ pub enum LogicalLiteral { pub enum LogicalAST { Clause(Vec<(Occur, LogicalAST)>), Leaf(Box), + Boost(Box, f32), +} + +impl LogicalAST { + pub fn boost(self, boost: f32) -> LogicalAST { + if (boost - 1.0f32).abs() < std::f32::EPSILON { + self + } else { + LogicalAST::Boost(Box::new(self), boost) + } + } } fn occur_letter(occur: Occur) -> &'static str { @@ -47,6 +58,7 @@ impl fmt::Debug for LogicalAST { } Ok(()) } + LogicalAST::Boost(ref ast, boost) => write!(formatter, "{:?}^{}", ast, boost), LogicalAST::Leaf(ref literal) => write!(formatter, "{:?}", literal), } } diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index bd9dc0869..6055cc2ce 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -1,6 +1,5 @@ use super::logical_ast::*; use crate::core::Index; -use crate::query::AllQuery; use crate::query::BooleanQuery; use crate::query::EmptyQuery; use crate::query::Occur; @@ -8,11 +7,13 @@ use crate::query::PhraseQuery; use crate::query::Query; use crate::query::RangeQuery; use crate::query::TermQuery; +use crate::query::{AllQuery, BoostQuery}; use crate::schema::{Facet, IndexRecordOption}; use crate::schema::{Field, Schema}; use crate::schema::{FieldType, Term}; use crate::tokenizer::TokenizerManager; use std::borrow::Cow; +use std::collections::HashMap; use std::num::{ParseFloatError, ParseIntError}; use std::ops::Bound; use std::str::FromStr; @@ -144,7 +145,6 @@ fn trim_ast(logical_ast: LogicalAST) -> Option { /// /// * must terms: By prepending a term by a `+`, a term can be made required for the search. /// -/// /// * phrase terms: Quoted terms become phrase searches on fields that have positions indexed. /// e.g., `title:"Barack Obama"` will only find documents that have "barack" immediately followed /// by "obama". @@ -158,12 +158,20 @@ fn trim_ast(logical_ast: LogicalAST) -> Option { /// /// * all docs query: A plain `*` will match all documents in the index. /// +/// Parts of the queries can be boosted by appending `^boostfactor`. +/// For instance, `"SRE"^2.0 OR devops^0.4` will boost documents containing `SRE` instead of +/// devops. Negative boosts are not allowed. +/// +/// It is also possible to define a boost for a some specific field, at the query parser level. +/// (See [`set_boost(...)`](#method.set_field_boost) ). Typically you may want to boost a title +/// field. #[derive(Clone)] pub struct QueryParser { schema: Schema, default_fields: Vec, conjunction_by_default: bool, tokenizer_manager: TokenizerManager, + boost: HashMap, } impl QueryParser { @@ -181,6 +189,7 @@ impl QueryParser { default_fields, tokenizer_manager, conjunction_by_default: false, + boost: Default::default(), } } @@ -201,6 +210,17 @@ impl QueryParser { self.conjunction_by_default = true; } + /// Sets a boost for a specific field. + /// + /// The parse query will automatically boost this field. + /// + /// If the query defines a query boost through the query language (e.g: `country:France^3.0`), + /// the two boosts (the one defined in the query, and the one defined in the `QueryParser`) + /// are multiplied together. + pub fn set_field_boost(&mut self, field: Field, boost: f32) { + self.boost.insert(field, boost); + } + /// Parse a query /// /// Note that `parse_query` returns an error if the input @@ -407,6 +427,10 @@ impl QueryParser { self.compute_logical_ast_with_occur(*subquery)?; Ok((Occur::compose(left_occur, right_occur), logical_sub_queries)) } + UserInputAST::Boost(ast, boost) => { + let (occur, ast_without_occur) = self.compute_logical_ast_with_occur(*ast)?; + Ok((occur, ast_without_occur.boost(boost))) + } UserInputAST::Leaf(leaf) => { let result_ast = self.compute_logical_ast_from_leaf(*leaf)?; Ok((Occur::Should, result_ast)) @@ -414,6 +438,10 @@ impl QueryParser { } } + fn field_boost(&self, field: Field) -> f32 { + self.boost.get(&field).cloned().unwrap_or(1.0f32) + } + fn compute_logical_ast_from_leaf( &self, leaf: UserInputLeaf, @@ -439,7 +467,9 @@ impl QueryParser { let mut asts: Vec = Vec::new(); for (field, phrase) in term_phrases { if let Some(ast) = self.compute_logical_ast_for_leaf(field, &phrase)? { - asts.push(LogicalAST::Leaf(Box::new(ast))); + // Apply some field specific boost defined at the query parser level. + let boost = self.field_boost(field); + asts.push(LogicalAST::Leaf(Box::new(ast)).boost(boost)); } } let result_ast: LogicalAST = if asts.len() == 1 { @@ -459,14 +489,16 @@ impl QueryParser { let mut clauses = fields .iter() .map(|&field| { + let boost = self.field_boost(field); let field_entry = self.schema.get_field_entry(field); let value_type = field_entry.field_type().value_type(); - Ok(LogicalAST::Leaf(Box::new(LogicalLiteral::Range { + let logical_ast = LogicalAST::Leaf(Box::new(LogicalLiteral::Range { field, value_type, lower: self.resolve_bound(field, &lower)?, upper: self.resolve_bound(field, &upper)?, - }))) + })); + Ok(logical_ast.boost(boost)) }) .collect::, QueryParserError>>()?; let result_ast = if clauses.len() == 1 { @@ -519,6 +551,11 @@ fn convert_to_query(logical_ast: LogicalAST) -> Box { Some(LogicalAST::Leaf(trimmed_logical_literal)) => { convert_literal_to_query(*trimmed_logical_literal) } + Some(LogicalAST::Boost(ast, boost)) => { + let query = convert_to_query(*ast); + let boosted_query = BoostQuery::new(query, boost); + Box::new(boosted_query) + } None => Box::new(EmptyQuery), } } @@ -538,7 +575,7 @@ mod test { use crate::Index; use matches::assert_matches; - fn make_query_parser() -> QueryParser { + fn make_schema() -> Schema { let mut schema_builder = Schema::builder(); let text_field_indexing = TextFieldIndexing::default() .set_tokenizer("en_with_stop_words") @@ -546,8 +583,8 @@ mod test { let text_options = TextOptions::default() .set_indexing_options(text_field_indexing) .set_stored(); - let title = schema_builder.add_text_field("title", TEXT); - let text = schema_builder.add_text_field("text", TEXT); + schema_builder.add_text_field("title", TEXT); + schema_builder.add_text_field("text", TEXT); schema_builder.add_i64_field("signed", INDEXED); schema_builder.add_u64_field("unsigned", INDEXED); schema_builder.add_text_field("notindexed_text", STORED); @@ -558,8 +595,15 @@ mod test { schema_builder.add_date_field("date", INDEXED); schema_builder.add_f64_field("float", INDEXED); schema_builder.add_facet_field("facet"); - let schema = schema_builder.build(); - let default_fields = vec![title, text]; + schema_builder.build() + } + + fn make_query_parser() -> QueryParser { + let schema = make_schema(); + let default_fields: Vec = vec!["title", "text"] + .into_iter() + .flat_map(|field_name| schema.get_field(field_name)) + .collect(); let tokenizer_manager = TokenizerManager::default(); tokenizer_manager.register( "en_with_stop_words", @@ -601,6 +645,45 @@ mod test { ); } + #[test] + pub fn test_parse_query_with_boost() { + let mut query_parser = make_query_parser(); + let schema = make_schema(); + let text_field = schema.get_field("text").unwrap(); + query_parser.set_field_boost(text_field, 2.0f32); + let query = query_parser.parse_query("text:hello").unwrap(); + assert_eq!( + format!("{:?}", query), + "Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2)" + ); + } + + #[test] + pub fn test_parse_query_range_with_boost() { + let mut query_parser = make_query_parser(); + let schema = make_schema(); + let title_field = schema.get_field("title").unwrap(); + query_parser.set_field_boost(title_field, 2.0f32); + let query = query_parser.parse_query("title:[A TO B]").unwrap(); + assert_eq!( + format!("{:?}", query), + "Boost(query=RangeQuery { field: Field(0), value_type: Str, left_bound: Included([97]), right_bound: Included([98]) }, boost=2)" + ); + } + + #[test] + pub fn test_parse_query_with_default_boost_and_custom_boost() { + let mut query_parser = make_query_parser(); + let schema = make_schema(); + let text_field = schema.get_field("text").unwrap(); + query_parser.set_field_boost(text_field, 2.0f32); + let query = query_parser.parse_query("text:hello^2").unwrap(); + assert_eq!( + format!("{:?}", query), + "Boost(query=Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2), boost=2)" + ); + } + #[test] pub fn test_parse_nonindexed_field_yields_error() { let query_parser = make_query_parser(); diff --git a/src/query/range_query.rs b/src/query/range_query.rs index c0d5afff9..0440b7e0c 100644 --- a/src/query/range_query.rs +++ b/src/query/range_query.rs @@ -289,7 +289,7 @@ impl RangeWeight { } impl Weight for RangeWeight { - fn scorer(&self, reader: &SegmentReader) -> Result> { + fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result> { let max_doc = reader.max_doc(); let mut doc_bitset = BitSet::with_max_value(max_doc); @@ -307,11 +307,11 @@ impl Weight for RangeWeight { } } let doc_bitset = BitSetDocSet::from(doc_bitset); - Ok(Box::new(ConstScorer::new(doc_bitset))) + Ok(Box::new(ConstScorer::new(doc_bitset, boost))) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { - let mut scorer = self.scorer(reader)?; + let mut scorer = self.scorer(reader, 1.0f32)?; if scorer.skip_next(doc) != SkipResult::Reached { return Err(does_not_match(doc)); } diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index b85a3ee88..16c15a198 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -115,8 +115,8 @@ mod tests { let req = vec![1, 3, 7]; let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> = RequiredOptionalScorer::new( - ConstScorer::new(VecDocSet::from(req.clone())), - ConstScorer::new(VecDocSet::from(vec![])), + ConstScorer::from(VecDocSet::from(req.clone())), + ConstScorer::from(VecDocSet::from(vec![])), ); let mut docs = vec![]; while reqoptscorer.advance() { @@ -129,8 +129,8 @@ mod tests { fn test_reqopt_scorer() { let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> = RequiredOptionalScorer::new( - ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15])), - ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15])), + ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15]), 1.0f32), + ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15]), 1.0f32), ); { assert!(reqoptscorer.advance()); @@ -183,8 +183,8 @@ mod tests { test_skip_against_unoptimized( || { Box::new(RequiredOptionalScorer::<_, _, DoNothingCombiner>::new( - ConstScorer::new(VecDocSet::from(req_docs.clone())), - ConstScorer::new(VecDocSet::from(opt_docs.clone())), + ConstScorer::from(VecDocSet::from(req_docs.clone())), + ConstScorer::from(VecDocSet::from(opt_docs.clone())), )) }, skip_docs, diff --git a/src/query/scorer.rs b/src/query/scorer.rs index a67d4c634..02a4fb021 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -49,16 +49,14 @@ pub struct ConstScorer { impl ConstScorer { /// Creates a new `ConstScorer`. - pub fn new(docset: TDocSet) -> ConstScorer { - ConstScorer { - docset, - score: 1f32, - } + pub fn new(docset: TDocSet, score: f32) -> ConstScorer { + ConstScorer { docset, score } } +} - /// Sets the constant score to a different value. - pub fn set_score(&mut self, score: Score) { - self.score = score; +impl From for ConstScorer { + fn from(docset: TDocSet) -> Self { + ConstScorer::new(docset, 1.0f32) } } @@ -90,6 +88,6 @@ impl DocSet for ConstScorer { impl Scorer for ConstScorer { fn score(&mut self) -> Score { - 1f32 + self.score } } diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index d5a29f9fd..d38756b7e 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -39,7 +39,7 @@ mod tests { ); let term_weight = term_query.weight(&searcher, true).unwrap(); let segment_reader = searcher.segment_reader(0); - let mut term_scorer = term_weight.scorer(segment_reader).unwrap(); + let mut term_scorer = term_weight.scorer(segment_reader, 1.0f32).unwrap(); assert!(term_scorer.advance()); assert_eq!(term_scorer.doc(), 0); assert_eq!(term_scorer.score(), 0.28768212); diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 7da8dbf78..e7d47847e 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -18,13 +18,13 @@ pub struct TermWeight { } impl Weight for TermWeight { - fn scorer(&self, reader: &SegmentReader) -> Result> { - let term_scorer = self.scorer_specialized(reader)?; + fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result> { + let term_scorer = self.scorer_specialized(reader, boost)?; Ok(Box::new(term_scorer)) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { - let mut scorer = self.scorer_specialized(reader)?; + let mut scorer = self.scorer_specialized(reader, 1.0f32)?; if scorer.skip_next(doc) != SkipResult::Reached { return Err(does_not_match(doc)); } @@ -33,7 +33,7 @@ impl Weight for TermWeight { fn count(&self, reader: &SegmentReader) -> Result { if let Some(delete_bitset) = reader.delete_bitset() { - Ok(self.scorer(reader)?.count(delete_bitset)) + Ok(self.scorer(reader, 1.0f32)?.count(delete_bitset)) } else { let field = self.term.field(); Ok(reader @@ -58,11 +58,11 @@ impl TermWeight { } } - fn scorer_specialized(&self, reader: &SegmentReader) -> Result { + fn scorer_specialized(&self, reader: &SegmentReader, boost: f32) -> Result { let field = self.term.field(); let inverted_index = reader.inverted_index(field); let fieldnorm_reader = reader.get_fieldnorms_reader(field); - let similarity_weight = self.similarity_weight.clone(); + let similarity_weight = self.similarity_weight.boost_by(boost); let postings_opt: Option = inverted_index.read_postings(&self.term, self.index_record_option); if let Some(segment_postings) = postings_opt { diff --git a/src/query/union.rs b/src/query/union.rs index 0a7d7bf91..7e27ac877 100644 --- a/src/query/union.rs +++ b/src/query/union.rs @@ -145,26 +145,6 @@ where } } - fn count_including_deleted(&mut self) -> u32 { - let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS] - .iter() - .map(|bitset| bitset.len()) - .sum::(); - for bitset in self.bitsets.iter_mut() { - bitset.clear(); - } - while self.refill() { - count += self.bitsets.iter().map(|bitset| bitset.len()).sum::(); - for bitset in self.bitsets.iter_mut() { - bitset.clear(); - } - } - self.cursor = HORIZON_NUM_TINYBITSETS; - count - } - - // TODO implement `count` efficiently. - fn skip_next(&mut self, target: DocId) -> SkipResult { if !self.advance() { return SkipResult::End; @@ -243,6 +223,8 @@ where } } + // TODO implement `count` efficiently. + fn doc(&self) -> DocId { self.doc } @@ -250,6 +232,24 @@ where fn size_hint(&self) -> u32 { 0u32 } + + fn count_including_deleted(&mut self) -> u32 { + let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS] + .iter() + .map(|bitset| bitset.len()) + .sum::(); + for bitset in self.bitsets.iter_mut() { + bitset.clear(); + } + while self.refill() { + count += self.bitsets.iter().map(|bitset| bitset.len()).sum::(); + for bitset in self.bitsets.iter_mut() { + bitset.clear(); + } + } + self.cursor = HORIZON_NUM_TINYBITSETS; + count + } } impl Scorer for Union @@ -290,7 +290,7 @@ mod tests { vals.iter() .cloned() .map(VecDocSet::from) - .map(ConstScorer::new) + .map(|docset| ConstScorer::new(docset, 1.0f32)) .collect::>>(), ) }; @@ -339,7 +339,7 @@ mod tests { .iter() .map(|docs| docs.clone()) .map(VecDocSet::from) - .map(ConstScorer::new) + .map(|docset| ConstScorer::new(docset, 1.0f32)) .collect::>(), )); res @@ -369,8 +369,8 @@ mod tests { #[test] fn test_union_skip_corner_case3() { let mut docset = Union::<_, DoNothingCombiner>::from(vec![ - ConstScorer::new(VecDocSet::from(vec![0u32, 5u32])), - ConstScorer::new(VecDocSet::from(vec![1u32, 4u32])), + ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), + ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), ]); assert!(docset.advance()); assert_eq!(docset.doc(), 0u32); diff --git a/src/query/weight.rs b/src/query/weight.rs index c51fea554..821cebfd7 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -1,7 +1,7 @@ use super::Scorer; use crate::core::SegmentReader; use crate::query::Explanation; -use crate::{DocId, Result}; +use crate::DocId; /// A Weight is the specialization of a Query /// for a given set of segments. @@ -9,15 +9,18 @@ use crate::{DocId, Result}; /// See [`Query`](./trait.Query.html). pub trait Weight: Send + Sync + 'static { /// Returns the scorer for the given segment. + /// + /// `boost` is a multiplier to apply to the score. + /// /// See [`Query`](./trait.Query.html). - fn scorer(&self, reader: &SegmentReader) -> Result>; + fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result>; /// Returns an `Explanation` for the given document. - fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result; + fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result; /// Returns the number documents within the given `SegmentReader`. - fn count(&self, reader: &SegmentReader) -> Result { - let mut scorer = self.scorer(reader)?; + fn count(&self, reader: &SegmentReader) -> crate::Result { + let mut scorer = self.scorer(reader, 1.0f32)?; if let Some(delete_bitset) = reader.delete_bitset() { Ok(scorer.count(delete_bitset)) } else { diff --git a/src/reader/pool.rs b/src/reader/pool.rs index abd5e5a5d..414f00169 100644 --- a/src/reader/pool.rs +++ b/src/reader/pool.rs @@ -68,7 +68,7 @@ impl Pool { /// After publish, all new `Searcher` acquired will be /// of the new generation. pub fn publish_new_generation(&self, items: Vec) { - assert!(items.len() >= 1); + assert!(!items.is_empty()); let next_generation = self.next_generation.fetch_add(1, Ordering::SeqCst) + 1; let num_items = items.len(); for item in items { @@ -93,7 +93,7 @@ impl Pool { // // Half of them are obsoletes. By requesting `(n+1)` fresh searchers, we ensure that all // searcher will be inspected. - for _ in 0..(num_items + 1) { + for _ in 0..=num_items { let _ = self.acquire(); } } diff --git a/src/schema/facet.rs b/src/schema/facet.rs index c42ad834c..192a536cd 100644 --- a/src/schema/facet.rs +++ b/src/schema/facet.rs @@ -125,7 +125,7 @@ impl Facet { /// This function is the inverse of Facet::from(&str). pub fn to_path_string(&self) -> String { - format!("{}", self.to_string()) + format!("{}", self) } }