diff --git a/query-grammar/src/infallible.rs b/query-grammar/src/infallible.rs index 9fd5bff69..b7134f868 100644 --- a/query-grammar/src/infallible.rs +++ b/query-grammar/src/infallible.rs @@ -117,6 +117,22 @@ where F: nom::Parser { } } +pub(crate) fn terminated_infallible( + mut first: F, + mut second: G, +) -> impl FnMut(I) -> JResult +where + F: nom::Parser, + G: nom::Parser, +{ + move |input: I| { + let (input, (o1, mut err)) = first.parse(input)?; + let (input, (_, mut err2)) = second.parse(input)?; + err.append(&mut err2); + Ok((input, (o1, err))) + } +} + pub(crate) fn delimited_infallible( mut first: F, mut second: G, diff --git a/query-grammar/src/query_grammar.rs b/query-grammar/src/query_grammar.rs index 252d4300f..1960f219a 100644 --- a/query-grammar/src/query_grammar.rs +++ b/query-grammar/src/query_grammar.rs @@ -367,7 +367,10 @@ fn literal(inp: &str) -> IResult<&str, UserInputAst> { // something (a field name) got parsed before alt(( map( - tuple((opt(field_name), alt((range, set, exists, term_or_phrase)))), + tuple(( + opt(field_name), + alt((range, set, exists, regex, term_or_phrase)), + )), |(field_name, leaf): (Option, UserInputLeaf)| leaf.set_field(field_name).into(), ), term_group, @@ -389,6 +392,10 @@ fn literal_no_group_infallible(inp: &str) -> JResult<&str, Option> value((), peek(one_of("{[><"))), map(range_infallible, |(range, errs)| (Some(range), errs)), ), + ( + value((), peek(one_of("/"))), + map(regex_infallible, |(regex, errs)| (Some(regex), errs)), + ), ), delimited_infallible(space0_infallible, term_or_phrase_infallible, nothing), ), @@ -689,6 +696,61 @@ fn set_infallible(mut inp: &str) -> JResult<&str, UserInputLeaf> { } } +fn regex(inp: &str) -> IResult<&str, UserInputLeaf> { + map( + terminated( + delimited( + char('/'), + many1(alt((preceded(char('\\'), char('/')), none_of("/")))), + char('/'), + ), + peek(alt((multispace1, eof))), + ), + |elements| UserInputLeaf::Regex { + field: None, + pattern: elements.into_iter().collect::(), + }, + )(inp) +} + +fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> { + match terminated_infallible( + delimited_infallible( + opt_i_err(char('/'), "missing delimiter /"), + opt_i(many1(alt((preceded(char('\\'), char('/')), none_of("/"))))), + opt_i_err(char('/'), "missing delimiter /"), + ), + opt_i_err( + peek(alt((multispace1, eof))), + "expected whitespace or end of input", + ), + )(inp) + { + Ok((rest, (elements_part, errors))) => { + let pattern = match elements_part { + Some(elements_part) => elements_part.into_iter().collect(), + None => String::new(), + }; + let res = UserInputLeaf::Regex { + field: None, + pattern, + }; + Ok((rest, (res, errors))) + } + Err(e) => { + let errs = vec![LenientErrorInternal { + pos: inp.len(), + message: e.to_string(), + }]; + let res = UserInputLeaf::Regex { + field: None, + pattern: String::new(), + }; + Ok((inp, (res, errs))) + } + } +} + fn negate(expr: UserInputAst) -> UserInputAst { expr.unary(Occur::MustNot) } @@ -1694,6 +1756,63 @@ mod test { test_is_parse_err(r#"!bc:def"#, "!bc:def"); } + #[test] + fn test_regex_parser() { + let r = parse_to_ast(r#"a:/joh?n(ath[oa]n)/"#); + assert!(r.is_ok(), "Failed to parse custom query: {r:?}"); + let (_, input) = r.unwrap(); + match input { + UserInputAst::Leaf(leaf) => match leaf.as_ref() { + UserInputLeaf::Regex { field, pattern } => { + assert_eq!(field, &Some("a".to_string())); + assert_eq!(pattern, "joh?n(ath[oa]n)"); + } + _ => panic!("Expected a regex leaf, got {leaf:?}"), + }, + _ => panic!("Expected a leaf"), + } + let r = parse_to_ast(r#"a:/\\/cgi-bin\\/luci.*/"#); + assert!(r.is_ok(), "Failed to parse custom query: {r:?}"); + let (_, input) = r.unwrap(); + match input { + UserInputAst::Leaf(leaf) => match leaf.as_ref() { + UserInputLeaf::Regex { field, pattern } => { + assert_eq!(field, &Some("a".to_string())); + assert_eq!(pattern, "\\/cgi-bin\\/luci.*"); + } + _ => panic!("Expected a regex leaf, got {leaf:?}"), + }, + _ => panic!("Expected a leaf"), + } + } + + #[test] + fn test_regex_parser_lenient() { + let literal = |query| literal_infallible(query).unwrap().1; + + let (res, errs) = literal(r#"a:/joh?n(ath[oa]n)/"#); + let expected = UserInputLeaf::Regex { + field: Some("a".to_string()), + pattern: "joh?n(ath[oa]n)".to_string(), + } + .into(); + assert_eq!(res.unwrap(), expected); + assert!(errs.is_empty(), "Expected no errors, got: {errs:?}"); + + let (res, errs) = literal("title:/joh?n(ath[oa]n)"); + let expected = UserInputLeaf::Regex { + field: Some("title".to_string()), + pattern: "joh?n(ath[oa]n)".to_string(), + } + .into(); + assert_eq!(res.unwrap(), expected); + assert_eq!(errs.len(), 1, "Expected 1 error, got: {errs:?}"); + assert_eq!( + errs[0].message, "missing delimiter /", + "Unexpected error message", + ); + } + #[test] fn test_space_before_value() { test_parse_query_to_ast_helper("field : a", r#""field":a"#); diff --git a/query-grammar/src/user_input_ast.rs b/query-grammar/src/user_input_ast.rs index f72f018fc..05a5be19c 100644 --- a/query-grammar/src/user_input_ast.rs +++ b/query-grammar/src/user_input_ast.rs @@ -23,6 +23,10 @@ pub enum UserInputLeaf { Exists { field: String, }, + Regex { + field: Option, + pattern: String, + }, } impl UserInputLeaf { @@ -46,6 +50,7 @@ impl UserInputLeaf { UserInputLeaf::Exists { field: _ } => UserInputLeaf::Exists { field: field.expect("Exist query without a field isn't allowed"), }, + UserInputLeaf::Regex { field: _, pattern } => UserInputLeaf::Regex { field, pattern }, } } @@ -103,6 +108,14 @@ impl Debug for UserInputLeaf { UserInputLeaf::Exists { field } => { write!(formatter, "$exists(\"{field}\")") } + UserInputLeaf::Regex { field, pattern } => { + if let Some(field) = field { + // TODO properly escape field (in case of \") + write!(formatter, "\"{field}\":")?; + } + // TODO properly escape pattern (in case of \") + write!(formatter, "/{pattern}/") + } } } } diff --git a/src/query/query_parser/logical_ast.rs b/src/query/query_parser/logical_ast.rs index b0929f26a..914dc8fee 100644 --- a/src/query/query_parser/logical_ast.rs +++ b/src/query/query_parser/logical_ast.rs @@ -1,8 +1,11 @@ use std::fmt; use std::ops::Bound; +use std::sync::Arc; + +use tantivy_fst::Regex; use crate::query::Occur; -use crate::schema::Term; +use crate::schema::{Field, Term}; use crate::Score; #[derive(Clone)] @@ -21,6 +24,10 @@ pub enum LogicalLiteral { elements: Vec, }, All, + Regex { + pattern: Arc, + field: Field, + }, } pub enum LogicalAst { @@ -147,6 +154,10 @@ impl fmt::Debug for LogicalLiteral { write!(formatter, "]") } LogicalLiteral::All => write!(formatter, "*"), + LogicalLiteral::Regex { + ref pattern, + ref field, + } => write!(formatter, "Regex({field:?}, {pattern:?})"), } } } diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index 7900d7837..62d15d8f5 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -2,12 +2,14 @@ use std::net::{AddrParseError, IpAddr}; use std::num::{ParseFloatError, ParseIntError}; use std::ops::Bound; use std::str::{FromStr, ParseBoolError}; +use std::sync::Arc; use base64::engine::general_purpose::STANDARD as BASE64; use base64::Engine; use itertools::Itertools; use query_grammar::{UserInputAst, UserInputBound, UserInputLeaf, UserInputLiteral}; use rustc_hash::FxHashMap; +use tantivy_fst::Regex; use super::logical_ast::*; use crate::index::Index; @@ -15,7 +17,7 @@ use crate::json_utils::convert_to_fast_value_and_append_to_json_term; use crate::query::range_query::{is_type_valid_for_fastfield_range_query, RangeQuery}; use crate::query::{ AllQuery, BooleanQuery, BoostQuery, EmptyQuery, FuzzyTermQuery, Occur, PhrasePrefixQuery, - PhraseQuery, Query, TermQuery, TermSetQuery, + PhraseQuery, Query, RegexQuery, TermQuery, TermSetQuery, }; use crate::schema::{ Facet, FacetParseError, Field, FieldType, IndexRecordOption, IntoIpv6Addr, JsonObjectOptions, @@ -206,6 +208,7 @@ pub struct QueryParser { tokenizer_manager: TokenizerManager, boost: FxHashMap, fuzzy: FxHashMap, + regexes_allowed: bool, } #[derive(Clone)] @@ -260,6 +263,7 @@ impl QueryParser { conjunction_by_default: false, boost: Default::default(), fuzzy: Default::default(), + regexes_allowed: false, } } @@ -320,6 +324,11 @@ impl QueryParser { ); } + /// Allow regexes in queries + pub fn allow_regexes(&mut self) { + self.regexes_allowed = true; + } + /// Parse a query /// /// Note that `parse_query` returns an error if the input @@ -860,6 +869,51 @@ impl QueryParser { "Range query need to target a specific field.".to_string(), )], ), + UserInputLeaf::Regex { field, pattern } => { + if !self.regexes_allowed { + return ( + None, + vec![QueryParserError::UnsupportedQuery( + "Regex queries are not allowed.".to_string(), + )], + ); + } + let full_path = try_tuple!(field.ok_or_else(|| { + QueryParserError::UnsupportedQuery( + "Regex query need to target a specific field.".to_string(), + ) + })); + let (field, json_path) = try_tuple!(self + .split_full_path(&full_path) + .ok_or_else(|| QueryParserError::FieldDoesNotExist(full_path.clone()))); + if !json_path.is_empty() { + return ( + None, + vec![QueryParserError::UnsupportedQuery( + "Regex query does not support json paths.".to_string(), + )], + ); + } + if !matches!( + self.schema.get_field_entry(field).field_type(), + FieldType::Str(_) + ) { + return ( + None, + vec![QueryParserError::UnsupportedQuery( + "Regex query only supported on text fields".to_string(), + )], + ); + } + let pattern = try_tuple!(Regex::new(&pattern).map_err(|e| { + QueryParserError::UnsupportedQuery(format!("Invalid regex: {e}")) + })); + let logical_ast = LogicalAst::Leaf(Box::new(LogicalLiteral::Regex { + pattern: Arc::new(pattern), + field, + })); + (Some(logical_ast), Vec::new()) + } } } } @@ -902,6 +956,9 @@ fn convert_literal_to_query( LogicalLiteral::Range { lower, upper } => Box::new(RangeQuery::new(lower, upper)), LogicalLiteral::Set { elements, .. } => Box::new(TermSetQuery::new(elements)), LogicalLiteral::All => Box::new(AllQuery), + LogicalLiteral::Regex { pattern, field } => { + Box::new(RegexQuery::from_regex(pattern, field)) + } } } @@ -1100,11 +1157,15 @@ mod test { query: &str, default_conjunction: bool, default_fields: &[&'static str], + allow_regexes: bool, ) -> Result { let mut query_parser = make_query_parser_with_default_fields(default_fields); if default_conjunction { query_parser.set_conjunction_by_default(); } + if allow_regexes { + query_parser.allow_regexes(); + } query_parser.parse_query_to_logical_ast(query) } @@ -1116,6 +1177,7 @@ mod test { query, default_conjunction, &["title", "text"], + true, ) } @@ -1130,6 +1192,7 @@ mod test { query, default_conjunction, default_fields, + true, ) .unwrap(); let query_str = format!("{query:?}"); @@ -1993,4 +2056,56 @@ mod test { Err(QueryParserError::ExpectedInt(_)) ); } + + #[test] + pub fn test_regex() { + let expected_regex = tantivy_fst::Regex::new(r".*b").unwrap(); + test_parse_query_to_logical_ast_helper( + "title:/.*b/", + format!("Regex(Field(0), {:#?})", expected_regex).as_str(), + false, + ); + + // Invalid field + let err = parse_query_to_logical_ast("float:/.*b/", false).unwrap_err(); + assert_eq!( + err.to_string(), + "Unsupported query: Regex query only supported on text fields" + ); + + // No field specified + let err = parse_query_to_logical_ast("/.*b/", false).unwrap_err(); + assert_eq!( + err.to_string(), + "Unsupported query: Regex query need to target a specific field." + ); + + // Regex on a json path + let err = parse_query_to_logical_ast("title.subpath:/.*b/", false).unwrap_err(); + assert_eq!( + err.to_string(), + "Unsupported query: Regex query does not support json paths." + ); + + // Invalid regex + let err = parse_query_to_logical_ast("title:/[A-Z*b/", false).unwrap_err(); + assert_eq!( + err.to_string(), + "Unsupported query: Invalid regex: regex parse error:\n [A-Z*b\n ^\nerror: \ + unclosed character class" + ); + + // Regexes not allowed + let err = parse_query_to_logical_ast_with_default_fields( + "title:/.*b/", + false, + &["title", "text"], + false, + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "Unsupported query: Regex queries are not allowed." + ); + } }