Use Levenshtein distance to score documents in fuzzy term queries

This commit is contained in:
Neil Hansen
2024-01-31 16:11:16 -08:00
committed by Stu Hood
parent 794ff1ffc9
commit dee2dd3f21
9 changed files with 385 additions and 42 deletions

View File

@@ -199,7 +199,9 @@ fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
/// deserialize Option<f64> from string or float
pub(crate) fn deserialize_option_f64<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
where D: Deserializer<'de> {
where
D: Deserializer<'de>,
{
struct StringOrFloatVisitor;
impl Visitor<'_> for StringOrFloatVisitor {
@@ -210,32 +212,44 @@ where D: Deserializer<'de> {
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
parse_str_into_f64(value).map(Some)
}
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(Some(value))
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(Some(value as f64))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(Some(value as f64))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(None)
}
}
@@ -245,7 +259,9 @@ where D: Deserializer<'de> {
/// deserialize f64 from string or float
pub(crate) fn deserialize_f64<'de, D>(deserializer: D) -> Result<f64, D::Error>
where D: Deserializer<'de> {
where
D: Deserializer<'de>,
{
struct StringOrFloatVisitor;
impl Visitor<'_> for StringOrFloatVisitor {
@@ -256,22 +272,30 @@ where D: Deserializer<'de> {
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
parse_str_into_f64(value)
}
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(value)
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(value as f64)
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(value as f64)
}
}

View File

@@ -1,15 +1,18 @@
use std::any::{Any, TypeId};
use std::io;
use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;
use super::phrase_prefix_query::prefix_end;
use super::BufferedUnionScorer;
use crate::index::SegmentReader;
use crate::postings::TermInfo;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::query::fuzzy_query::DfaWrapper;
use crate::query::score_combiner::SumCombiner;
use crate::query::{ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::termdict::{TermDictionary, TermWithStateStreamer};
use crate::{DocId, Score, TantivyError};
/// A weight struct for Fuzzy Term and Regex Queries
@@ -52,9 +55,9 @@ where
fn automaton_stream<'a>(
&'a self,
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
) -> io::Result<TermWithStateStreamer<'a, &'a A>> {
let automaton: &A = &self.automaton;
let mut term_stream_builder = term_dict.search(automaton);
let mut term_stream_builder = term_dict.search_with_state(automaton);
if let Some(json_path_bytes) = &self.json_path_bytes {
term_stream_builder = term_stream_builder.ge(json_path_bytes);
@@ -85,35 +88,27 @@ where
A::State: Clone,
{
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
let mut scorers = vec![];
while let Some((_term, term_info, state)) = term_stream.next() {
let score = automaton_score(self.automaton.as_ref(), state);
let segment_postings =
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
let scorer = ConstScorer::new(segment_postings, boost * score);
scorers.push(scorer);
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost);
Ok(Box::new(const_scorer))
let scorer = BufferedUnionScorer::build(scorers, SumCombiner::default);
Ok(Box::new(scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0))
Ok(Explanation::new("AutomatonScorer", scorer.score()))
} else {
Err(TantivyError::InvalidArgument(
"Document does not exist".to_string(),
@@ -122,6 +117,25 @@ where
}
}
fn automaton_score<A>(automaton: &A, state: A::State) -> f32
where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
if TypeId::of::<DfaWrapper>() == automaton.type_id() && TypeId::of::<u32>() == state.type_id() {
let dfa = automaton as *const A as *const DfaWrapper;
let dfa = unsafe { &*dfa };
let id = &state as *const A::State as *const u32;
let id = unsafe { *id };
let dist = dfa.0.distance(id).to_u8() as f32;
1.0 / (1.0 + dist)
} else {
1.0
}
}
#[cfg(test)]
mod tests {
use tantivy_fst::Automaton;

View File

@@ -299,7 +299,7 @@ mod test {
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
assert_nearly_equals!(0.5, score);
}
// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')

View File

@@ -4,6 +4,7 @@ mod term_weight;
pub use self::term_query::TermQuery;
pub use self::term_scorer::TermScorer;
#[cfg(test)]
mod tests {

View File

@@ -24,5 +24,7 @@ mod term_info_store;
mod termdict;
pub use self::merger::TermMerger;
pub use self::streamer::{TermStreamer, TermStreamerBuilder};
pub use self::streamer::{
TermStreamer, TermStreamerBuilder, TermWithStateStreamer, TermWithStateStreamerBuilder,
};
pub use self::termdict::{TermDictionary, TermDictionaryBuilder};

View File

@@ -1,7 +1,7 @@
use std::io;
use tantivy_fst::automaton::AlwaysMatch;
use tantivy_fst::map::{Stream, StreamBuilder};
use tantivy_fst::map::{Stream, StreamBuilder, StreamWithState};
use tantivy_fst::{Automaton, IntoStreamer, Streamer};
use super::TermDictionary;
@@ -145,3 +145,152 @@ where A: Automaton
}
}
}
/// `TermWithStateStreamerBuilder` is a helper object used to define
/// a range of terms that should be streamed.
pub struct TermWithStateStreamerBuilder<'a, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
{
fst_map: &'a TermDictionary,
stream_builder: StreamBuilder<'a, A>,
}
impl<'a, A> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton,
A::State: Clone,
{
pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self {
TermWithStateStreamerBuilder {
fst_map,
stream_builder,
}
}
/// Limit the range to terms greater or equal to the bound
pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.ge(bound);
self
}
/// Limit the range to terms strictly greater than the bound
pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.gt(bound);
self
}
/// Limit the range to terms lesser or equal to the bound
pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.le(bound);
self
}
/// Limit the range to terms lesser or equal to the bound
pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.lt(bound);
self
}
/// Iterate over the range backwards.
pub fn backward(mut self) -> Self {
self.stream_builder = self.stream_builder.backward();
self
}
/// Creates the stream corresponding to the range
/// of terms defined using the `TermWithStateStreamerBuilder`.
pub fn into_stream(self) -> io::Result<TermWithStateStreamer<'a, A>> {
Ok(TermWithStateStreamer {
fst_map: self.fst_map,
stream: self.stream_builder.with_state().into_stream(),
term_ord: 0u64,
current_key: Vec::with_capacity(100),
current_value: TermInfo::default(),
current_state: None,
})
}
}
/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment.
/// Terms are guaranteed to be sorted.
pub struct TermWithStateStreamer<'a, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
{
fst_map: &'a TermDictionary,
stream: StreamWithState<'a, A>,
term_ord: TermOrdinal,
current_key: Vec<u8>,
current_value: TermInfo,
current_state: Option<A::State>,
}
impl<'a, A> TermWithStateStreamer<'a, A>
where
A: Automaton,
A::State: Clone,
{
/// Advance position the stream on the next item.
/// Before the first call to `.advance()`, the stream
/// is an unitialized state.
pub fn advance(&mut self) -> bool {
if let Some((term, term_ord, state)) = self.stream.next() {
self.current_key.clear();
self.current_key.extend_from_slice(term);
self.term_ord = term_ord;
self.current_value = self.fst_map.term_info_from_ord(term_ord);
self.current_state = Some(state);
true
} else {
false
}
}
/// Returns the `TermOrdinal` of the given term.
///
/// May panic if the called as `.advance()` as never
/// been called before.
pub fn term_ord(&self) -> TermOrdinal {
self.term_ord
}
/// Accesses the current key.
///
/// `.key()` should return the key that was returned
/// by the `.next()` method.
///
/// If the end of the stream as been reached, and `.next()`
/// has been called and returned `None`, `.key()` remains
/// the value of the last key encountered.
///
/// Before any call to `.next()`, `.key()` returns an empty array.
pub fn key(&self) -> &[u8] {
&self.current_key
}
/// Accesses the current value.
///
/// Calling `.value()` after the end of the stream will return the
/// last `.value()` encountered.
///
/// # Panics
///
/// Calling `.value()` before the first call to `.advance()` returns
/// `V::default()`.
pub fn value(&self) -> &TermInfo {
&self.current_value
}
/// Return the next `(key, value, state)` triplet.
pub fn next(&mut self) -> Option<(&[u8], &TermInfo, A::State)> {
if self.advance() {
let state = self.current_state.take().unwrap(); // always Some(_) after advance
Some((self.key(), self.value(), state))
} else {
None
}
}
}

View File

@@ -7,7 +7,7 @@ use tantivy_fst::raw::Fst;
use tantivy_fst::Automaton;
use super::term_info_store::{TermInfoStore, TermInfoStoreWriter};
use super::{TermStreamer, TermStreamerBuilder};
use super::{TermStreamer, TermStreamerBuilder, TermWithStateStreamerBuilder};
use crate::directory::{FileSlice, OwnedBytes};
use crate::postings::TermInfo;
use crate::termdict::TermOrdinal;
@@ -218,4 +218,15 @@ impl TermDictionary {
let stream_builder = self.fst_index.search(automaton);
TermStreamerBuilder::<A>::new(self, stream_builder)
}
/// Returns a search builder, to stream all of the terms
/// within the Automaton
pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton + 'a,
A::State: Clone,
{
let stream_builder = self.fst_index.search(automaton);
TermWithStateStreamerBuilder::<A>::new(self, stream_builder)
}
}

View File

@@ -40,11 +40,12 @@ use common::file_slice::FileSlice;
use common::BinarySerializable;
use tantivy_fst::Automaton;
use self::fst_termdict::TermWithStateStreamerBuilder;
use self::termdict::{
TermDictionary as InnerTermDict, TermDictionaryBuilder as InnerTermDictBuilder,
TermStreamerBuilder,
};
pub use self::termdict::{TermMerger, TermStreamer};
pub use self::termdict::{TermMerger, TermStreamer, TermWithStateStreamer};
use crate::postings::TermInfo;
#[derive(Debug, Eq, PartialEq)]
@@ -178,6 +179,16 @@ impl TermDictionary {
) -> FileSlice {
self.0.file_slice_for_range(key_range, limit)
}
/// Returns a search builder, to stream all of the terms
/// within the Automaton
pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton + 'a,
A::State: Clone,
{
self.0.search_with_state(automaton)
}
}
/// A TermDictionaryBuilder wrapping either an FST or a SSTable dictionary builder.