mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-06-05 10:00:41 +00:00
[WIP] Alternative take on boosted queries (#772)
* Alternative take on boosted queries * Fixing unit test * Added boosting to the query grammar. * Made BoostQuery public. * Added support for boosting field in QueryParser Closes #547
This commit is contained in:
@@ -55,10 +55,11 @@ impl BooleanWeight {
|
||||
fn per_occur_scorers(
|
||||
&self,
|
||||
reader: &SegmentReader,
|
||||
boost: f32,
|
||||
) -> crate::Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> {
|
||||
let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new();
|
||||
for &(ref occur, ref subweight) in &self.weights {
|
||||
let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader)?;
|
||||
let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader, boost)?;
|
||||
per_occur_scorers
|
||||
.entry(*occur)
|
||||
.or_insert_with(Vec::new)
|
||||
@@ -70,8 +71,9 @@ impl BooleanWeight {
|
||||
fn complex_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
&self,
|
||||
reader: &SegmentReader,
|
||||
boost: f32,
|
||||
) -> crate::Result<Box<dyn Scorer>> {
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader)?;
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||
|
||||
let should_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::Should)
|
||||
@@ -112,7 +114,7 @@ impl BooleanWeight {
|
||||
}
|
||||
|
||||
impl Weight for BooleanWeight {
|
||||
fn scorer(&self, reader: &SegmentReader) -> crate::Result<Box<dyn Scorer>> {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> {
|
||||
if self.weights.is_empty() {
|
||||
Ok(Box::new(EmptyScorer))
|
||||
} else if self.weights.len() == 1 {
|
||||
@@ -120,17 +122,17 @@ impl Weight for BooleanWeight {
|
||||
if occur == Occur::MustNot {
|
||||
Ok(Box::new(EmptyScorer))
|
||||
} else {
|
||||
weight.scorer(reader)
|
||||
weight.scorer(reader, boost)
|
||||
}
|
||||
} else if self.scoring_enabled {
|
||||
self.complex_scorer::<SumWithCoordsCombiner>(reader)
|
||||
self.complex_scorer::<SumWithCoordsCombiner>(reader, boost)
|
||||
} else {
|
||||
self.complex_scorer::<DoNothingCombiner>(reader)
|
||||
self.complex_scorer::<DoNothingCombiner>(reader, boost)
|
||||
}
|
||||
}
|
||||
|
||||
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
||||
let mut scorer = self.scorer(reader)?;
|
||||
let mut scorer = self.scorer(reader, 1.0f32)?;
|
||||
if scorer.skip_next(doc) != SkipResult::Reached {
|
||||
return Err(does_not_match(doc));
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ mod tests {
|
||||
use crate::query::Scorer;
|
||||
use crate::query::TermQuery;
|
||||
use crate::schema::*;
|
||||
use crate::tests::assert_nearly_equals;
|
||||
use crate::Index;
|
||||
use crate::{DocAddress, DocId};
|
||||
|
||||
@@ -70,7 +71,9 @@ mod tests {
|
||||
let query = query_parser.parse_query("+a").unwrap();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let weight = query.weight(&searcher, true).unwrap();
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
|
||||
let scorer = weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
assert!(scorer.is::<TermScorer>());
|
||||
}
|
||||
|
||||
@@ -82,13 +85,17 @@ mod tests {
|
||||
{
|
||||
let query = query_parser.parse_query("+a +b +c").unwrap();
|
||||
let weight = query.weight(&searcher, true).unwrap();
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
|
||||
let scorer = weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
assert!(scorer.is::<Intersection<TermScorer>>());
|
||||
}
|
||||
{
|
||||
let query = query_parser.parse_query("+a +(b c)").unwrap();
|
||||
let weight = query.weight(&searcher, true).unwrap();
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
|
||||
let scorer = weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
|
||||
}
|
||||
}
|
||||
@@ -101,7 +108,9 @@ mod tests {
|
||||
{
|
||||
let query = query_parser.parse_query("+a b").unwrap();
|
||||
let weight = query.weight(&searcher, true).unwrap();
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
|
||||
let scorer = weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
assert!(scorer.is::<RequiredOptionalScorer<
|
||||
Box<dyn Scorer>,
|
||||
Box<dyn Scorer>,
|
||||
@@ -111,7 +120,9 @@ mod tests {
|
||||
{
|
||||
let query = query_parser.parse_query("+a b").unwrap();
|
||||
let weight = query.weight(&searcher, false).unwrap();
|
||||
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
|
||||
let scorer = weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
assert!(scorer.is::<TermScorer>());
|
||||
}
|
||||
}
|
||||
@@ -179,6 +190,50 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_boolean_query_with_weight() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field = schema_builder.add_text_field("text", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
{
|
||||
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
|
||||
index_writer.add_document(doc!(text_field => "a b c"));
|
||||
index_writer.add_document(doc!(text_field => "a c"));
|
||||
index_writer.add_document(doc!(text_field => "b c"));
|
||||
assert!(index_writer.commit().is_ok());
|
||||
}
|
||||
let term_a: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "a"),
|
||||
IndexRecordOption::WithFreqs,
|
||||
));
|
||||
let term_b: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
Term::from_field_text(text_field, "b"),
|
||||
IndexRecordOption::WithFreqs,
|
||||
));
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let boolean_query =
|
||||
BooleanQuery::from(vec![(Occur::Should, term_a), (Occur::Should, term_b)]);
|
||||
let boolean_weight = boolean_query.weight(&searcher, true).unwrap();
|
||||
{
|
||||
let mut boolean_scorer = boolean_weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
assert!(boolean_scorer.advance());
|
||||
assert_eq!(boolean_scorer.doc(), 0u32);
|
||||
assert_nearly_equals(boolean_scorer.score(), 0.84163445f32);
|
||||
}
|
||||
{
|
||||
let mut boolean_scorer = boolean_weight
|
||||
.scorer(searcher.segment_reader(0u32), 2.0f32)
|
||||
.unwrap();
|
||||
assert!(boolean_scorer.advance());
|
||||
assert_eq!(boolean_scorer.doc(), 0u32);
|
||||
assert_nearly_equals(boolean_scorer.score(), 1.6832689f32);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_intersection_score() {
|
||||
let (index, text_field) = aux_test_helper();
|
||||
@@ -249,7 +304,9 @@ mod tests {
|
||||
let query_parser = QueryParser::for_index(&index, vec![title, text]);
|
||||
let query = query_parser.parse_query("Оксана Лифенко").unwrap();
|
||||
let weight = query.weight(&searcher, true).unwrap();
|
||||
let mut scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
|
||||
let mut scorer = weight
|
||||
.scorer(searcher.segment_reader(0u32), 1.0f32)
|
||||
.unwrap();
|
||||
scorer.advance();
|
||||
|
||||
let explanation = query.explain(&searcher, DocAddress(0u32, 0u32)).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user