diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs
index e9375e96c..36be9a2c7 100644
--- a/src/query/automaton_weight.rs
+++ b/src/query/automaton_weight.rs
@@ -4,6 +4,7 @@ use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;
+use super::phrase_prefix_query::prefix_end;
use crate::core::SegmentReader;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
@@ -14,6 +15,10 @@ use crate::{DocId, Score, TantivyError};
pub struct AutomatonWeight {
field: Field,
automaton: Arc,
+ // For JSON fields, the term dictionary include terms from all paths.
+ // We apply additional filtering based on the given JSON path, when searching within the term
+ // dictionary. This prevents terms from unrelated paths from matching the search criteria.
+ json_path_bytes: Option>,
}
impl AutomatonWeight
@@ -26,6 +31,20 @@ where
AutomatonWeight {
field,
automaton: automaton.into(),
+ json_path_bytes: None,
+ }
+ }
+
+ /// Create a new AutomationWeight for a json path
+ pub fn new_for_json_path>>(
+ field: Field,
+ automaton: IntoArcA,
+ json_path_bytes: &[u8],
+ ) -> AutomatonWeight {
+ AutomatonWeight {
+ field,
+ automaton: automaton.into(),
+ json_path_bytes: Some(json_path_bytes.to_vec().into_boxed_slice()),
}
}
@@ -34,7 +53,15 @@ where
term_dict: &'a TermDictionary,
) -> io::Result> {
let automaton: &A = &self.automaton;
- let term_stream_builder = term_dict.search(automaton);
+ let mut term_stream_builder = term_dict.search(automaton);
+
+ if let Some(json_path_bytes) = &self.json_path_bytes {
+ term_stream_builder = term_stream_builder.ge(json_path_bytes);
+ if let Some(end) = prefix_end(json_path_bytes) {
+ term_stream_builder = term_stream_builder.lt(&end);
+ }
+ }
+
term_stream_builder.into_stream()
}
}
diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs
index 1c6b1f479..934dfce33 100644
--- a/src/query/fuzzy_query.rs
+++ b/src/query/fuzzy_query.rs
@@ -3,7 +3,7 @@ use once_cell::sync::OnceCell;
use tantivy_fst::Automaton;
use crate::query::{AutomatonWeight, EnableScoring, Query, Weight};
-use crate::schema::Term;
+use crate::schema::{Term, Type};
use crate::TantivyError::InvalidArgument;
pub(crate) struct DfaWrapper(pub DFA);
@@ -132,18 +132,46 @@ impl FuzzyTermQuery {
});
let term_value = self.term.value();
- let term_text = term_value.as_str().ok_or_else(|| {
- InvalidArgument("The fuzzy term query requires a string term.".to_string())
- })?;
+
+ let term_text = if term_value.typ() == Type::Json {
+ if let Some(json_path_type) = term_value.json_path_type() {
+ if json_path_type != Type::Str {
+ return Err(InvalidArgument(format!(
+ "The fuzzy term query requires a string path type for a json term. Found \
+ {:?}",
+ json_path_type
+ )));
+ }
+ }
+
+ std::str::from_utf8(self.term.serialized_value_bytes()).map_err(|_| {
+ InvalidArgument(
+ "Failed to convert json term value bytes to utf8 string.".to_string(),
+ )
+ })?
+ } else {
+ term_value.as_str().ok_or_else(|| {
+ InvalidArgument("The fuzzy term query requires a string term.".to_string())
+ })?
+ };
let automaton = if self.prefix {
automaton_builder.build_prefix_dfa(term_text)
} else {
automaton_builder.build_dfa(term_text)
};
- Ok(AutomatonWeight::new(
- self.term.field(),
- DfaWrapper(automaton),
- ))
+
+ if let Some((json_path_bytes, _)) = term_value.as_json() {
+ Ok(AutomatonWeight::new_for_json_path(
+ self.term.field(),
+ DfaWrapper(automaton),
+ json_path_bytes,
+ ))
+ } else {
+ Ok(AutomatonWeight::new(
+ self.term.field(),
+ DfaWrapper(automaton),
+ ))
+ }
}
}
@@ -157,9 +185,89 @@ impl Query for FuzzyTermQuery {
mod test {
use super::FuzzyTermQuery;
use crate::collector::{Count, TopDocs};
- use crate::schema::{Schema, TEXT};
+ use crate::indexer::NoMergePolicy;
+ use crate::query::QueryParser;
+ use crate::schema::{Schema, STORED, TEXT};
use crate::{assert_nearly_equals, Index, Term};
+ #[test]
+ pub fn test_fuzzy_json_path() -> crate::Result<()> {
+ // # Defining the schema
+ let mut schema_builder = Schema::builder();
+ let attributes = schema_builder.add_json_field("attributes", TEXT | STORED);
+ let schema = schema_builder.build();
+
+ // # Indexing documents
+ let index = Index::create_in_ram(schema.clone());
+
+ let mut index_writer = index.writer_for_tests()?;
+ index_writer.set_merge_policy(Box::new(NoMergePolicy));
+ let doc = schema.parse_document(
+ r#"{
+ "attributes": {
+ "a": "japan"
+ }
+ }"#,
+ )?;
+ index_writer.add_document(doc)?;
+ let doc = schema.parse_document(
+ r#"{
+ "attributes": {
+ "aa": "japan"
+ }
+ }"#,
+ )?;
+ index_writer.add_document(doc)?;
+ index_writer.commit()?;
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ // # Fuzzy search
+ let query_parser = QueryParser::for_index(&index, vec![attributes]);
+
+ let get_json_path_term = |query: &str| -> crate::Result {
+ let query = query_parser.parse_query(query)?;
+ let mut terms = Vec::new();
+ query.query_terms(&mut |term, _| {
+ terms.push(term.clone());
+ });
+
+ Ok(terms[0].clone())
+ };
+
+ // shall not match the first document due to json path mismatch
+ {
+ let term = get_json_path_term("attributes.aa:japan")?;
+ let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
+ let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
+ assert_eq!(top_docs.len(), 1, "Expected only 1 document");
+ assert_eq!(top_docs[0].1.doc_id, 1, "Expected the second document");
+ }
+
+ // shall match the first document because Levenshtein distance is 1 (substitute 'o' with
+ // 'a')
+ {
+ let term = get_json_path_term("attributes.a:japon")?;
+
+ let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
+ let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
+ assert_eq!(top_docs.len(), 1, "Expected only 1 document");
+ assert_eq!(top_docs[0].1.doc_id, 0, "Expected the first document");
+ }
+
+ // shall not match because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
+ {
+ let term = get_json_path_term("attributes.a:jap")?;
+
+ let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
+ let top_docs = searcher.search(&fuzzy_query, &TopDocs::with_limit(2))?;
+ assert_eq!(top_docs.len(), 0, "Expected no document");
+ }
+
+ Ok(())
+ }
+
#[test]
pub fn test_fuzzy_term() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
diff --git a/src/query/phrase_prefix_query/mod.rs b/src/query/phrase_prefix_query/mod.rs
index 260e43311..891ac001f 100644
--- a/src/query/phrase_prefix_query/mod.rs
+++ b/src/query/phrase_prefix_query/mod.rs
@@ -6,7 +6,7 @@ pub use phrase_prefix_query::PhrasePrefixQuery;
pub use phrase_prefix_scorer::PhrasePrefixScorer;
pub use phrase_prefix_weight::PhrasePrefixWeight;
-fn prefix_end(prefix_start: &[u8]) -> Option> {
+pub(crate) fn prefix_end(prefix_start: &[u8]) -> Option> {
let mut res = prefix_start.to_owned();
while !res.is_empty() {
let end = res.len() - 1;
diff --git a/src/schema/term.rs b/src/schema/term.rs
index 9f9b9c6e8..55829b444 100644
--- a/src/schema/term.rs
+++ b/src/schema/term.rs
@@ -397,20 +397,29 @@ where B: AsRef<[u8]>
Some(Ipv6Addr::from_u128(ip_u128))
}
- /// Returns the json path (without non-human friendly separators),
+ /// Returns the json path type.
+ ///
+ /// Returns `None` if the value is not JSON.
+ pub fn json_path_type(&self) -> Option {
+ let json_value_bytes = self.as_json_value_bytes()?;
+
+ Some(json_value_bytes.typ())
+ }
+
+ /// Returns the json path bytes (including the JSON_END_OF_PATH byte),
/// and the encoded ValueBytes after the json path.
///
/// Returns `None` if the value is not JSON.
- pub(crate) fn as_json(&self) -> Option<(&str, ValueBytes<&[u8]>)> {
+ pub(crate) fn as_json(&self) -> Option<(&[u8], ValueBytes<&[u8]>)> {
if self.typ() != Type::Json {
return None;
}
let bytes = self.value_bytes();
let pos = bytes.iter().cloned().position(|b| b == JSON_END_OF_PATH)?;
- let (json_path_bytes, term) = bytes.split_at(pos);
- let json_path = str::from_utf8(json_path_bytes).ok()?;
- Some((json_path, ValueBytes::wrap(&term[1..])))
+ // split at pos + 1, so that json_path_bytes includes the JSON_END_OF_PATH byte.
+ let (json_path_bytes, term) = bytes.split_at(pos + 1);
+ Some((json_path_bytes, ValueBytes::wrap(&term)))
}
/// Returns the encoded ValueBytes after the json path.
@@ -469,7 +478,10 @@ where B: AsRef<[u8]>
write_opt(f, self.as_bytes())?;
}
Type::Json => {
- if let Some((path, sub_value_bytes)) = self.as_json() {
+ if let Some((path_bytes, sub_value_bytes)) = self.as_json() {
+ // Remove the JSON_END_OF_PATH byte & convert to utf8.
+ let path = str::from_utf8(&path_bytes[..path_bytes.len() - 1])
+ .map_err(|_| std::fmt::Error)?;
let path_pretty = path.replace(JSON_PATH_SEGMENT_SEP_STR, ".");
write!(f, "path={path_pretty}, ")?;
sub_value_bytes.debug_value_bytes(f)?;