Use Levenshtein distance to score documents in fuzzy term queries

This commit is contained in:
Neil Hansen
2024-01-31 16:11:16 -08:00
committed by Stu Hood
parent 794ff1ffc9
commit dee2dd3f21
9 changed files with 385 additions and 42 deletions

View File

@@ -199,7 +199,9 @@ fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> {
/// deserialize Option<f64> from string or float
pub(crate) fn deserialize_option_f64<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
where D: Deserializer<'de> {
where
D: Deserializer<'de>,
{
struct StringOrFloatVisitor;
impl Visitor<'_> for StringOrFloatVisitor {
@@ -210,32 +212,44 @@ where D: Deserializer<'de> {
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
parse_str_into_f64(value).map(Some)
}
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(Some(value))
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(Some(value as f64))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(Some(value as f64))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(None)
}
}
@@ -245,7 +259,9 @@ where D: Deserializer<'de> {
/// deserialize f64 from string or float
pub(crate) fn deserialize_f64<'de, D>(deserializer: D) -> Result<f64, D::Error>
where D: Deserializer<'de> {
where
D: Deserializer<'de>,
{
struct StringOrFloatVisitor;
impl Visitor<'_> for StringOrFloatVisitor {
@@ -256,22 +272,30 @@ where D: Deserializer<'de> {
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
parse_str_into_f64(value)
}
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(value)
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(value as f64)
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where E: de::Error {
where
E: de::Error,
{
Ok(value as f64)
}
}

View File

@@ -1,15 +1,18 @@
use std::any::{Any, TypeId};
use std::io;
use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;
use super::phrase_prefix_query::prefix_end;
use super::BufferedUnionScorer;
use crate::index::SegmentReader;
use crate::postings::TermInfo;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::query::fuzzy_query::DfaWrapper;
use crate::query::score_combiner::SumCombiner;
use crate::query::{ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::termdict::{TermDictionary, TermWithStateStreamer};
use crate::{DocId, Score, TantivyError};
/// A weight struct for Fuzzy Term and Regex Queries
@@ -52,9 +55,9 @@ where
fn automaton_stream<'a>(
&'a self,
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
) -> io::Result<TermWithStateStreamer<'a, &'a A>> {
let automaton: &A = &self.automaton;
let mut term_stream_builder = term_dict.search(automaton);
let mut term_stream_builder = term_dict.search_with_state(automaton);
if let Some(json_path_bytes) = &self.json_path_bytes {
term_stream_builder = term_stream_builder.ge(json_path_bytes);
@@ -85,35 +88,27 @@ where
A::State: Clone,
{
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
let mut scorers = vec![];
while let Some((_term, term_info, state)) = term_stream.next() {
let score = automaton_score(self.automaton.as_ref(), state);
let segment_postings =
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
let scorer = ConstScorer::new(segment_postings, boost * score);
scorers.push(scorer);
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost);
Ok(Box::new(const_scorer))
let scorer = BufferedUnionScorer::build(scorers, SumCombiner::default);
Ok(Box::new(scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0))
Ok(Explanation::new("AutomatonScorer", scorer.score()))
} else {
Err(TantivyError::InvalidArgument(
"Document does not exist".to_string(),
@@ -122,6 +117,25 @@ where
}
}
fn automaton_score<A>(automaton: &A, state: A::State) -> f32
where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
if TypeId::of::<DfaWrapper>() == automaton.type_id() && TypeId::of::<u32>() == state.type_id() {
let dfa = automaton as *const A as *const DfaWrapper;
let dfa = unsafe { &*dfa };
let id = &state as *const A::State as *const u32;
let id = unsafe { *id };
let dist = dfa.0.distance(id).to_u8() as f32;
1.0 / (1.0 + dist)
} else {
1.0
}
}
#[cfg(test)]
mod tests {
use tantivy_fst::Automaton;

View File

@@ -299,7 +299,7 @@ mod test {
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
assert_nearly_equals!(0.5, score);
}
// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')

View File

@@ -4,6 +4,7 @@ mod term_weight;
pub use self::term_query::TermQuery;
pub use self::term_scorer::TermScorer;
#[cfg(test)]
mod tests {

View File

@@ -24,5 +24,7 @@ mod term_info_store;
mod termdict;
pub use self::merger::TermMerger;
pub use self::streamer::{TermStreamer, TermStreamerBuilder};
pub use self::streamer::{
TermStreamer, TermStreamerBuilder, TermWithStateStreamer, TermWithStateStreamerBuilder,
};
pub use self::termdict::{TermDictionary, TermDictionaryBuilder};

View File

@@ -1,7 +1,7 @@
use std::io;
use tantivy_fst::automaton::AlwaysMatch;
use tantivy_fst::map::{Stream, StreamBuilder};
use tantivy_fst::map::{Stream, StreamBuilder, StreamWithState};
use tantivy_fst::{Automaton, IntoStreamer, Streamer};
use super::TermDictionary;
@@ -145,3 +145,152 @@ where A: Automaton
}
}
}
/// `TermWithStateStreamerBuilder` is a helper object used to define
/// a range of terms that should be streamed.
pub struct TermWithStateStreamerBuilder<'a, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
{
fst_map: &'a TermDictionary,
stream_builder: StreamBuilder<'a, A>,
}
impl<'a, A> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton,
A::State: Clone,
{
pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self {
TermWithStateStreamerBuilder {
fst_map,
stream_builder,
}
}
/// Limit the range to terms greater or equal to the bound
pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.ge(bound);
self
}
/// Limit the range to terms strictly greater than the bound
pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.gt(bound);
self
}
/// Limit the range to terms lesser or equal to the bound
pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.le(bound);
self
}
/// Limit the range to terms lesser or equal to the bound
pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
self.stream_builder = self.stream_builder.lt(bound);
self
}
/// Iterate over the range backwards.
pub fn backward(mut self) -> Self {
self.stream_builder = self.stream_builder.backward();
self
}
/// Creates the stream corresponding to the range
/// of terms defined using the `TermWithStateStreamerBuilder`.
pub fn into_stream(self) -> io::Result<TermWithStateStreamer<'a, A>> {
Ok(TermWithStateStreamer {
fst_map: self.fst_map,
stream: self.stream_builder.with_state().into_stream(),
term_ord: 0u64,
current_key: Vec::with_capacity(100),
current_value: TermInfo::default(),
current_state: None,
})
}
}
/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment.
/// Terms are guaranteed to be sorted.
pub struct TermWithStateStreamer<'a, A = AlwaysMatch>
where
A: Automaton,
A::State: Clone,
{
fst_map: &'a TermDictionary,
stream: StreamWithState<'a, A>,
term_ord: TermOrdinal,
current_key: Vec<u8>,
current_value: TermInfo,
current_state: Option<A::State>,
}
impl<'a, A> TermWithStateStreamer<'a, A>
where
A: Automaton,
A::State: Clone,
{
/// Advance position the stream on the next item.
/// Before the first call to `.advance()`, the stream
/// is an unitialized state.
pub fn advance(&mut self) -> bool {
if let Some((term, term_ord, state)) = self.stream.next() {
self.current_key.clear();
self.current_key.extend_from_slice(term);
self.term_ord = term_ord;
self.current_value = self.fst_map.term_info_from_ord(term_ord);
self.current_state = Some(state);
true
} else {
false
}
}
/// Returns the `TermOrdinal` of the given term.
///
/// May panic if the called as `.advance()` as never
/// been called before.
pub fn term_ord(&self) -> TermOrdinal {
self.term_ord
}
/// Accesses the current key.
///
/// `.key()` should return the key that was returned
/// by the `.next()` method.
///
/// If the end of the stream as been reached, and `.next()`
/// has been called and returned `None`, `.key()` remains
/// the value of the last key encountered.
///
/// Before any call to `.next()`, `.key()` returns an empty array.
pub fn key(&self) -> &[u8] {
&self.current_key
}
/// Accesses the current value.
///
/// Calling `.value()` after the end of the stream will return the
/// last `.value()` encountered.
///
/// # Panics
///
/// Calling `.value()` before the first call to `.advance()` returns
/// `V::default()`.
pub fn value(&self) -> &TermInfo {
&self.current_value
}
/// Return the next `(key, value, state)` triplet.
pub fn next(&mut self) -> Option<(&[u8], &TermInfo, A::State)> {
if self.advance() {
let state = self.current_state.take().unwrap(); // always Some(_) after advance
Some((self.key(), self.value(), state))
} else {
None
}
}
}

View File

@@ -7,7 +7,7 @@ use tantivy_fst::raw::Fst;
use tantivy_fst::Automaton;
use super::term_info_store::{TermInfoStore, TermInfoStoreWriter};
use super::{TermStreamer, TermStreamerBuilder};
use super::{TermStreamer, TermStreamerBuilder, TermWithStateStreamerBuilder};
use crate::directory::{FileSlice, OwnedBytes};
use crate::postings::TermInfo;
use crate::termdict::TermOrdinal;
@@ -218,4 +218,15 @@ impl TermDictionary {
let stream_builder = self.fst_index.search(automaton);
TermStreamerBuilder::<A>::new(self, stream_builder)
}
/// Returns a search builder, to stream all of the terms
/// within the Automaton
pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton + 'a,
A::State: Clone,
{
let stream_builder = self.fst_index.search(automaton);
TermWithStateStreamerBuilder::<A>::new(self, stream_builder)
}
}

View File

@@ -40,11 +40,12 @@ use common::file_slice::FileSlice;
use common::BinarySerializable;
use tantivy_fst::Automaton;
use self::fst_termdict::TermWithStateStreamerBuilder;
use self::termdict::{
TermDictionary as InnerTermDict, TermDictionaryBuilder as InnerTermDictBuilder,
TermStreamerBuilder,
};
pub use self::termdict::{TermMerger, TermStreamer};
pub use self::termdict::{TermMerger, TermStreamer, TermWithStateStreamer};
use crate::postings::TermInfo;
#[derive(Debug, Eq, PartialEq)]
@@ -178,6 +179,16 @@ impl TermDictionary {
) -> FileSlice {
self.0.file_slice_for_range(key_range, limit)
}
/// Returns a search builder, to stream all of the terms
/// within the Automaton
pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A>
where
A: Automaton + 'a,
A::State: Clone,
{
self.0.search_with_state(automaton)
}
}
/// A TermDictionaryBuilder wrapping either an FST or a SSTable dictionary builder.

131
tests/fuzzy_scoring.rs Normal file
View File

@@ -0,0 +1,131 @@
#[cfg(test)]
mod test {
use maplit::hashmap;
use tantivy::collector::TopDocs;
use tantivy::query::FuzzyTermQuery;
use tantivy::schema::{Schema, Value, STORED, TEXT};
use tantivy::{doc, Index, TantivyDocument, Term};
#[test]
pub fn test_fuzzy_term() {
// Define a list of documents to be indexed. Each entry represents a text
// that will be associated with the field "country" in the index.
let docs = vec![
"WENN ROT WIE RUBIN",
"WENN ROT WIE ROBIN",
"WHEN RED LIKE ROBIN",
"WENN RED AS ROBIN",
"WHEN ROYAL BLUE ROBIN",
"IF RED LIKE RUBEN",
"WHEN GREEN LIKE ROBIN",
"WENN ROSE LIKE ROBIN",
"IF PINK LIKE ROBIN",
"WENN ROT WIE RABIN",
"WENN BLU WIE ROBIN",
"WHEN YELLOW LIKE RABBIT",
"IF BLUE LIKE ROBIN",
"WHEN ORANGE LIKE RIBBON",
"WENN VIOLET WIE RUBIX",
"WHEN INDIGO LIKE ROBBIE",
"IF TEAL LIKE RUBY",
"WHEN GOLD LIKE ROB",
"WENN SILVER WIE ROBY",
"IF BRONZE LIKE ROBE",
];
// Define the expected scores when queried with "robin" and a fuzziness of 2.
// This map associates each document text with its expected score.
let expected_scores = hashmap! {
"WHEN GREEN LIKE ROBIN" => 1.0,
"WENN RED AS ROBIN" => 1.0,
"WHEN RED LIKE ROBIN" => 1.0,
"WENN ROSE LIKE ROBIN" => 1.0,
"WENN ROT WIE ROBIN" => 1.0,
"WHEN ROYAL BLUE ROBIN" => 1.0,
"IF PINK LIKE ROBIN" => 1.0,
"IF BLUE LIKE ROBIN" => 1.0,
"WENN BLU WIE ROBIN" => 1.0,
"WENN ROT WIE RUBIN" => 0.5,
"WENN ROT WIE RABIN" => 0.5,
"IF RED LIKE RUBEN" => 0.33333334,
"WENN VIOLET WIE RUBIX" => 0.33333334,
"IF BRONZE LIKE ROBE" => 0.33333334,
"WENN SILVER WIE ROBY" => 0.33333334,
"WHEN GOLD LIKE ROB" => 0.33333334,
"WHEN INDIGO LIKE ROBBIE" => 0.33333334,
};
// Build a schema for the index.
// The schema determines how documents are indexed and searched.
let mut schema_builder = Schema::builder();
// Add a text field named "country" to the schema. This field will store the text and
// is indexed in a way that makes it searchable.
let country_field = schema_builder.add_text_field("country", TEXT | STORED);
// Build the schema based on the provided definitions.
let schema = schema_builder.build();
// Create a new index in RAM based on the defined schema.
let index = Index::create_in_ram(schema);
{
// Create an index writer with one thread and a certain memory limit.
// The writer allows us to add documents to the index.
let mut index_writer = index.writer_with_num_threads(1, 15_000_000).unwrap();
// Index each document in the docs list.
for &doc in &docs {
index_writer
.add_document(doc!(country_field => doc))
.unwrap();
}
// Commit changes to the index. This finalizes the addition of documents.
index_writer.commit().unwrap();
}
// Create a reader for the index to search the indexed documents.
let reader = index.reader().unwrap();
let searcher = reader.searcher();
{
// Define a term based on the field "country" and the text "robin".
let term = Term::from_field_text(country_field, "robin");
// Create a fuzzy query for "robin", a fuzziness of 2, and a prefix length of 0.
let fuzzy_query = FuzzyTermQuery::new(term, 2, true);
// Search the index with the fuzzy query and retrieve up to 100 top documents.
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(100).order_by_score())
.unwrap();
// Print out the scores and documents retrieved by the search.
for (score, adr) in &top_docs {
let doc: TantivyDocument = searcher.doc(*adr).expect("document");
println!(
"{score}, {:?}",
doc.field_values().next().unwrap().1.as_str()
);
}
// Assert that 17 documents match the fuzzy query criteria.
// We don't expect anything that has a larger fuzziness than 2
// to be returned in the query, leaving us with 17 expected results.
assert_eq!(top_docs.len(), 17, "Expected 17 documents");
// Check the scores of the returned documents against the expected scores.
for (score, adr) in &top_docs {
let doc: TantivyDocument = searcher.doc(*adr).expect("document");
let doc_text = doc.field_values().next().unwrap().1.as_str().unwrap();
// Ensure the retrieved score for each document is close to the expected score.
assert!(
(score - expected_scores[doc_text]).abs() < f32::EPSILON,
"Unexpected score for document {}. Expected: {}, Actual: {}",
doc_text,
expected_scores[doc_text],
score
);
}
}
}
}