diff --git a/query-grammar/Cargo.toml b/query-grammar/Cargo.toml index 5334cd677..59469ff2b 100644 --- a/query-grammar/Cargo.toml +++ b/query-grammar/Cargo.toml @@ -15,3 +15,5 @@ edition = "2024" nom = "7" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" +ordered-float = "5.0.0" +fnv = "1.0.7" diff --git a/query-grammar/src/lib.rs b/query-grammar/src/lib.rs index c7c41df14..46ea65c84 100644 --- a/query-grammar/src/lib.rs +++ b/query-grammar/src/lib.rs @@ -31,7 +31,17 @@ pub fn parse_query_lenient(query: &str) -> (UserInputAst, Vec) { #[cfg(test)] mod tests { - use crate::{parse_query, parse_query_lenient}; + use crate::{UserInputAst, parse_query, parse_query_lenient}; + + #[test] + fn test_deduplication() { + let ast: UserInputAst = parse_query("a a").unwrap(); + let json = serde_json::to_string(&ast).unwrap(); + assert_eq!( + json, + r#"{"type":"bool","clauses":[[null,{"type":"literal","field_name":null,"phrase":"a","delimiter":"none","slop":0,"prefix":false}]]}"# + ); + } #[test] fn test_parse_query_serialization() { diff --git a/query-grammar/src/query_grammar.rs b/query-grammar/src/query_grammar.rs index 3fac6cce7..879a95725 100644 --- a/query-grammar/src/query_grammar.rs +++ b/query-grammar/src/query_grammar.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::iter::once; +use fnv::FnvHashSet; use nom::IResult; use nom::branch::alt; use nom::bytes::complete::tag; @@ -814,7 +815,7 @@ fn boosted_leaf(inp: &str) -> IResult<&str, UserInputAst> { tuple((leaf, fallible(boost))), |(leaf, boost_opt)| match boost_opt { Some(boost) if (boost - 1.0).abs() > f64::EPSILON => { - UserInputAst::Boost(Box::new(leaf), boost) + UserInputAst::Boost(Box::new(leaf), boost.into()) } _ => leaf, }, @@ -826,7 +827,7 @@ fn boosted_leaf_infallible(inp: &str) -> JResult<&str, Option> { tuple_infallible((leaf_infallible, boost)), |((leaf, boost_opt), error)| match boost_opt { Some(boost) if (boost - 1.0).abs() > f64::EPSILON => ( - leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost)), + leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost.into())), error, ), _ => (leaf, error), @@ -1077,12 +1078,25 @@ pub fn parse_to_ast_lenient(query_str: &str) -> (UserInputAst, Vec (rewrite_ast(res), errors) } -/// Removes unnecessary children clauses in AST -/// -/// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433) fn rewrite_ast(mut input: UserInputAst) -> UserInputAst { - if let UserInputAst::Clause(terms) = &mut input { - for term in terms { + if let UserInputAst::Clause(sub_clauses) = &mut input { + // call rewrite_ast recursively on children clauses if applicable + let mut new_clauses = Vec::with_capacity(sub_clauses.len()); + for (occur, clause) in sub_clauses.drain(..) { + let rewritten_clause = rewrite_ast(clause); + new_clauses.push((occur, rewritten_clause)); + } + *sub_clauses = new_clauses; + + // remove duplicate child clauses + // e.g. (+a +b) OR (+c +d) OR (+a +b) => (+a +b) OR (+c +d) + let mut seen = FnvHashSet::default(); + sub_clauses.retain(|term| seen.insert(term.clone())); + + // Removes unnecessary children clauses in AST + // + // Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433) + for term in sub_clauses { rewrite_ast_clause(term); } } diff --git a/query-grammar/src/user_input_ast.rs b/query-grammar/src/user_input_ast.rs index 05a5be19c..e6b4858fd 100644 --- a/query-grammar/src/user_input_ast.rs +++ b/query-grammar/src/user_input_ast.rs @@ -5,7 +5,7 @@ use serde::Serialize; use crate::Occur; -#[derive(PartialEq, Clone, Serialize)] +#[derive(PartialEq, Eq, Hash, Clone, Serialize)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] pub enum UserInputLeaf { @@ -120,7 +120,7 @@ impl Debug for UserInputLeaf { } } -#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize)] +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize)] #[serde(rename_all = "snake_case")] pub enum Delimiter { SingleQuotes, @@ -128,7 +128,7 @@ pub enum Delimiter { None, } -#[derive(PartialEq, Clone, Serialize)] +#[derive(PartialEq, Eq, Hash, Clone, Serialize)] #[serde(rename_all = "snake_case")] pub struct UserInputLiteral { pub field_name: Option, @@ -167,7 +167,7 @@ impl fmt::Debug for UserInputLiteral { } } -#[derive(PartialEq, Debug, Clone, Serialize)] +#[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize)] #[serde(tag = "type", content = "value")] #[serde(rename_all = "snake_case")] pub enum UserInputBound { @@ -204,11 +204,11 @@ impl UserInputBound { } } -#[derive(PartialEq, Clone, Serialize)] +#[derive(PartialEq, Eq, Hash, Clone, Serialize)] #[serde(into = "UserInputAstSerde")] pub enum UserInputAst { Clause(Vec<(Option, UserInputAst)>), - Boost(Box, f64), + Boost(Box, ordered_float::OrderedFloat), Leaf(Box), } @@ -230,9 +230,10 @@ impl From for UserInputAstSerde { fn from(ast: UserInputAst) -> Self { match ast { UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause }, - UserInputAst::Boost(underlying, boost) => { - UserInputAstSerde::Boost { underlying, boost } - } + UserInputAst::Boost(underlying, boost) => UserInputAstSerde::Boost { + underlying, + boost: boost.into_inner(), + }, UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf), } } @@ -391,7 +392,7 @@ mod tests { #[test] fn test_boost_serialization() { let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All)); - let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5); + let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5.into()); let json = serde_json::to_string(&boost_ast).unwrap(); assert_eq!( json, @@ -418,7 +419,7 @@ mod tests { }))), ), ])), - 2.5, + 2.5.into(), ); let json = serde_json::to_string(&boost_ast).unwrap(); assert_eq!( diff --git a/src/query/mod.rs b/src/query/mod.rs index 23e64f189..2ba2f4def 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -104,7 +104,7 @@ mod tests { let query = query_parser.parse_query("a a a a a").unwrap(); let mut terms = Vec::new(); query.query_terms(&mut |term, pos| terms.push((term, pos))); - assert_eq!(vec![(&term_a, false); 5], terms); + assert_eq!(vec![(&term_a, false); 1], terms); } { let query = query_parser.parse_query("a -b").unwrap(); diff --git a/src/query/query_parser/logical_ast.rs b/src/query/query_parser/logical_ast.rs index 914dc8fee..17c68a76a 100644 --- a/src/query/query_parser/logical_ast.rs +++ b/src/query/query_parser/logical_ast.rs @@ -45,6 +45,7 @@ impl LogicalAst { } } + // TODO: Move to rewrite_ast in query_grammar pub fn simplify(self) -> LogicalAst { match self { LogicalAst::Clause(clauses) => { diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index c44de3886..7ebaf3a47 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -672,7 +672,7 @@ impl QueryParser { } UserInputAst::Boost(ast, boost) => { let (ast, errors) = self.compute_logical_ast_with_occur_lenient(*ast); - (ast.boost(boost as Score), errors) + (ast.boost(boost.into_inner() as Score), errors) } UserInputAst::Leaf(leaf) => { let (ast, errors) = self.compute_logical_ast_from_leaf_lenient(*leaf); @@ -2050,6 +2050,16 @@ mod test { ); } + #[test] + pub fn test_deduplication() { + let query = "be be"; + test_parse_query_to_logical_ast_helper( + query, + "(Term(field=0, type=Str, \"be\") Term(field=1, type=Str, \"be\"))", + false, + ); + } + #[test] pub fn test_regex() { let expected_regex = tantivy_fst::Regex::new(r".*b").unwrap();