Compare commits

...

4 Commits

Author SHA1 Message Date
Paul Masurel
45da5829bc added unit test 2019-09-10 08:31:27 +09:00
Paul Masurel
e2f7aab39f Added exact fuzzy query 2019-09-07 15:39:04 +09:00
Paul Masurel
1b9cbdb672 blop 2019-09-07 15:05:21 +09:00
Paul Masurel
a8f3cf9679 Added an incremental search crate 2019-09-07 13:23:58 +09:00
8 changed files with 865 additions and 66 deletions

View File

@@ -21,6 +21,7 @@ tantivy-fst = "0.1"
memmap = {version = "0.7", optional=true}
lz4 = {version="1.20", optional=true}
snap = {version="0.2"}
derive_builder = "0.7"
atomicwrites = {version="0.2.2", optional=true}
tempfile = "3.0"
log = "0.4"
@@ -30,7 +31,7 @@ serde_json = "1.0"
num_cpus = "1.2"
fs2={version="0.4", optional=true}
itertools = "0.8"
levenshtein_automata = {version="0.1", features=["fst_automaton"]}
levenshtein_automata = "0.1"
notify = {version="4", optional=true}
bit-set = "0.5"
uuid = { version = "0.7.2", features = ["v4", "serde"] }
@@ -81,7 +82,7 @@ unstable = [] # useful for benches.
wasm-bindgen = ["uuid/wasm-bindgen"]
[workspace]
members = ["query-grammar"]
members = ["query-grammar", "incremental-search"]
[badges]
travis-ci = { repository = "tantivy-search/tantivy" }

View File

@@ -1,3 +1,3 @@
test:
echo "Run test only... No examples."
cargo test --tests --lib
cargo test --all --tests --lib

View File

@@ -0,0 +1,10 @@
[package]
name = "incremental-search"
version = "0.11.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
edition = "2018"
[dependencies]
derive_builder = "0.7"
tantivy = {path = ".."}

View File

@@ -0,0 +1,395 @@
use std::fmt;
use std::u64;
#[derive(Clone, Copy, Eq, PartialEq)]
pub(crate) struct TinySet(u64);
impl fmt::Debug for TinySet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.into_iter().collect::<Vec<u32>>().fmt(f)
}
}
pub struct TinySetIterator(TinySet);
impl Iterator for TinySetIterator {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop_lowest()
}
}
impl IntoIterator for TinySet {
type Item = u32;
type IntoIter = TinySetIterator;
fn into_iter(self) -> Self::IntoIter {
TinySetIterator(self)
}
}
impl TinySet {
/// Returns an empty `TinySet`.
pub fn empty() -> TinySet {
TinySet(0u64)
}
/// Returns the complement of the set in `[0, 64[`.
fn complement(self) -> TinySet {
TinySet(!self.0)
}
/// Returns true iff the `TinySet` contains the element `el`.
pub fn contains(self, el: u32) -> bool {
!self.intersect(TinySet::singleton(el)).is_empty()
}
/// Returns the intersection of `self` and `other`
pub fn intersect(self, other: TinySet) -> TinySet {
TinySet(self.0 & other.0)
}
/// Creates a new `TinySet` containing only one element
/// within `[0; 64[`
#[inline(always)]
pub fn singleton(el: u32) -> TinySet {
TinySet(1u64 << u64::from(el))
}
/// Insert a new element within [0..64[
#[inline(always)]
pub fn insert(self, el: u32) -> TinySet {
self.union(TinySet::singleton(el))
}
/// Insert a new element within [0..64[
#[inline(always)]
pub fn insert_mut(&mut self, el: u32) -> bool {
let old = *self;
*self = old.insert(el);
old != *self
}
/// Returns the union of two tinysets
#[inline(always)]
pub fn union(self, other: TinySet) -> TinySet {
TinySet(self.0 | other.0)
}
/// Returns true iff the `TinySet` is empty.
#[inline(always)]
pub fn is_empty(self) -> bool {
self.0 == 0u64
}
/// Returns the lowest element in the `TinySet`
/// and removes it.
#[inline(always)]
pub fn pop_lowest(&mut self) -> Option<u32> {
if self.is_empty() {
None
} else {
let lowest = self.0.trailing_zeros() as u32;
self.0 ^= TinySet::singleton(lowest).0;
Some(lowest)
}
}
/// Returns a `TinySet` than contains all values up
/// to limit excluded.
///
/// The limit is assumed to be strictly lower than 64.
pub fn range_lower(upper_bound: u32) -> TinySet {
TinySet((1u64 << u64::from(upper_bound % 64u32)) - 1u64)
}
/// Returns a `TinySet` that contains all values greater
/// or equal to the given limit, included. (and up to 63)
///
/// The limit is assumed to be strictly lower than 64.
pub fn range_greater_or_equal(from_included: u32) -> TinySet {
TinySet::range_lower(from_included).complement()
}
pub fn clear(&mut self) {
self.0 = 0u64;
}
pub fn len(self) -> u32 {
self.0.count_ones()
}
}
#[derive(Clone)]
pub struct BitSet {
tinysets: Box<[TinySet]>,
len: usize, //< Technically it should be u32, but we
// count multiple inserts.
// `usize` guards us from overflow.
max_value: u32,
}
fn num_buckets(max_val: u32) -> u32 {
(max_val + 63u32) / 64u32
}
impl BitSet {
/// Create a new `BitSet` that may contain elements
/// within `[0, max_val[`.
pub fn with_max_value(max_value: u32) -> BitSet {
let num_buckets = num_buckets(max_value);
let tinybisets = vec![TinySet::empty(); num_buckets as usize].into_boxed_slice();
BitSet {
tinysets: tinybisets,
len: 0,
max_value,
}
}
/// Removes all elements from the `BitSet`.
pub fn clear(&mut self) {
for tinyset in self.tinysets.iter_mut() {
*tinyset = TinySet::empty();
}
}
/// Returns the number of elements in the `BitSet`.
pub fn len(&self) -> usize {
self.len
}
/// Inserts an element in the `BitSet`
pub fn insert(&mut self, el: u32) {
// we do not check saturated els.
let higher = el / 64u32;
let lower = el % 64u32;
self.len += if self.tinysets[higher as usize].insert_mut(lower) {
1
} else {
0
};
}
/// Returns true iff the elements is in the `BitSet`.
pub fn contains(&self, el: u32) -> bool {
self.tinyset(el / 64u32).contains(el % 64)
}
/// Returns the first non-empty `TinySet` associated to a bucket lower
/// or greater than bucket.
///
/// Reminder: the tiny set with the bucket `bucket`, represents the
/// elements from `bucket * 64` to `(bucket+1) * 64`.
pub(crate) fn first_non_empty_bucket(&self, bucket: u32) -> Option<u32> {
self.tinysets[bucket as usize..]
.iter()
.cloned()
.position(|tinyset| !tinyset.is_empty())
.map(|delta_bucket| bucket + delta_bucket as u32)
}
pub fn max_value(&self) -> u32 {
self.max_value
}
/// Returns the tiny bitset representing the
/// the set restricted to the number range from
/// `bucket * 64` to `(bucket + 1) * 64`.
pub(crate) fn tinyset(&self, bucket: u32) -> TinySet {
self.tinysets[bucket as usize]
}
}
#[cfg(test)]
mod tests {
use super::BitSet;
use super::TinySet;
use crate::docset::DocSet;
use crate::query::BitSetDocSet;
use crate::tests;
use crate::tests::generate_nonunique_unsorted;
use std::collections::BTreeSet;
use std::collections::HashSet;
#[test]
fn test_tiny_set() {
assert!(TinySet::empty().is_empty());
{
let mut u = TinySet::empty().insert(1u32);
assert_eq!(u.pop_lowest(), Some(1u32));
assert!(u.pop_lowest().is_none())
}
{
let mut u = TinySet::empty().insert(1u32).insert(1u32);
assert_eq!(u.pop_lowest(), Some(1u32));
assert!(u.pop_lowest().is_none())
}
{
let mut u = TinySet::empty().insert(2u32);
assert_eq!(u.pop_lowest(), Some(2u32));
u.insert_mut(1u32);
assert_eq!(u.pop_lowest(), Some(1u32));
assert!(u.pop_lowest().is_none());
}
{
let mut u = TinySet::empty().insert(63u32);
assert_eq!(u.pop_lowest(), Some(63u32));
assert!(u.pop_lowest().is_none());
}
}
#[test]
fn test_bitset() {
let test_against_hashset = |els: &[u32], max_value: u32| {
let mut hashset: HashSet<u32> = HashSet::new();
let mut bitset = BitSet::with_max_value(max_value);
for &el in els {
assert!(el < max_value);
hashset.insert(el);
bitset.insert(el);
}
for el in 0..max_value {
assert_eq!(hashset.contains(&el), bitset.contains(el));
}
assert_eq!(bitset.max_value(), max_value);
};
test_against_hashset(&[], 0);
test_against_hashset(&[], 1);
test_against_hashset(&[0u32], 1);
test_against_hashset(&[0u32], 100);
test_against_hashset(&[1u32, 2u32], 4);
test_against_hashset(&[99u32], 100);
test_against_hashset(&[63u32], 64);
test_against_hashset(&[62u32, 63u32], 64);
}
#[test]
fn test_bitset_large() {
let arr = generate_nonunique_unsorted(100_000, 5_000);
let mut btreeset: BTreeSet<u32> = BTreeSet::new();
let mut bitset = BitSet::with_max_value(100_000);
for el in arr {
btreeset.insert(el);
bitset.insert(el);
}
for i in 0..100_000 {
assert_eq!(btreeset.contains(&i), bitset.contains(i));
}
assert_eq!(btreeset.len(), bitset.len());
let mut bitset_docset = BitSetDocSet::from(bitset);
for el in btreeset.into_iter() {
bitset_docset.advance();
assert_eq!(bitset_docset.doc(), el);
}
assert!(!bitset_docset.advance());
}
#[test]
fn test_bitset_num_buckets() {
use super::num_buckets;
assert_eq!(num_buckets(0u32), 0);
assert_eq!(num_buckets(1u32), 1);
assert_eq!(num_buckets(64u32), 1);
assert_eq!(num_buckets(65u32), 2);
assert_eq!(num_buckets(128u32), 2);
assert_eq!(num_buckets(129u32), 3);
}
#[test]
fn test_tinyset_range() {
assert_eq!(
TinySet::range_lower(3).into_iter().collect::<Vec<u32>>(),
[0, 1, 2]
);
assert!(TinySet::range_lower(0).is_empty());
assert_eq!(
TinySet::range_lower(63).into_iter().collect::<Vec<u32>>(),
(0u32..63u32).collect::<Vec<_>>()
);
assert_eq!(
TinySet::range_lower(1).into_iter().collect::<Vec<u32>>(),
[0]
);
assert_eq!(
TinySet::range_lower(2).into_iter().collect::<Vec<u32>>(),
[0, 1]
);
assert_eq!(
TinySet::range_greater_or_equal(3)
.into_iter()
.collect::<Vec<u32>>(),
(3u32..64u32).collect::<Vec<_>>()
);
}
#[test]
fn test_bitset_len() {
let mut bitset = BitSet::with_max_value(1_000);
assert_eq!(bitset.len(), 0);
bitset.insert(3u32);
assert_eq!(bitset.len(), 1);
bitset.insert(103u32);
assert_eq!(bitset.len(), 2);
bitset.insert(3u32);
assert_eq!(bitset.len(), 2);
bitset.insert(103u32);
assert_eq!(bitset.len(), 2);
bitset.insert(104u32);
assert_eq!(bitset.len(), 3);
}
#[test]
fn test_bitset_clear() {
let mut bitset = BitSet::with_max_value(1_000);
let els = tests::sample(1_000, 0.01f64);
for &el in &els {
bitset.insert(el);
}
assert!(els.iter().all(|el| bitset.contains(*el)));
bitset.clear();
for el in 0u32..1000u32 {
assert!(!bitset.contains(el));
}
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use super::BitSet;
use super::TinySet;
use test;
#[bench]
fn bench_tinyset_pop(b: &mut test::Bencher) {
b.iter(|| {
let mut tinyset = TinySet::singleton(test::black_box(31u32));
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
});
}
#[bench]
fn bench_tinyset_sum(b: &mut test::Bencher) {
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
b.iter(|| {
assert_eq!(test::black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
});
}
#[bench]
fn bench_tinyarr_sum(b: &mut test::Bencher) {
let v = [10u32, 14u32, 21u32];
b.iter(|| test::black_box(v).iter().cloned().sum::<u32>());
}
#[bench]
fn bench_bitset_initialize(b: &mut test::Bencher) {
b.iter(|| BitSet::with_max_value(1_000_000));
}
}

View File

@@ -0,0 +1,266 @@
use tantivy::query::{BooleanQuery, FuzzyTermQuery, EmptyQuery};
use derive_builder::Builder;
use std::str::FromStr;
use tantivy::query::{FuzzyConfiguration, FuzzyConfigurationBuilder, Query, Occur};
use tantivy::schema::Field;
use tantivy::{Searcher, TantivyError, DocAddress, Term, Document};
use tantivy::collector::TopDocs;
use std::ops::Deref;
#[derive(Debug)]
pub struct IncrementalSearchQuery {
pub terms: Vec<String>,
pub last_is_prefix: bool,
}
impl IncrementalSearchQuery {
pub fn fuzzy_configurations(&self) -> Vec<FuzzyConfigurations> {
if self.terms.is_empty() {
return Vec::default();
}
let single_term_confs: Vec<FuzzyConfigurationBuilder> = (0u8..3u8)
.map(|d: u8| {
let mut builder = FuzzyConfigurationBuilder::default();
builder.distance(d).transposition_cost_one(true);
builder
})
.collect();
let mut configurations: Vec<Vec<FuzzyConfigurationBuilder>> = single_term_confs
.iter()
.map(|conf| vec![conf.clone()])
.collect();
let mut new_configurations = Vec::new();
for _ in 1..self.terms.len() {
new_configurations.clear();
for single_term_conf in &single_term_confs {
for configuration in &configurations {
let mut new_configuration: Vec<FuzzyConfigurationBuilder> = configuration.clone();
new_configuration.push(single_term_conf.clone());
new_configurations.push(new_configuration);
}
}
std::mem::swap(&mut configurations, &mut new_configurations);
}
if self.last_is_prefix {
for configuration in &mut configurations {
if let Some(last_conf) = configuration.last_mut() {
last_conf.prefix(true);
}
}
}
let mut fuzzy_configurations: Vec<FuzzyConfigurations> = configurations
.into_iter()
.map(FuzzyConfigurations::from)
.collect();
fuzzy_configurations.sort_by(|left, right| left.cost.partial_cmp(&right.cost).unwrap());
fuzzy_configurations
}
fn search_query(&self, fields: &[Field], configurations: FuzzyConfigurations) -> Box<dyn Query> {
if self.terms.is_empty() {
Box::new(EmptyQuery)
} else if self.terms.len() == 1 {
build_query_for_fields(fields, &self.terms[0], &configurations.configurations[0])
} else {
Box::new(BooleanQuery::from(self.terms.iter()
.zip(configurations.configurations.iter())
.map(|(term, configuration)|
(Occur::Must, build_query_for_fields(fields, &term, &configuration))
)
.collect::<Vec<_>>()))
}
}
}
#[derive(Debug)]
pub struct FuzzyConfigurations {
configurations: Vec<FuzzyConfiguration>,
cost: f64,
}
fn compute_cost(fuzzy_confs: &[FuzzyConfiguration]) -> f64 {
fuzzy_confs
.iter()
.map(|fuzzy_conf| {
let weight = if fuzzy_conf.prefix { 30f64 } else { 5f64 };
weight * f64::from(fuzzy_conf.distance)
})
.sum()
}
impl From<Vec<FuzzyConfigurationBuilder>> for FuzzyConfigurations {
fn from(fuzzy_conf_builder: Vec<FuzzyConfigurationBuilder>) -> FuzzyConfigurations {
let configurations = fuzzy_conf_builder
.into_iter()
.map(|conf| conf.build().unwrap())
.collect::<Vec<FuzzyConfiguration>>();
let cost = compute_cost(&configurations);
FuzzyConfigurations {
configurations,
cost,
}
}
}
#[derive(Debug)]
pub struct ParseIncrementalQueryError;
impl Into<TantivyError> for ParseIncrementalQueryError {
fn into(self) -> TantivyError {
TantivyError::InvalidArgument(format!("Invalid query: {:?}", self))
}
}
impl FromStr for IncrementalSearchQuery {
type Err = ParseIncrementalQueryError;
fn from_str(query_str: &str) -> Result<Self, Self::Err> {
let terms: Vec<String> = query_str
.split_whitespace()
.map(ToString::to_string)
.collect();
Ok(IncrementalSearchQuery {
terms,
last_is_prefix: query_str
.chars()
.last()
.map(|c| !c.is_whitespace())
.unwrap_or(false),
})
}
}
fn build_query_for_fields(fields: &[Field], term_text: &str, conf: &FuzzyConfiguration) -> Box<dyn Query> {
assert!(fields.len() > 0);
if fields.len() > 1 {
let term_queries: Vec<(Occur, Box<dyn Query>)> = fields
.iter()
.map(|&field| {
let term = Term::from_field_text(field, term_text);
let query = FuzzyTermQuery::new_from_configuration(term, conf.clone());
let boxed_query: Box<dyn Query> = Box::new(query);
(Occur::Must, boxed_query)
})
.collect();
Box::new(BooleanQuery::from(term_queries))
} else {
let term = Term::from_field_text(fields[0], term_text);
Box::new( FuzzyTermQuery::new_from_configuration(term, conf.clone()))
}
}
pub struct IncrementalSearchResult {
pub docs: Vec<Document>
}
#[derive(Builder, Default)]
pub struct IncrementalSearch {
nhits: usize,
#[builder(default)]
search_fields: Vec<Field>,
#[builder(default)]
return_fields: Vec<Field>,
}
impl IncrementalSearch {
pub fn search<S: Deref<Target=Searcher>>(
&self,
query: &str,
searcher: &S,
) -> tantivy::Result<IncrementalSearchResult> {
let searcher = searcher.deref();
let inc_search_query: IncrementalSearchQuery =
FromStr::from_str(query).map_err(Into::<TantivyError>::into)?;
let mut results: Vec<DocAddress> = Vec::default();
let mut remaining = self.nhits;
for fuzzy_conf in inc_search_query.fuzzy_configurations() {
if remaining == 0 {
break;
}
let query = inc_search_query.search_query(&self.search_fields[..], fuzzy_conf);
let new_docs = searcher.search(query.as_ref(), &TopDocs::with_limit(remaining))?;
// TODO(pmasurel) remove already added docs.
results.extend(new_docs.into_iter()
.map(|(_, doc_address)| doc_address));
remaining = self.nhits - results.len();
if remaining == 0 {
break;
}
}
let docs: Vec<Document> = results.into_iter()
.map(|doc_address: DocAddress| searcher.doc(doc_address))
.collect::<tantivy::Result<_>>()?;
Ok(IncrementalSearchResult {
docs
})
}
}
#[cfg(test)]
mod tests {
use tantivy::doc;
use crate::{IncrementalSearch, IncrementalSearchBuilder, IncrementalSearchQuery};
use std::str::FromStr;
use tantivy::schema::{SchemaBuilder, TEXT, STORED};
use tantivy::Index;
#[test]
fn test_incremental_search() {
let incremental_search = IncrementalSearchBuilder::default()
.nhits(10)
.build()
.unwrap();
}
#[test]
fn test_incremental_search_query_parse_empty() {
let query = IncrementalSearchQuery::from_str("").unwrap();
assert_eq!(query.terms, Vec::<String>::new());
assert!(!query.last_is_prefix);
}
#[test]
fn test_incremental_search_query_parse_trailing_whitespace() {
let query = IncrementalSearchQuery::from_str("hello happy tax pa ").unwrap();
assert_eq!(query.terms, vec!["hello", "happy", "tax", "pa"]);
assert!(!query.last_is_prefix);
}
#[test]
fn test_incremental_search_query_parse_unicode_whitespace() {
let query = IncrementalSearchQuery::from_str("hello happy tax pa ").unwrap();
assert_eq!(query.terms, vec!["hello", "happy", "tax", "pa"]);
assert!(!query.last_is_prefix);
}
#[test]
fn test_incremental_search_query_parse() {
let query = IncrementalSearchQuery::from_str("hello happy tax pa").unwrap();
assert_eq!(query.terms, vec!["hello", "happy", "tax", "pa"]);
assert!(query.last_is_prefix);
}
#[test]
fn test_blop() {
let mut schema_builder = SchemaBuilder::new();
let body = schema_builder.add_text_field("body", TEXT | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 30_000_000).unwrap();
index_writer.add_document(doc!(body=> "hello happy tax payer"));
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let incremental_search: IncrementalSearch = IncrementalSearchBuilder::default()
.nhits(1)
.search_fields(vec![body])
.build()
.unwrap();
let top_docs = incremental_search.search("hello hapy t", &searcher).unwrap();
assert_eq!(top_docs.docs.len(), 1);
}
}

View File

@@ -1,12 +1,14 @@
use crate::error::TantivyError::InvalidArgument;
use crate::query::{AutomatonWeight, Query, Weight};
use crate::schema::Term;
use crate::termdict::WrappedDFA;
use crate::Result;
use crate::Searcher;
use levenshtein_automata::{LevenshteinAutomatonBuilder, DFA};
use levenshtein_automata::{Distance, LevenshteinAutomatonBuilder, DFA};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::ops::Range;
use derive_builder::Builder;
/// A range of Levenshtein distances that we will build DFAs for our terms
/// The computation is exponential, so best keep it to low single digits
@@ -24,6 +26,38 @@ static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Laz
lev_builder_cache
});
#[derive(Builder, Default, Clone, Debug)]
pub struct FuzzyConfiguration {
/// How many changes are we going to allow
pub distance: u8,
/// Should a transposition cost 1 or 2?
#[builder(default)]
pub transposition_cost_one: bool,
#[builder(default)]
pub prefix: bool,
/// If true, only the term with a levenshtein of exactly `distance` will match.
/// If false, terms at a distance `<=` to `distance` will match.
#[builder(default)]
pub exact_distance: bool,
}
fn build_dfa(fuzzy_configuration: &FuzzyConfiguration, term_text: &str) -> Result<DFA> {
let automaton_builder = LEV_BUILDER
.get(&(fuzzy_configuration.distance, fuzzy_configuration.transposition_cost_one))
.ok_or_else(|| {
InvalidArgument(format!(
"Levenshtein distance of {} is not allowed. Choose a value in the {:?} range",
fuzzy_configuration.distance, VALID_LEVENSHTEIN_DISTANCE_RANGE
))
})?;
if fuzzy_configuration.prefix {
Ok(automaton_builder.build_prefix_dfa(term_text))
} else {
Ok(automaton_builder.build_dfa(term_text))
}
}
/// A Fuzzy Query matches all of the documents
/// containing a specific term that is within
/// Levenshtein distance
@@ -41,32 +75,19 @@ static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Laz
/// let index = Index::create_in_ram(schema);
/// {
/// let mut index_writer = index.writer(3_000_000)?;
/// index_writer.add_document(doc!(
/// title => "The Name of the Wind",
/// ));
/// index_writer.add_document(doc!(
/// title => "The Diary of Muadib",
/// ));
/// index_writer.add_document(doc!(
/// title => "A Dairy Cow",
/// ));
/// index_writer.add_document(doc!(
/// title => "The Diary of a Young Girl",
/// ));
/// index_writer.add_document(doc!(title => "The Name of the Wind"));
/// index_writer.add_document(doc!(title => "The Diary of Muadib"));
/// index_writer.add_document(doc!(title => "A Dairy Cow"));
/// index_writer.add_document(doc!(title => "The Diary of a Young Girl"));
/// index_writer.commit().unwrap();
/// }
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// {
///
/// let term = Term::from_field_text(title, "Diary");
/// let query = FuzzyTermQuery::new(term, 1, true);
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count)).unwrap();
/// assert_eq!(count, 2);
/// assert_eq!(top_docs.len(), 2);
/// }
///
/// let term = Term::from_field_text(title, "Diary");
/// let query = FuzzyTermQuery::new(term, 1, true);
/// let (top_docs, count) = searcher.search(&query, &(TopDocs::with_limit(2), Count)).unwrap();
/// assert_eq!(count, 2);
/// assert_eq!(top_docs.len(), 2);
/// Ok(())
/// }
/// ```
@@ -74,54 +95,58 @@ static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Laz
pub struct FuzzyTermQuery {
/// What term are we searching
term: Term,
/// How many changes are we going to allow
distance: u8,
/// Should a transposition cost 1 or 2?
transposition_cost_one: bool,
///
prefix: bool,
configuration: FuzzyConfiguration
}
impl FuzzyTermQuery {
pub fn new_from_configuration(term: Term, configuration: FuzzyConfiguration) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
configuration
}
}
/// Creates a new Fuzzy Query
pub fn new(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
distance,
transposition_cost_one,
prefix: false,
}
}
/// Creates a new Fuzzy Query that treats transpositions as cost one rather than two
pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
distance,
transposition_cost_one,
prefix: true,
}
}
fn specialized_weight(&self) -> Result<AutomatonWeight<DFA>> {
// LEV_BUILDER is a HashMap, whose `get` method returns an Option
match LEV_BUILDER.get(&(self.distance, false)) {
// Unwrap the option and build the Ok(AutomatonWeight)
Some(automaton_builder) => {
let automaton = automaton_builder.build_dfa(self.term.text());
Ok(AutomatonWeight::new(self.term.field(), automaton))
configuration: FuzzyConfiguration {
distance,
transposition_cost_one,
prefix: false,
exact_distance: false
}
None => Err(InvalidArgument(format!(
"Levenshtein distance of {} is not allowed. Choose a value in the {:?} range",
self.distance, VALID_LEVENSHTEIN_DISTANCE_RANGE
))),
}
}
}
impl Query for FuzzyTermQuery {
fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight()?))
let dfa = build_dfa(&self.configuration, self.term.text())?;
// TODO optimize for distance = 0 and possibly prefix
if self.configuration.exact_distance {
let target_distance = self.configuration.distance;
let wrapped_dfa = WrappedDFA {
dfa,
condition: move |distance: Distance| distance == Distance::Exact(target_distance),
};
Ok(Box::new(AutomatonWeight::new(
self.term.field(),
wrapped_dfa,
)))
} else {
let wrapped_dfa = WrappedDFA {
dfa,
condition: move |distance: Distance| match distance {
Distance::Exact(_) => true,
Distance::AtLeast(_) => false,
},
};
Ok(Box::new(AutomatonWeight::new(
self.term.field(),
wrapped_dfa,
)))
}
}
}
@@ -134,6 +159,7 @@ mod test {
use crate::tests::assert_nearly_equals;
use crate::Index;
use crate::Term;
use super::FuzzyConfigurationBuilder;
#[test]
pub fn test_fuzzy_term() {
@@ -155,7 +181,6 @@ mod test {
let searcher = reader.searcher();
{
let term = Term::from_field_text(country_field, "japon");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
@@ -164,5 +189,73 @@ mod test {
let (score, _) = top_docs[0];
assert_nearly_equals(1f32, score);
}
{
let term = Term::from_field_text(country_field, "japon");
let fuzzy_conf = FuzzyConfigurationBuilder::default()
.distance(2)
.exact_distance(true)
.build()
.unwrap();
let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert!(top_docs.is_empty());
}
{
let term = Term::from_field_text(country_field, "japon");
let fuzzy_conf = FuzzyConfigurationBuilder::default()
.distance(1)
.exact_distance(true)
.build()
.unwrap();
let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 1);
}
{
let term = Term::from_field_text(country_field, "jpp");
let fuzzy_conf = FuzzyConfigurationBuilder::default()
.distance(1)
.prefix(true)
.build()
.unwrap();
let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 1);
}
{
let term = Term::from_field_text(country_field, "jpaan");
let fuzzy_conf = FuzzyConfigurationBuilder::default()
.distance(1)
.exact_distance(true)
.transposition_cost_one(true)
.build()
.unwrap();
let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 1);
}
{
let term = Term::from_field_text(country_field, "jpaan");
let fuzzy_conf = FuzzyConfigurationBuilder::default()
.distance(2)
.exact_distance(true)
.transposition_cost_one(false)
.build()
.unwrap();
let fuzzy_query = FuzzyTermQuery::new_from_configuration(term, fuzzy_conf);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 1);
}
}
}

View File

@@ -40,7 +40,7 @@ pub use self::boolean_query::BooleanQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::Exclude;
pub use self::explanation::Explanation;
pub use self::fuzzy_query::FuzzyTermQuery;
pub use self::fuzzy_query::{FuzzyTermQuery, FuzzyConfiguration, FuzzyConfigurationBuilder};
pub use self::intersection::intersect_scorers;
pub use self::phrase_query::PhraseQuery;
pub use self::query::Query;

View File

@@ -31,14 +31,43 @@ mod termdict;
pub use self::merger::TermMerger;
pub use self::streamer::{TermStreamer, TermStreamerBuilder};
pub use self::termdict::{TermDictionary, TermDictionaryBuilder};
use levenshtein_automata::{Distance, DFA, SINK_STATE};
use tantivy_fst::Automaton;
pub(crate) struct WrappedDFA<Cond> {
pub dfa: DFA,
pub condition: Cond,
}
impl<Cond: Fn(Distance) -> bool> Automaton for WrappedDFA<Cond> {
type State = u32;
fn start(&self) -> Self::State {
self.dfa.initial_state()
}
fn is_match(&self, state: &Self::State) -> bool {
let distance = self.dfa.distance(*state);
(self.condition)(distance)
}
fn can_match(&self, state: &Self::State) -> bool {
*state != SINK_STATE
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
self.dfa.transition(*state, byte)
}
}
#[cfg(test)]
mod tests {
use super::{TermDictionary, TermDictionaryBuilder, TermStreamer};
use super::{TermDictionary, TermDictionaryBuilder, TermStreamer, WrappedDFA};
use crate::core::Index;
use crate::directory::{Directory, RAMDirectory, ReadOnlySource};
use crate::postings::TermInfo;
use crate::schema::{Document, FieldType, Schema, TEXT};
use levenshtein_automata::Distance;
use std::path::PathBuf;
use std::str;
@@ -423,9 +452,14 @@ mod tests {
// We can now build an entire dfa.
let lev_automaton_builder = LevenshteinAutomatonBuilder::new(2, true);
let automaton = lev_automaton_builder.build_dfa("Spaen");
let mut range = term_dict.search(automaton).into_stream();
let wrapped_dfa = WrappedDFA {
dfa: lev_automaton_builder.build_dfa("Spaen"),
condition: |distance| match distance {
Distance::Exact(_) => true,
Distance::AtLeast(_) => false,
},
};
let mut range = term_dict.search(wrapped_dfa).into_stream();
// get the first finding
assert!(range.advance());