From 45da5829bc7ea6dad5a0dcaaa16041802a05ce99 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 7 Sep 2019 15:45:05 +0900 Subject: [PATCH] added unit test --- Cargo.toml | 1 + Makefile | 2 +- incremental-search/src/bitset.rs | 395 +++++++++++++++++++++++++++++++ incremental-search/src/lib.rs | 227 +++++++++++++++--- src/query/fuzzy_query.rs | 210 ++++++++++------ src/query/mod.rs | 2 +- 6 files changed, 731 insertions(+), 106 deletions(-) create mode 100644 incremental-search/src/bitset.rs diff --git a/Cargo.toml b/Cargo.toml index c5d3b6694..e9f117661 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ tantivy-fst = "0.1" memmap = {version = "0.7", optional=true} lz4 = {version="1.20", optional=true} snap = {version="0.2"} +derive_builder = "0.7" atomicwrites = {version="0.2.2", optional=true} tempfile = "3.0" log = "0.4" diff --git a/Makefile b/Makefile index 05f0f4447..4bd8dc413 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,3 @@ test: echo "Run test only... No examples." - cargo test --tests --lib + cargo test --all --tests --lib diff --git a/incremental-search/src/bitset.rs b/incremental-search/src/bitset.rs new file mode 100644 index 000000000..527aa8d4a --- /dev/null +++ b/incremental-search/src/bitset.rs @@ -0,0 +1,395 @@ +use std::fmt; +use std::u64; + +#[derive(Clone, Copy, Eq, PartialEq)] +pub(crate) struct TinySet(u64); + +impl fmt::Debug for TinySet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.into_iter().collect::>().fmt(f) + } +} + +pub struct TinySetIterator(TinySet); +impl Iterator for TinySetIterator { + type Item = u32; + + fn next(&mut self) -> Option { + self.0.pop_lowest() + } +} + +impl IntoIterator for TinySet { + type Item = u32; + type IntoIter = TinySetIterator; + fn into_iter(self) -> Self::IntoIter { + TinySetIterator(self) + } +} + +impl TinySet { + /// Returns an empty `TinySet`. + pub fn empty() -> TinySet { + TinySet(0u64) + } + + /// Returns the complement of the set in `[0, 64[`. + fn complement(self) -> TinySet { + TinySet(!self.0) + } + + /// Returns true iff the `TinySet` contains the element `el`. + pub fn contains(self, el: u32) -> bool { + !self.intersect(TinySet::singleton(el)).is_empty() + } + + /// Returns the intersection of `self` and `other` + pub fn intersect(self, other: TinySet) -> TinySet { + TinySet(self.0 & other.0) + } + + /// Creates a new `TinySet` containing only one element + /// within `[0; 64[` + #[inline(always)] + pub fn singleton(el: u32) -> TinySet { + TinySet(1u64 << u64::from(el)) + } + + /// Insert a new element within [0..64[ + #[inline(always)] + pub fn insert(self, el: u32) -> TinySet { + self.union(TinySet::singleton(el)) + } + + /// Insert a new element within [0..64[ + #[inline(always)] + pub fn insert_mut(&mut self, el: u32) -> bool { + let old = *self; + *self = old.insert(el); + old != *self + } + + /// Returns the union of two tinysets + #[inline(always)] + pub fn union(self, other: TinySet) -> TinySet { + TinySet(self.0 | other.0) + } + + /// Returns true iff the `TinySet` is empty. + #[inline(always)] + pub fn is_empty(self) -> bool { + self.0 == 0u64 + } + + /// Returns the lowest element in the `TinySet` + /// and removes it. + #[inline(always)] + pub fn pop_lowest(&mut self) -> Option { + if self.is_empty() { + None + } else { + let lowest = self.0.trailing_zeros() as u32; + self.0 ^= TinySet::singleton(lowest).0; + Some(lowest) + } + } + + /// Returns a `TinySet` than contains all values up + /// to limit excluded. + /// + /// The limit is assumed to be strictly lower than 64. + pub fn range_lower(upper_bound: u32) -> TinySet { + TinySet((1u64 << u64::from(upper_bound % 64u32)) - 1u64) + } + + /// Returns a `TinySet` that contains all values greater + /// or equal to the given limit, included. (and up to 63) + /// + /// The limit is assumed to be strictly lower than 64. + pub fn range_greater_or_equal(from_included: u32) -> TinySet { + TinySet::range_lower(from_included).complement() + } + + pub fn clear(&mut self) { + self.0 = 0u64; + } + + pub fn len(self) -> u32 { + self.0.count_ones() + } +} + +#[derive(Clone)] +pub struct BitSet { + tinysets: Box<[TinySet]>, + len: usize, //< Technically it should be u32, but we + // count multiple inserts. + // `usize` guards us from overflow. + max_value: u32, +} + +fn num_buckets(max_val: u32) -> u32 { + (max_val + 63u32) / 64u32 +} + +impl BitSet { + /// Create a new `BitSet` that may contain elements + /// within `[0, max_val[`. + pub fn with_max_value(max_value: u32) -> BitSet { + let num_buckets = num_buckets(max_value); + let tinybisets = vec![TinySet::empty(); num_buckets as usize].into_boxed_slice(); + BitSet { + tinysets: tinybisets, + len: 0, + max_value, + } + } + + /// Removes all elements from the `BitSet`. + pub fn clear(&mut self) { + for tinyset in self.tinysets.iter_mut() { + *tinyset = TinySet::empty(); + } + } + + /// Returns the number of elements in the `BitSet`. + pub fn len(&self) -> usize { + self.len + } + + /// Inserts an element in the `BitSet` + pub fn insert(&mut self, el: u32) { + // we do not check saturated els. + let higher = el / 64u32; + let lower = el % 64u32; + self.len += if self.tinysets[higher as usize].insert_mut(lower) { + 1 + } else { + 0 + }; + } + + /// Returns true iff the elements is in the `BitSet`. + pub fn contains(&self, el: u32) -> bool { + self.tinyset(el / 64u32).contains(el % 64) + } + + /// Returns the first non-empty `TinySet` associated to a bucket lower + /// or greater than bucket. + /// + /// Reminder: the tiny set with the bucket `bucket`, represents the + /// elements from `bucket * 64` to `(bucket+1) * 64`. + pub(crate) fn first_non_empty_bucket(&self, bucket: u32) -> Option { + self.tinysets[bucket as usize..] + .iter() + .cloned() + .position(|tinyset| !tinyset.is_empty()) + .map(|delta_bucket| bucket + delta_bucket as u32) + } + + pub fn max_value(&self) -> u32 { + self.max_value + } + + /// Returns the tiny bitset representing the + /// the set restricted to the number range from + /// `bucket * 64` to `(bucket + 1) * 64`. + pub(crate) fn tinyset(&self, bucket: u32) -> TinySet { + self.tinysets[bucket as usize] + } +} + +#[cfg(test)] +mod tests { + + use super::BitSet; + use super::TinySet; + use crate::docset::DocSet; + use crate::query::BitSetDocSet; + use crate::tests; + use crate::tests::generate_nonunique_unsorted; + use std::collections::BTreeSet; + use std::collections::HashSet; + + #[test] + fn test_tiny_set() { + assert!(TinySet::empty().is_empty()); + { + let mut u = TinySet::empty().insert(1u32); + assert_eq!(u.pop_lowest(), Some(1u32)); + assert!(u.pop_lowest().is_none()) + } + { + let mut u = TinySet::empty().insert(1u32).insert(1u32); + assert_eq!(u.pop_lowest(), Some(1u32)); + assert!(u.pop_lowest().is_none()) + } + { + let mut u = TinySet::empty().insert(2u32); + assert_eq!(u.pop_lowest(), Some(2u32)); + u.insert_mut(1u32); + assert_eq!(u.pop_lowest(), Some(1u32)); + assert!(u.pop_lowest().is_none()); + } + { + let mut u = TinySet::empty().insert(63u32); + assert_eq!(u.pop_lowest(), Some(63u32)); + assert!(u.pop_lowest().is_none()); + } + } + + #[test] + fn test_bitset() { + let test_against_hashset = |els: &[u32], max_value: u32| { + let mut hashset: HashSet = HashSet::new(); + let mut bitset = BitSet::with_max_value(max_value); + for &el in els { + assert!(el < max_value); + hashset.insert(el); + bitset.insert(el); + } + for el in 0..max_value { + assert_eq!(hashset.contains(&el), bitset.contains(el)); + } + assert_eq!(bitset.max_value(), max_value); + }; + + test_against_hashset(&[], 0); + test_against_hashset(&[], 1); + test_against_hashset(&[0u32], 1); + test_against_hashset(&[0u32], 100); + test_against_hashset(&[1u32, 2u32], 4); + test_against_hashset(&[99u32], 100); + test_against_hashset(&[63u32], 64); + test_against_hashset(&[62u32, 63u32], 64); + } + + #[test] + fn test_bitset_large() { + let arr = generate_nonunique_unsorted(100_000, 5_000); + let mut btreeset: BTreeSet = BTreeSet::new(); + let mut bitset = BitSet::with_max_value(100_000); + for el in arr { + btreeset.insert(el); + bitset.insert(el); + } + for i in 0..100_000 { + assert_eq!(btreeset.contains(&i), bitset.contains(i)); + } + assert_eq!(btreeset.len(), bitset.len()); + let mut bitset_docset = BitSetDocSet::from(bitset); + for el in btreeset.into_iter() { + bitset_docset.advance(); + assert_eq!(bitset_docset.doc(), el); + } + assert!(!bitset_docset.advance()); + } + + #[test] + fn test_bitset_num_buckets() { + use super::num_buckets; + assert_eq!(num_buckets(0u32), 0); + assert_eq!(num_buckets(1u32), 1); + assert_eq!(num_buckets(64u32), 1); + assert_eq!(num_buckets(65u32), 2); + assert_eq!(num_buckets(128u32), 2); + assert_eq!(num_buckets(129u32), 3); + } + + #[test] + fn test_tinyset_range() { + assert_eq!( + TinySet::range_lower(3).into_iter().collect::>(), + [0, 1, 2] + ); + assert!(TinySet::range_lower(0).is_empty()); + assert_eq!( + TinySet::range_lower(63).into_iter().collect::>(), + (0u32..63u32).collect::>() + ); + assert_eq!( + TinySet::range_lower(1).into_iter().collect::>(), + [0] + ); + assert_eq!( + TinySet::range_lower(2).into_iter().collect::>(), + [0, 1] + ); + assert_eq!( + TinySet::range_greater_or_equal(3) + .into_iter() + .collect::>(), + (3u32..64u32).collect::>() + ); + } + + #[test] + fn test_bitset_len() { + let mut bitset = BitSet::with_max_value(1_000); + assert_eq!(bitset.len(), 0); + bitset.insert(3u32); + assert_eq!(bitset.len(), 1); + bitset.insert(103u32); + assert_eq!(bitset.len(), 2); + bitset.insert(3u32); + assert_eq!(bitset.len(), 2); + bitset.insert(103u32); + assert_eq!(bitset.len(), 2); + bitset.insert(104u32); + assert_eq!(bitset.len(), 3); + } + + #[test] + fn test_bitset_clear() { + let mut bitset = BitSet::with_max_value(1_000); + let els = tests::sample(1_000, 0.01f64); + for &el in &els { + bitset.insert(el); + } + assert!(els.iter().all(|el| bitset.contains(*el))); + bitset.clear(); + for el in 0u32..1000u32 { + assert!(!bitset.contains(el)); + } + } +} + +#[cfg(all(test, feature = "unstable"))] +mod bench { + + use super::BitSet; + use super::TinySet; + use test; + + #[bench] + fn bench_tinyset_pop(b: &mut test::Bencher) { + b.iter(|| { + let mut tinyset = TinySet::singleton(test::black_box(31u32)); + tinyset.pop_lowest(); + tinyset.pop_lowest(); + tinyset.pop_lowest(); + tinyset.pop_lowest(); + tinyset.pop_lowest(); + tinyset.pop_lowest(); + }); + } + + #[bench] + fn bench_tinyset_sum(b: &mut test::Bencher) { + let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32); + b.iter(|| { + assert_eq!(test::black_box(tiny_set).into_iter().sum::(), 45u32); + }); + } + + #[bench] + fn bench_tinyarr_sum(b: &mut test::Bencher) { + let v = [10u32, 14u32, 21u32]; + b.iter(|| test::black_box(v).iter().cloned().sum::()); + } + + #[bench] + fn bench_bitset_initialize(b: &mut test::Bencher) { + b.iter(|| BitSet::with_max_value(1_000_000)); + } +} diff --git a/incremental-search/src/lib.rs b/incremental-search/src/lib.rs index c043e92a1..7f9e5a2e4 100644 --- a/incremental-search/src/lib.rs +++ b/incremental-search/src/lib.rs @@ -1,25 +1,108 @@ +use tantivy::query::{BooleanQuery, FuzzyTermQuery, EmptyQuery}; use derive_builder::Builder; use std::str::FromStr; +use tantivy::query::{FuzzyConfiguration, FuzzyConfigurationBuilder, Query, Occur}; use tantivy::schema::Field; -use tantivy::{Searcher, TantivyError}; +use tantivy::{Searcher, TantivyError, DocAddress, Term, Document}; +use tantivy::collector::TopDocs; +use std::ops::Deref; -#[derive(Builder, Default)] -pub struct IncrementalSearch { - nhits: usize, - #[builder(default)] - search_fields: Vec, - #[builder(default)] - return_fields: Vec, -} #[derive(Debug)] pub struct IncrementalSearchQuery { pub terms: Vec, - pub prefix: Option, + pub last_is_prefix: bool, } -// TODO have a smarter, more robust query parser. -// This is a first stab +impl IncrementalSearchQuery { + pub fn fuzzy_configurations(&self) -> Vec { + if self.terms.is_empty() { + return Vec::default(); + } + let single_term_confs: Vec = (0u8..3u8) + .map(|d: u8| { + let mut builder = FuzzyConfigurationBuilder::default(); + builder.distance(d).transposition_cost_one(true); + builder + }) + .collect(); + let mut configurations: Vec> = single_term_confs + .iter() + .map(|conf| vec![conf.clone()]) + .collect(); + let mut new_configurations = Vec::new(); + for _ in 1..self.terms.len() { + new_configurations.clear(); + for single_term_conf in &single_term_confs { + for configuration in &configurations { + let mut new_configuration: Vec = configuration.clone(); + new_configuration.push(single_term_conf.clone()); + new_configurations.push(new_configuration); + } + } + std::mem::swap(&mut configurations, &mut new_configurations); + } + if self.last_is_prefix { + for configuration in &mut configurations { + if let Some(last_conf) = configuration.last_mut() { + last_conf.prefix(true); + } + } + } + let mut fuzzy_configurations: Vec = configurations + .into_iter() + .map(FuzzyConfigurations::from) + .collect(); + fuzzy_configurations.sort_by(|left, right| left.cost.partial_cmp(&right.cost).unwrap()); + fuzzy_configurations + } + + fn search_query(&self, fields: &[Field], configurations: FuzzyConfigurations) -> Box { + if self.terms.is_empty() { + Box::new(EmptyQuery) + } else if self.terms.len() == 1 { + build_query_for_fields(fields, &self.terms[0], &configurations.configurations[0]) + } else { + Box::new(BooleanQuery::from(self.terms.iter() + .zip(configurations.configurations.iter()) + .map(|(term, configuration)| + (Occur::Must, build_query_for_fields(fields, &term, &configuration)) + ) + .collect::>())) + } + } +} + +#[derive(Debug)] +pub struct FuzzyConfigurations { + configurations: Vec, + cost: f64, +} + + +fn compute_cost(fuzzy_confs: &[FuzzyConfiguration]) -> f64 { + fuzzy_confs + .iter() + .map(|fuzzy_conf| { + let weight = if fuzzy_conf.prefix { 30f64 } else { 5f64 }; + weight * f64::from(fuzzy_conf.distance) + }) + .sum() +} + +impl From> for FuzzyConfigurations { + fn from(fuzzy_conf_builder: Vec) -> FuzzyConfigurations { + let configurations = fuzzy_conf_builder + .into_iter() + .map(|conf| conf.build().unwrap()) + .collect::>(); + let cost = compute_cost(&configurations); + FuzzyConfigurations { + configurations, + cost, + } + } +} #[derive(Debug)] pub struct ParseIncrementalQueryError; @@ -34,42 +117,96 @@ impl FromStr for IncrementalSearchQuery { type Err = ParseIncrementalQueryError; fn from_str(query_str: &str) -> Result { - let mut terms: Vec = query_str + let terms: Vec = query_str .split_whitespace() .map(ToString::to_string) .collect(); - if query_str.ends_with(|c: char| c.is_whitespace()) { - Ok(IncrementalSearchQuery { - terms, - prefix: None, - }) - } else { - let prefix = terms.pop(); - Ok(IncrementalSearchQuery { terms, prefix }) - } + Ok(IncrementalSearchQuery { + terms, + last_is_prefix: query_str + .chars() + .last() + .map(|c| !c.is_whitespace()) + .unwrap_or(false), + }) } } -pub struct IncrementalSearchResult; +fn build_query_for_fields(fields: &[Field], term_text: &str, conf: &FuzzyConfiguration) -> Box { + assert!(fields.len() > 0); + if fields.len() > 1 { + let term_queries: Vec<(Occur, Box)> = fields + .iter() + .map(|&field| { + let term = Term::from_field_text(field, term_text); + let query = FuzzyTermQuery::new_from_configuration(term, conf.clone()); + let boxed_query: Box = Box::new(query); + (Occur::Must, boxed_query) + }) + .collect(); + Box::new(BooleanQuery::from(term_queries)) + } else { + let term = Term::from_field_text(fields[0], term_text); + Box::new( FuzzyTermQuery::new_from_configuration(term, conf.clone())) + } + +} + +pub struct IncrementalSearchResult { + pub docs: Vec +} + +#[derive(Builder, Default)] +pub struct IncrementalSearch { + nhits: usize, + #[builder(default)] + search_fields: Vec, + #[builder(default)] + return_fields: Vec, +} impl IncrementalSearch { - pub fn search( + + pub fn search>( &self, query: &str, - searcher: &Searcher, + searcher: &S, ) -> tantivy::Result { - let query: IncrementalSearchQuery = + let searcher = searcher.deref(); + let inc_search_query: IncrementalSearchQuery = FromStr::from_str(query).map_err(Into::::into)?; - let result = IncrementalSearchResult; - Ok(result) + let mut results: Vec = Vec::default(); + let mut remaining = self.nhits; + for fuzzy_conf in inc_search_query.fuzzy_configurations() { + if remaining == 0 { + break; + } + let query = inc_search_query.search_query(&self.search_fields[..], fuzzy_conf); + let new_docs = searcher.search(query.as_ref(), &TopDocs::with_limit(remaining))?; + // TODO(pmasurel) remove already added docs. + results.extend(new_docs.into_iter() + .map(|(_, doc_address)| doc_address)); + remaining = self.nhits - results.len(); + if remaining == 0 { + break; + } + } + let docs: Vec = results.into_iter() + .map(|doc_address: DocAddress| searcher.doc(doc_address)) + .collect::>()?; + Ok(IncrementalSearchResult { + docs + }) } } #[cfg(test)] mod tests { - use super::{IncrementalSearch, IncrementalSearchBuilder}; - use crate::IncrementalSearchQuery; + use tantivy::doc; + use crate::{IncrementalSearch, IncrementalSearchBuilder, IncrementalSearchQuery}; use std::str::FromStr; + use tantivy::schema::{SchemaBuilder, TEXT, STORED}; + use tantivy::Index; #[test] fn test_incremental_search() { @@ -83,27 +220,47 @@ mod tests { fn test_incremental_search_query_parse_empty() { let query = IncrementalSearchQuery::from_str("").unwrap(); assert_eq!(query.terms, Vec::::new()); - assert_eq!(query.prefix, None); + assert!(!query.last_is_prefix); } #[test] fn test_incremental_search_query_parse_trailing_whitespace() { let query = IncrementalSearchQuery::from_str("hello happy tax pa ").unwrap(); assert_eq!(query.terms, vec!["hello", "happy", "tax", "pa"]); - assert_eq!(query.prefix, None); + assert!(!query.last_is_prefix); } #[test] fn test_incremental_search_query_parse_unicode_whitespace() { let query = IncrementalSearchQuery::from_str("hello happy tax pa ").unwrap(); assert_eq!(query.terms, vec!["hello", "happy", "tax", "pa"]); - assert_eq!(query.prefix, None); + assert!(!query.last_is_prefix); } #[test] fn test_incremental_search_query_parse() { let query = IncrementalSearchQuery::from_str("hello happy tax pa").unwrap(); - assert_eq!(query.terms, vec!["hello", "happy", "tax"]); - assert_eq!(query.prefix, Some("pa".to_string())); + assert_eq!(query.terms, vec!["hello", "happy", "tax", "pa"]); + assert!(query.last_is_prefix); + } + + #[test] + fn test_blop() { + let mut schema_builder = SchemaBuilder::new(); + let body = schema_builder.add_text_field("body", TEXT | STORED); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_with_num_threads(1, 30_000_000).unwrap(); + index_writer.add_document(doc!(body=> "hello happy tax payer")); + index_writer.commit().unwrap(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let incremental_search: IncrementalSearch = IncrementalSearchBuilder::default() + .nhits(1) + .search_fields(vec![body]) + .build() + .unwrap(); + let top_docs = incremental_search.search("hello hapy t", &searcher).unwrap(); + assert_eq!(top_docs.docs.len(), 1); } } diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 9f51b815f..b521cb312 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -4,10 +4,11 @@ use crate::schema::Term; use crate::termdict::WrappedDFA; use crate::Result; use crate::Searcher; -use levenshtein_automata::{Distance, LevenshteinAutomatonBuilder}; +use levenshtein_automata::{Distance, LevenshteinAutomatonBuilder, DFA}; use once_cell::sync::Lazy; use std::collections::HashMap; use std::ops::Range; +use derive_builder::Builder; /// A range of Levenshtein distances that we will build DFAs for our terms /// The computation is exponential, so best keep it to low single digits @@ -25,6 +26,38 @@ static LEV_BUILDER: Lazy> = Laz lev_builder_cache }); + +#[derive(Builder, Default, Clone, Debug)] +pub struct FuzzyConfiguration { + /// How many changes are we going to allow + pub distance: u8, + /// Should a transposition cost 1 or 2? + #[builder(default)] + pub transposition_cost_one: bool, + #[builder(default)] + pub prefix: bool, + /// If true, only the term with a levenshtein of exactly `distance` will match. + /// If false, terms at a distance `<=` to `distance` will match. + #[builder(default)] + pub exact_distance: bool, +} + +fn build_dfa(fuzzy_configuration: &FuzzyConfiguration, term_text: &str) -> Result { + let automaton_builder = LEV_BUILDER + .get(&(fuzzy_configuration.distance, fuzzy_configuration.transposition_cost_one)) + .ok_or_else(|| { + InvalidArgument(format!( + "Levenshtein distance of {} is not allowed. Choose a value in the {:?} range", + fuzzy_configuration.distance, VALID_LEVENSHTEIN_DISTANCE_RANGE + )) + })?; + if fuzzy_configuration.prefix { + Ok(automaton_builder.build_prefix_dfa(term_text)) + } else { + Ok(automaton_builder.build_dfa(term_text)) + } +} + /// A Fuzzy Query matches all of the documents /// containing a specific term that is within /// Levenshtein distance @@ -62,86 +95,57 @@ static LEV_BUILDER: Lazy> = Laz pub struct FuzzyTermQuery { /// What term are we searching term: Term, - /// How many changes are we going to allow - distance: u8, - /// Should a transposition cost 1 or 2? - transposition_cost_one: bool, - /// - prefix: bool, - /// If true, only the term with a levenshtein of exactly `distance` will match. - /// If false, terms at a distance `<=` to `distance` will match. - exact_distance: bool + configuration: FuzzyConfiguration } impl FuzzyTermQuery { + pub fn new_from_configuration(term: Term, configuration: FuzzyConfiguration) -> FuzzyTermQuery { + FuzzyTermQuery { + term, + configuration + } + } + /// Creates a new Fuzzy Query pub fn new(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery { FuzzyTermQuery { term, - distance, - transposition_cost_one, - prefix: false, - exact_distance: false - } - } - - /// Creates a new Fuzzy Query in which term matching are exactly matching the - /// given distance. - pub fn new_exact(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery { - FuzzyTermQuery { - term, - distance, - transposition_cost_one, - prefix: false, - exact_distance: true - } - } - - /// Creates a new Fuzzy Query that treats transpositions as cost one rather than two - pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery { - FuzzyTermQuery { - term, - distance, - transposition_cost_one, - prefix: true, - exact_distance: false + configuration: FuzzyConfiguration { + distance, + transposition_cost_one, + prefix: false, + exact_distance: false + } } } } impl Query for FuzzyTermQuery { fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result> { - // LEV_BUILDER is a HashMap, whose `get` method returns an Option - match LEV_BUILDER.get(&(self.distance, false)) { - // Unwrap the option and build the Ok(AutomatonWeight) - Some(automaton_builder) => { - let dfa = automaton_builder.build_dfa(self.term.text()); - let target_distance = self.distance; - if self.exact_distance { - let wrapped_dfa = WrappedDFA { - dfa, - condition: move |distance: Distance| { - distance == Distance::Exact(target_distance) - } - }; - Ok(Box::new(AutomatonWeight::new(self.term.field(), wrapped_dfa))) - } else { - let wrapped_dfa = WrappedDFA { - dfa, - condition: move |distance: Distance| { - match distance { - Distance::Exact(_) => true, - Distance::AtLeast(_) => false, - } - } - }; - Ok(Box::new(AutomatonWeight::new(self.term.field(), wrapped_dfa))) - } - } - None => Err(InvalidArgument(format!( - "Levenshtein distance of {} is not allowed. Choose a value in the {:?} range", - self.distance, VALID_LEVENSHTEIN_DISTANCE_RANGE - ))), + let dfa = build_dfa(&self.configuration, self.term.text())?; + // TODO optimize for distance = 0 and possibly prefix + if self.configuration.exact_distance { + let target_distance = self.configuration.distance; + let wrapped_dfa = WrappedDFA { + dfa, + condition: move |distance: Distance| distance == Distance::Exact(target_distance), + }; + Ok(Box::new(AutomatonWeight::new( + self.term.field(), + wrapped_dfa, + ))) + } else { + let wrapped_dfa = WrappedDFA { + dfa, + condition: move |distance: Distance| match distance { + Distance::Exact(_) => true, + Distance::AtLeast(_) => false, + }, + }; + Ok(Box::new(AutomatonWeight::new( + self.term.field(), + wrapped_dfa, + ))) } } } @@ -155,6 +159,7 @@ mod test { use crate::tests::assert_nearly_equals; use crate::Index; use crate::Term; + use super::FuzzyConfigurationBuilder; #[test] pub fn test_fuzzy_term() { @@ -176,7 +181,6 @@ mod test { let searcher = reader.searcher(); { let term = Term::from_field_text(country_field, "japon"); - let fuzzy_query = FuzzyTermQuery::new(term, 1, true); let top_docs = searcher .search(&fuzzy_query, &TopDocs::with_limit(2)) @@ -185,5 +189,73 @@ mod test { let (score, _) = top_docs[0]; assert_nearly_equals(1f32, score); } + { + let term = Term::from_field_text(country_field, "japon"); + let fuzzy_conf = FuzzyConfigurationBuilder::default() + .distance(2) + .exact_distance(true) + .build() + .unwrap(); + let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf); + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(2)) + .unwrap(); + assert!(top_docs.is_empty()); + } + { + let term = Term::from_field_text(country_field, "japon"); + let fuzzy_conf = FuzzyConfigurationBuilder::default() + .distance(1) + .exact_distance(true) + .build() + .unwrap(); + let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf); + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(2)) + .unwrap(); + assert_eq!(top_docs.len(), 1); + } + { + let term = Term::from_field_text(country_field, "jpp"); + let fuzzy_conf = FuzzyConfigurationBuilder::default() + .distance(1) + .prefix(true) + .build() + .unwrap(); + let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf); + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(2)) + .unwrap(); + assert_eq!(top_docs.len(), 1); + } + { + let term = Term::from_field_text(country_field, "jpaan"); + let fuzzy_conf = FuzzyConfigurationBuilder::default() + .distance(1) + .exact_distance(true) + .transposition_cost_one(true) + .build() + .unwrap(); + let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf); + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(2)) + .unwrap(); + assert_eq!(top_docs.len(), 1); + } + { + let term = Term::from_field_text(country_field, "jpaan"); + let fuzzy_conf = FuzzyConfigurationBuilder::default() + .distance(2) + .exact_distance(true) + .transposition_cost_one(false) + .build() + .unwrap(); + let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf); + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(2)) + .unwrap(); + assert_eq!(top_docs.len(), 1); + } } + } diff --git a/src/query/mod.rs b/src/query/mod.rs index 82653cb81..50098b3a4 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -40,7 +40,7 @@ pub use self::boolean_query::BooleanQuery; pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight}; pub use self::exclude::Exclude; pub use self::explanation::Explanation; -pub use self::fuzzy_query::FuzzyTermQuery; +pub use self::fuzzy_query::{FuzzyTermQuery, FuzzyConfiguration, FuzzyConfigurationBuilder}; pub use self::intersection::intersect_scorers; pub use self::phrase_query::PhraseQuery; pub use self::query::Query;