perf: deduplicate queries (#2698)

* deduplicate queries

Deduplicate queries in the UserInputAst after parsing queries

* add return type
This commit is contained in:
PSeitz-dd
2025-09-22 12:16:58 +02:00
committed by GitHub
parent 85010b589a
commit 70da310b2d
7 changed files with 59 additions and 21 deletions

View File

@@ -15,3 +15,5 @@ edition = "2024"
nom = "7" nom = "7"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140" serde_json = "1.0.140"
ordered-float = "5.0.0"
fnv = "1.0.7"

View File

@@ -31,7 +31,17 @@ pub fn parse_query_lenient(query: &str) -> (UserInputAst, Vec<LenientError>) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{parse_query, parse_query_lenient}; use crate::{UserInputAst, parse_query, parse_query_lenient};
#[test]
fn test_deduplication() {
let ast: UserInputAst = parse_query("a a").unwrap();
let json = serde_json::to_string(&ast).unwrap();
assert_eq!(
json,
r#"{"type":"bool","clauses":[[null,{"type":"literal","field_name":null,"phrase":"a","delimiter":"none","slop":0,"prefix":false}]]}"#
);
}
#[test] #[test]
fn test_parse_query_serialization() { fn test_parse_query_serialization() {

View File

@@ -1,6 +1,7 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::iter::once; use std::iter::once;
use fnv::FnvHashSet;
use nom::IResult; use nom::IResult;
use nom::branch::alt; use nom::branch::alt;
use nom::bytes::complete::tag; use nom::bytes::complete::tag;
@@ -814,7 +815,7 @@ fn boosted_leaf(inp: &str) -> IResult<&str, UserInputAst> {
tuple((leaf, fallible(boost))), tuple((leaf, fallible(boost))),
|(leaf, boost_opt)| match boost_opt { |(leaf, boost_opt)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => { Some(boost) if (boost - 1.0).abs() > f64::EPSILON => {
UserInputAst::Boost(Box::new(leaf), boost) UserInputAst::Boost(Box::new(leaf), boost.into())
} }
_ => leaf, _ => leaf,
}, },
@@ -826,7 +827,7 @@ fn boosted_leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
tuple_infallible((leaf_infallible, boost)), tuple_infallible((leaf_infallible, boost)),
|((leaf, boost_opt), error)| match boost_opt { |((leaf, boost_opt), error)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => ( Some(boost) if (boost - 1.0).abs() > f64::EPSILON => (
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost)), leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost.into())),
error, error,
), ),
_ => (leaf, error), _ => (leaf, error),
@@ -1077,12 +1078,25 @@ pub fn parse_to_ast_lenient(query_str: &str) -> (UserInputAst, Vec<LenientError>
(rewrite_ast(res), errors) (rewrite_ast(res), errors)
} }
/// Removes unnecessary children clauses in AST
///
/// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
fn rewrite_ast(mut input: UserInputAst) -> UserInputAst { fn rewrite_ast(mut input: UserInputAst) -> UserInputAst {
if let UserInputAst::Clause(terms) = &mut input { if let UserInputAst::Clause(sub_clauses) = &mut input {
for term in terms { // call rewrite_ast recursively on children clauses if applicable
let mut new_clauses = Vec::with_capacity(sub_clauses.len());
for (occur, clause) in sub_clauses.drain(..) {
let rewritten_clause = rewrite_ast(clause);
new_clauses.push((occur, rewritten_clause));
}
*sub_clauses = new_clauses;
// remove duplicate child clauses
// e.g. (+a +b) OR (+c +d) OR (+a +b) => (+a +b) OR (+c +d)
let mut seen = FnvHashSet::default();
sub_clauses.retain(|term| seen.insert(term.clone()));
// Removes unnecessary children clauses in AST
//
// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
for term in sub_clauses {
rewrite_ast_clause(term); rewrite_ast_clause(term);
} }
} }

View File

@@ -5,7 +5,7 @@ use serde::Serialize;
use crate::Occur; use crate::Occur;
#[derive(PartialEq, Clone, Serialize)] #[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum UserInputLeaf { pub enum UserInputLeaf {
@@ -120,7 +120,7 @@ impl Debug for UserInputLeaf {
} }
} }
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize)] #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Delimiter { pub enum Delimiter {
SingleQuotes, SingleQuotes,
@@ -128,7 +128,7 @@ pub enum Delimiter {
None, None,
} }
#[derive(PartialEq, Clone, Serialize)] #[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct UserInputLiteral { pub struct UserInputLiteral {
pub field_name: Option<String>, pub field_name: Option<String>,
@@ -167,7 +167,7 @@ impl fmt::Debug for UserInputLiteral {
} }
} }
#[derive(PartialEq, Debug, Clone, Serialize)] #[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum UserInputBound { pub enum UserInputBound {
@@ -204,11 +204,11 @@ impl UserInputBound {
} }
} }
#[derive(PartialEq, Clone, Serialize)] #[derive(PartialEq, Eq, Hash, Clone, Serialize)]
#[serde(into = "UserInputAstSerde")] #[serde(into = "UserInputAstSerde")]
pub enum UserInputAst { pub enum UserInputAst {
Clause(Vec<(Option<Occur>, UserInputAst)>), Clause(Vec<(Option<Occur>, UserInputAst)>),
Boost(Box<UserInputAst>, f64), Boost(Box<UserInputAst>, ordered_float::OrderedFloat<f64>),
Leaf(Box<UserInputLeaf>), Leaf(Box<UserInputLeaf>),
} }
@@ -230,9 +230,10 @@ impl From<UserInputAst> for UserInputAstSerde {
fn from(ast: UserInputAst) -> Self { fn from(ast: UserInputAst) -> Self {
match ast { match ast {
UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause }, UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause },
UserInputAst::Boost(underlying, boost) => { UserInputAst::Boost(underlying, boost) => UserInputAstSerde::Boost {
UserInputAstSerde::Boost { underlying, boost } underlying,
} boost: boost.into_inner(),
},
UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf), UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf),
} }
} }
@@ -391,7 +392,7 @@ mod tests {
#[test] #[test]
fn test_boost_serialization() { fn test_boost_serialization() {
let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All)); let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All));
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5); let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5.into());
let json = serde_json::to_string(&boost_ast).unwrap(); let json = serde_json::to_string(&boost_ast).unwrap();
assert_eq!( assert_eq!(
json, json,
@@ -418,7 +419,7 @@ mod tests {
}))), }))),
), ),
])), ])),
2.5, 2.5.into(),
); );
let json = serde_json::to_string(&boost_ast).unwrap(); let json = serde_json::to_string(&boost_ast).unwrap();
assert_eq!( assert_eq!(

View File

@@ -104,7 +104,7 @@ mod tests {
let query = query_parser.parse_query("a a a a a").unwrap(); let query = query_parser.parse_query("a a a a a").unwrap();
let mut terms = Vec::new(); let mut terms = Vec::new();
query.query_terms(&mut |term, pos| terms.push((term, pos))); query.query_terms(&mut |term, pos| terms.push((term, pos)));
assert_eq!(vec![(&term_a, false); 5], terms); assert_eq!(vec![(&term_a, false); 1], terms);
} }
{ {
let query = query_parser.parse_query("a -b").unwrap(); let query = query_parser.parse_query("a -b").unwrap();

View File

@@ -45,6 +45,7 @@ impl LogicalAst {
} }
} }
// TODO: Move to rewrite_ast in query_grammar
pub fn simplify(self) -> LogicalAst { pub fn simplify(self) -> LogicalAst {
match self { match self {
LogicalAst::Clause(clauses) => { LogicalAst::Clause(clauses) => {

View File

@@ -672,7 +672,7 @@ impl QueryParser {
} }
UserInputAst::Boost(ast, boost) => { UserInputAst::Boost(ast, boost) => {
let (ast, errors) = self.compute_logical_ast_with_occur_lenient(*ast); let (ast, errors) = self.compute_logical_ast_with_occur_lenient(*ast);
(ast.boost(boost as Score), errors) (ast.boost(boost.into_inner() as Score), errors)
} }
UserInputAst::Leaf(leaf) => { UserInputAst::Leaf(leaf) => {
let (ast, errors) = self.compute_logical_ast_from_leaf_lenient(*leaf); let (ast, errors) = self.compute_logical_ast_from_leaf_lenient(*leaf);
@@ -2050,6 +2050,16 @@ mod test {
); );
} }
#[test]
pub fn test_deduplication() {
let query = "be be";
test_parse_query_to_logical_ast_helper(
query,
"(Term(field=0, type=Str, \"be\") Term(field=1, type=Str, \"be\"))",
false,
);
}
#[test] #[test]
pub fn test_regex() { pub fn test_regex() {
let expected_regex = tantivy_fst::Regex::new(r".*b").unwrap(); let expected_regex = tantivy_fst::Regex::new(r".*b").unwrap();