From e4e416ac423d7dfa3140970b1cacf986efedea5c Mon Sep 17 00:00:00 2001 From: Ping Xia <99703971+PingXia-at@users.noreply.github.com> Date: Sun, 10 Sep 2023 23:59:40 -0400 Subject: [PATCH] extend FuzzyTermQuery to support json field (#2173) * extend fuzzy search for json field * comments * comments * fmt fix * comments --- src/query/automaton_weight.rs | 29 +++++- src/query/fuzzy_query.rs | 126 +++++++++++++++++++++++++-- src/query/phrase_prefix_query/mod.rs | 2 +- src/schema/term.rs | 24 +++-- 4 files changed, 164 insertions(+), 17 deletions(-) diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index e9375e96c..36be9a2c7 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use common::BitSet; use tantivy_fst::Automaton; +use super::phrase_prefix_query::prefix_end; use crate::core::SegmentReader; use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; @@ -14,6 +15,10 @@ use crate::{DocId, Score, TantivyError}; pub struct AutomatonWeight { field: Field, automaton: Arc, + // For JSON fields, the term dictionary include terms from all paths. + // We apply additional filtering based on the given JSON path, when searching within the term + // dictionary. This prevents terms from unrelated paths from matching the search criteria. + json_path_bytes: Option>, } impl AutomatonWeight @@ -26,6 +31,20 @@ where AutomatonWeight { field, automaton: automaton.into(), + json_path_bytes: None, + } + } + + /// Create a new AutomationWeight for a json path + pub fn new_for_json_path>>( + field: Field, + automaton: IntoArcA, + json_path_bytes: &[u8], + ) -> AutomatonWeight { + AutomatonWeight { + field, + automaton: automaton.into(), + json_path_bytes: Some(json_path_bytes.to_vec().into_boxed_slice()), } } @@ -34,7 +53,15 @@ where term_dict: &'a TermDictionary, ) -> io::Result> { let automaton: &A = &self.automaton; - let term_stream_builder = term_dict.search(automaton); + let mut term_stream_builder = term_dict.search(automaton); + + if let Some(json_path_bytes) = &self.json_path_bytes { + term_stream_builder = term_stream_builder.ge(json_path_bytes); + if let Some(end) = prefix_end(json_path_bytes) { + term_stream_builder = term_stream_builder.lt(&end); + } + } + term_stream_builder.into_stream() } } diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 1c6b1f479..934dfce33 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -3,7 +3,7 @@ use once_cell::sync::OnceCell; use tantivy_fst::Automaton; use crate::query::{AutomatonWeight, EnableScoring, Query, Weight}; -use crate::schema::Term; +use crate::schema::{Term, Type}; use crate::TantivyError::InvalidArgument; pub(crate) struct DfaWrapper(pub DFA); @@ -132,18 +132,46 @@ impl FuzzyTermQuery { }); let term_value = self.term.value(); - let term_text = term_value.as_str().ok_or_else(|| { - InvalidArgument("The fuzzy term query requires a string term.".to_string()) - })?; + + let term_text = if term_value.typ() == Type::Json { + if let Some(json_path_type) = term_value.json_path_type() { + if json_path_type != Type::Str { + return Err(InvalidArgument(format!( + "The fuzzy term query requires a string path type for a json term. Found \ + {:?}", + json_path_type + ))); + } + } + + std::str::from_utf8(self.term.serialized_value_bytes()).map_err(|_| { + InvalidArgument( + "Failed to convert json term value bytes to utf8 string.".to_string(), + ) + })? + } else { + term_value.as_str().ok_or_else(|| { + InvalidArgument("The fuzzy term query requires a string term.".to_string()) + })? + }; let automaton = if self.prefix { automaton_builder.build_prefix_dfa(term_text) } else { automaton_builder.build_dfa(term_text) }; - Ok(AutomatonWeight::new( - self.term.field(), - DfaWrapper(automaton), - )) + + if let Some((json_path_bytes, _)) = term_value.as_json() { + Ok(AutomatonWeight::new_for_json_path( + self.term.field(), + DfaWrapper(automaton), + json_path_bytes, + )) + } else { + Ok(AutomatonWeight::new( + self.term.field(), + DfaWrapper(automaton), + )) + } } } @@ -157,9 +185,89 @@ impl Query for FuzzyTermQuery { mod test { use super::FuzzyTermQuery; use crate::collector::{Count, TopDocs}; - use crate::schema::{Schema, TEXT}; + use crate::indexer::NoMergePolicy; + use crate::query::QueryParser; + use crate::schema::{Schema, STORED, TEXT}; use crate::{assert_nearly_equals, Index, Term}; + #[test] + pub fn test_fuzzy_json_path() -> crate::Result<()> { + // # Defining the schema + let mut schema_builder = Schema::builder(); + let attributes = schema_builder.add_json_field("attributes", TEXT | STORED); + let schema = schema_builder.build(); + + // # Indexing documents + let index = Index::create_in_ram(schema.clone()); + + let mut index_writer = index.writer_for_tests()?; + index_writer.set_merge_policy(Box::new(NoMergePolicy)); + let doc = schema.parse_document( + r#"{ + "attributes": { + "a": "japan" + } + }"#, + )?; + index_writer.add_document(doc)?; + let doc = schema.parse_document( + r#"{ + "attributes": { + "aa": "japan" + } + }"#, + )?; + index_writer.add_document(doc)?; + index_writer.commit()?; + + let reader = index.reader()?; + let searcher = reader.searcher(); + + // # Fuzzy search + let query_parser = QueryParser::for_index(&index, vec![attributes]); + + let get_json_path_term = |query: &str| -> crate::Result { + let query = query_parser.parse_query(query)?; + let mut terms = Vec::new(); + query.query_terms(&mut |term, _| { + terms.push(term.clone()); + }); + + Ok(terms[0].clone()) + }; + + // shall not match the first document due to json path mismatch + { + let term = get_json_path_term("attributes.aa:japan")?; + let fuzzy_query = FuzzyTermQuery::new(term, 2, true); + let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + assert_eq!(top_docs.len(), 1, "Expected only 1 document"); + assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document"); + } + + // shall match the first document because Levenshtein distance is 1 (substitute 'o' with + // 'a') + { + let term = get_json_path_term("attributes.a:japon")?; + + let fuzzy_query = FuzzyTermQuery::new(term, 1, true); + let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + assert_eq!(top_docs.len(), 1, "Expected only 1 document"); + assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document"); + } + + // shall not match because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n') + { + let term = get_json_path_term("attributes.a:jap")?; + + let fuzzy_query = FuzzyTermQuery::new(term, 1, true); + let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?; + assert_eq!(top_docs.len(), 0, "Expected no document"); + } + + Ok(()) + } + #[test] pub fn test_fuzzy_term() -> crate::Result<()> { let mut schema_builder = Schema::builder(); diff --git a/src/query/phrase_prefix_query/mod.rs b/src/query/phrase_prefix_query/mod.rs index 260e43311..891ac001f 100644 --- a/src/query/phrase_prefix_query/mod.rs +++ b/src/query/phrase_prefix_query/mod.rs @@ -6,7 +6,7 @@ pub use phrase_prefix_query::PhrasePrefixQuery; pub use phrase_prefix_scorer::PhrasePrefixScorer; pub use phrase_prefix_weight::PhrasePrefixWeight; -fn prefix_end(prefix_start: &[u8]) -> Option> { +pub(crate) fn prefix_end(prefix_start: &[u8]) -> Option> { let mut res = prefix_start.to_owned(); while !res.is_empty() { let end = res.len() - 1; diff --git a/src/schema/term.rs b/src/schema/term.rs index 9f9b9c6e8..55829b444 100644 --- a/src/schema/term.rs +++ b/src/schema/term.rs @@ -397,20 +397,29 @@ where B: AsRef<[u8]> Some(Ipv6Addr::from_u128(ip_u128)) } - /// Returns the json path (without non-human friendly separators), + /// Returns the json path type. + /// + /// Returns `None` if the value is not JSON. + pub fn json_path_type(&self) -> Option { + let json_value_bytes = self.as_json_value_bytes()?; + + Some(json_value_bytes.typ()) + } + + /// Returns the json path bytes (including the JSON_END_OF_PATH byte), /// and the encoded ValueBytes after the json path. /// /// Returns `None` if the value is not JSON. - pub(crate) fn as_json(&self) -> Option<(&str, ValueBytes<&[u8]>)> { + pub(crate) fn as_json(&self) -> Option<(&[u8], ValueBytes<&[u8]>)> { if self.typ() != Type::Json { return None; } let bytes = self.value_bytes(); let pos = bytes.iter().cloned().position(|b| b == JSON_END_OF_PATH)?; - let (json_path_bytes, term) = bytes.split_at(pos); - let json_path = str::from_utf8(json_path_bytes).ok()?; - Some((json_path, ValueBytes::wrap(&term[1..]))) + // split at pos + 1, so that json_path_bytes includes the JSON_END_OF_PATH byte. + let (json_path_bytes, term) = bytes.split_at(pos + 1); + Some((json_path_bytes, ValueBytes::wrap(&term))) } /// Returns the encoded ValueBytes after the json path. @@ -469,7 +478,10 @@ where B: AsRef<[u8]> write_opt(f, self.as_bytes())?; } Type::Json => { - if let Some((path, sub_value_bytes)) = self.as_json() { + if let Some((path_bytes, sub_value_bytes)) = self.as_json() { + // Remove the JSON_END_OF_PATH byte & convert to utf8. + let path = str::from_utf8(&path_bytes[..path_bytes.len() - 1]) + .map_err(|_| std::fmt::Error)?; let path_pretty = path.replace(JSON_PATH_SEGMENT_SEP_STR, "."); write!(f, "path={path_pretty}, ")?; sub_value_bytes.debug_value_bytes(f)?;