mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-06 01:02:55 +00:00
[WIP] Alternative take on boosted queries (#772)
* Alternative take on boosted queries * Fixing unit test * Added boosting to the query grammar. * Made BoostQuery public. * Added support for boosting field in QueryParser Closes #547
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
use super::logical_ast::*;
|
||||
use crate::core::Index;
|
||||
use crate::query::AllQuery;
|
||||
use crate::query::BooleanQuery;
|
||||
use crate::query::EmptyQuery;
|
||||
use crate::query::Occur;
|
||||
@@ -8,11 +7,13 @@ use crate::query::PhraseQuery;
|
||||
use crate::query::Query;
|
||||
use crate::query::RangeQuery;
|
||||
use crate::query::TermQuery;
|
||||
use crate::query::{AllQuery, BoostQuery};
|
||||
use crate::schema::{Facet, IndexRecordOption};
|
||||
use crate::schema::{Field, Schema};
|
||||
use crate::schema::{FieldType, Term};
|
||||
use crate::tokenizer::TokenizerManager;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::num::{ParseFloatError, ParseIntError};
|
||||
use std::ops::Bound;
|
||||
use std::str::FromStr;
|
||||
@@ -144,7 +145,6 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
|
||||
///
|
||||
/// * must terms: By prepending a term by a `+`, a term can be made required for the search.
|
||||
///
|
||||
///
|
||||
/// * phrase terms: Quoted terms become phrase searches on fields that have positions indexed.
|
||||
/// e.g., `title:"Barack Obama"` will only find documents that have "barack" immediately followed
|
||||
/// by "obama".
|
||||
@@ -158,12 +158,20 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
|
||||
///
|
||||
/// * all docs query: A plain `*` will match all documents in the index.
|
||||
///
|
||||
/// Parts of the queries can be boosted by appending `^boostfactor`.
|
||||
/// For instance, `"SRE"^2.0 OR devops^0.4` will boost documents containing `SRE` instead of
|
||||
/// devops. Negative boosts are not allowed.
|
||||
///
|
||||
/// It is also possible to define a boost for a some specific field, at the query parser level.
|
||||
/// (See [`set_boost(...)`](#method.set_field_boost) ). Typically you may want to boost a title
|
||||
/// field.
|
||||
#[derive(Clone)]
|
||||
pub struct QueryParser {
|
||||
schema: Schema,
|
||||
default_fields: Vec<Field>,
|
||||
conjunction_by_default: bool,
|
||||
tokenizer_manager: TokenizerManager,
|
||||
boost: HashMap<Field, f32>,
|
||||
}
|
||||
|
||||
impl QueryParser {
|
||||
@@ -181,6 +189,7 @@ impl QueryParser {
|
||||
default_fields,
|
||||
tokenizer_manager,
|
||||
conjunction_by_default: false,
|
||||
boost: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,6 +210,17 @@ impl QueryParser {
|
||||
self.conjunction_by_default = true;
|
||||
}
|
||||
|
||||
/// Sets a boost for a specific field.
|
||||
///
|
||||
/// The parse query will automatically boost this field.
|
||||
///
|
||||
/// If the query defines a query boost through the query language (e.g: `country:France^3.0`),
|
||||
/// the two boosts (the one defined in the query, and the one defined in the `QueryParser`)
|
||||
/// are multiplied together.
|
||||
pub fn set_field_boost(&mut self, field: Field, boost: f32) {
|
||||
self.boost.insert(field, boost);
|
||||
}
|
||||
|
||||
/// Parse a query
|
||||
///
|
||||
/// Note that `parse_query` returns an error if the input
|
||||
@@ -407,6 +427,10 @@ impl QueryParser {
|
||||
self.compute_logical_ast_with_occur(*subquery)?;
|
||||
Ok((Occur::compose(left_occur, right_occur), logical_sub_queries))
|
||||
}
|
||||
UserInputAST::Boost(ast, boost) => {
|
||||
let (occur, ast_without_occur) = self.compute_logical_ast_with_occur(*ast)?;
|
||||
Ok((occur, ast_without_occur.boost(boost)))
|
||||
}
|
||||
UserInputAST::Leaf(leaf) => {
|
||||
let result_ast = self.compute_logical_ast_from_leaf(*leaf)?;
|
||||
Ok((Occur::Should, result_ast))
|
||||
@@ -414,6 +438,10 @@ impl QueryParser {
|
||||
}
|
||||
}
|
||||
|
||||
fn field_boost(&self, field: Field) -> f32 {
|
||||
self.boost.get(&field).cloned().unwrap_or(1.0f32)
|
||||
}
|
||||
|
||||
fn compute_logical_ast_from_leaf(
|
||||
&self,
|
||||
leaf: UserInputLeaf,
|
||||
@@ -439,7 +467,9 @@ impl QueryParser {
|
||||
let mut asts: Vec<LogicalAST> = Vec::new();
|
||||
for (field, phrase) in term_phrases {
|
||||
if let Some(ast) = self.compute_logical_ast_for_leaf(field, &phrase)? {
|
||||
asts.push(LogicalAST::Leaf(Box::new(ast)));
|
||||
// Apply some field specific boost defined at the query parser level.
|
||||
let boost = self.field_boost(field);
|
||||
asts.push(LogicalAST::Leaf(Box::new(ast)).boost(boost));
|
||||
}
|
||||
}
|
||||
let result_ast: LogicalAST = if asts.len() == 1 {
|
||||
@@ -459,14 +489,16 @@ impl QueryParser {
|
||||
let mut clauses = fields
|
||||
.iter()
|
||||
.map(|&field| {
|
||||
let boost = self.field_boost(field);
|
||||
let field_entry = self.schema.get_field_entry(field);
|
||||
let value_type = field_entry.field_type().value_type();
|
||||
Ok(LogicalAST::Leaf(Box::new(LogicalLiteral::Range {
|
||||
let logical_ast = LogicalAST::Leaf(Box::new(LogicalLiteral::Range {
|
||||
field,
|
||||
value_type,
|
||||
lower: self.resolve_bound(field, &lower)?,
|
||||
upper: self.resolve_bound(field, &upper)?,
|
||||
})))
|
||||
}));
|
||||
Ok(logical_ast.boost(boost))
|
||||
})
|
||||
.collect::<Result<Vec<_>, QueryParserError>>()?;
|
||||
let result_ast = if clauses.len() == 1 {
|
||||
@@ -519,6 +551,11 @@ fn convert_to_query(logical_ast: LogicalAST) -> Box<dyn Query> {
|
||||
Some(LogicalAST::Leaf(trimmed_logical_literal)) => {
|
||||
convert_literal_to_query(*trimmed_logical_literal)
|
||||
}
|
||||
Some(LogicalAST::Boost(ast, boost)) => {
|
||||
let query = convert_to_query(*ast);
|
||||
let boosted_query = BoostQuery::new(query, boost);
|
||||
Box::new(boosted_query)
|
||||
}
|
||||
None => Box::new(EmptyQuery),
|
||||
}
|
||||
}
|
||||
@@ -538,7 +575,7 @@ mod test {
|
||||
use crate::Index;
|
||||
use matches::assert_matches;
|
||||
|
||||
fn make_query_parser() -> QueryParser {
|
||||
fn make_schema() -> Schema {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text_field_indexing = TextFieldIndexing::default()
|
||||
.set_tokenizer("en_with_stop_words")
|
||||
@@ -546,8 +583,8 @@ mod test {
|
||||
let text_options = TextOptions::default()
|
||||
.set_indexing_options(text_field_indexing)
|
||||
.set_stored();
|
||||
let title = schema_builder.add_text_field("title", TEXT);
|
||||
let text = schema_builder.add_text_field("text", TEXT);
|
||||
schema_builder.add_text_field("title", TEXT);
|
||||
schema_builder.add_text_field("text", TEXT);
|
||||
schema_builder.add_i64_field("signed", INDEXED);
|
||||
schema_builder.add_u64_field("unsigned", INDEXED);
|
||||
schema_builder.add_text_field("notindexed_text", STORED);
|
||||
@@ -558,8 +595,15 @@ mod test {
|
||||
schema_builder.add_date_field("date", INDEXED);
|
||||
schema_builder.add_f64_field("float", INDEXED);
|
||||
schema_builder.add_facet_field("facet");
|
||||
let schema = schema_builder.build();
|
||||
let default_fields = vec![title, text];
|
||||
schema_builder.build()
|
||||
}
|
||||
|
||||
fn make_query_parser() -> QueryParser {
|
||||
let schema = make_schema();
|
||||
let default_fields: Vec<Field> = vec!["title", "text"]
|
||||
.into_iter()
|
||||
.flat_map(|field_name| schema.get_field(field_name))
|
||||
.collect();
|
||||
let tokenizer_manager = TokenizerManager::default();
|
||||
tokenizer_manager.register(
|
||||
"en_with_stop_words",
|
||||
@@ -601,6 +645,45 @@ mod test {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_query_with_boost() {
|
||||
let mut query_parser = make_query_parser();
|
||||
let schema = make_schema();
|
||||
let text_field = schema.get_field("text").unwrap();
|
||||
query_parser.set_field_boost(text_field, 2.0f32);
|
||||
let query = query_parser.parse_query("text:hello").unwrap();
|
||||
assert_eq!(
|
||||
format!("{:?}", query),
|
||||
"Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_query_range_with_boost() {
|
||||
let mut query_parser = make_query_parser();
|
||||
let schema = make_schema();
|
||||
let title_field = schema.get_field("title").unwrap();
|
||||
query_parser.set_field_boost(title_field, 2.0f32);
|
||||
let query = query_parser.parse_query("title:[A TO B]").unwrap();
|
||||
assert_eq!(
|
||||
format!("{:?}", query),
|
||||
"Boost(query=RangeQuery { field: Field(0), value_type: Str, left_bound: Included([97]), right_bound: Included([98]) }, boost=2)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_query_with_default_boost_and_custom_boost() {
|
||||
let mut query_parser = make_query_parser();
|
||||
let schema = make_schema();
|
||||
let text_field = schema.get_field("text").unwrap();
|
||||
query_parser.set_field_boost(text_field, 2.0f32);
|
||||
let query = query_parser.parse_query("text:hello^2").unwrap();
|
||||
assert_eq!(
|
||||
format!("{:?}", query),
|
||||
"Boost(query=Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2), boost=2)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_nonindexed_field_yields_error() {
|
||||
let query_parser = make_query_parser();
|
||||
|
||||
Reference in New Issue
Block a user