From dee2dd3f2129faf62bb39796674ee3d7630470bc Mon Sep 17 00:00:00 2001 From: Neil Hansen Date: Wed, 31 Jan 2024 16:11:16 -0800 Subject: [PATCH] Use Levenshtein distance to score documents in fuzzy term queries --- src/aggregation/mod.rs | 48 ++++++-- src/query/automaton_weight.rs | 64 ++++++----- src/query/fuzzy_query.rs | 2 +- src/query/term_query/mod.rs | 1 + src/termdict/fst_termdict/mod.rs | 4 +- src/termdict/fst_termdict/streamer.rs | 151 +++++++++++++++++++++++++- src/termdict/fst_termdict/termdict.rs | 13 ++- src/termdict/mod.rs | 13 ++- tests/fuzzy_scoring.rs | 131 ++++++++++++++++++++++ 9 files changed, 385 insertions(+), 42 deletions(-) create mode 100644 tests/fuzzy_scoring.rs diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index ddf60ea4c..9c65bcee6 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -199,7 +199,9 @@ fn parse_str_into_f64(value: &str) -> Result { /// deserialize Option from string or float pub(crate) fn deserialize_option_f64<'de, D>(deserializer: D) -> Result, D::Error> -where D: Deserializer<'de> { +where + D: Deserializer<'de>, +{ struct StringOrFloatVisitor; impl Visitor<'_> for StringOrFloatVisitor { @@ -210,32 +212,44 @@ where D: Deserializer<'de> { } fn visit_str(self, value: &str) -> Result - where E: de::Error { + where + E: de::Error, + { parse_str_into_f64(value).map(Some) } fn visit_f64(self, value: f64) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(Some(value)) } fn visit_i64(self, value: i64) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(Some(value as f64)) } fn visit_u64(self, value: u64) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(Some(value as f64)) } fn visit_none(self) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(None) } fn visit_unit(self) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(None) } } @@ -245,7 +259,9 @@ where D: Deserializer<'de> { /// deserialize f64 from string or float pub(crate) fn deserialize_f64<'de, D>(deserializer: D) -> Result -where D: Deserializer<'de> { +where + D: Deserializer<'de>, +{ struct StringOrFloatVisitor; impl Visitor<'_> for StringOrFloatVisitor { @@ -256,22 +272,30 @@ where D: Deserializer<'de> { } fn visit_str(self, value: &str) -> Result - where E: de::Error { + where + E: de::Error, + { parse_str_into_f64(value) } fn visit_f64(self, value: f64) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(value) } fn visit_i64(self, value: i64) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(value as f64) } fn visit_u64(self, value: u64) -> Result - where E: de::Error { + where + E: de::Error, + { Ok(value as f64) } } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index 5f1053fb6..4f1f196b0 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -1,15 +1,18 @@ +use std::any::{Any, TypeId}; use std::io; use std::sync::Arc; -use common::BitSet; use tantivy_fst::Automaton; use super::phrase_prefix_query::prefix_end; +use super::BufferedUnionScorer; use crate::index::SegmentReader; use crate::postings::TermInfo; -use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight}; +use crate::query::fuzzy_query::DfaWrapper; +use crate::query::score_combiner::SumCombiner; +use crate::query::{ConstScorer, Explanation, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; -use crate::termdict::{TermDictionary, TermStreamer}; +use crate::termdict::{TermDictionary, TermWithStateStreamer}; use crate::{DocId, Score, TantivyError}; /// A weight struct for Fuzzy Term and Regex Queries @@ -52,9 +55,9 @@ where fn automaton_stream<'a>( &'a self, term_dict: &'a TermDictionary, - ) -> io::Result> { + ) -> io::Result> { let automaton: &A = &self.automaton; - let mut term_stream_builder = term_dict.search(automaton); + let mut term_stream_builder = term_dict.search_with_state(automaton); if let Some(json_path_bytes) = &self.json_path_bytes { term_stream_builder = term_stream_builder.ge(json_path_bytes); @@ -85,35 +88,27 @@ where A::State: Clone, { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { - let max_doc = reader.max_doc(); - let mut doc_bitset = BitSet::with_max_value(max_doc); let inverted_index = reader.inverted_index(self.field)?; let term_dict = inverted_index.terms(); let mut term_stream = self.automaton_stream(term_dict)?; - while term_stream.advance() { - let term_info = term_stream.value(); - let mut block_segment_postings = inverted_index - .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; - loop { - let docs = block_segment_postings.docs(); - if docs.is_empty() { - break; - } - for &doc in docs { - doc_bitset.insert(doc); - } - block_segment_postings.advance(); - } + + let mut scorers = vec![]; + while let Some((_term, term_info, state)) = term_stream.next() { + let score = automaton_score(self.automaton.as_ref(), state); + let segment_postings = + inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; + let scorer = ConstScorer::new(segment_postings, boost * score); + scorers.push(scorer); } - let doc_bitset = BitSetDocSet::from(doc_bitset); - let const_scorer = ConstScorer::new(doc_bitset, boost); - Ok(Box::new(const_scorer)) + + let scorer = BufferedUnionScorer::build(scorers, SumCombiner::default); + Ok(Box::new(scorer)) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { let mut scorer = self.scorer(reader, 1.0)?; if scorer.seek(doc) == doc { - Ok(Explanation::new("AutomatonScorer", 1.0)) + Ok(Explanation::new("AutomatonScorer", scorer.score())) } else { Err(TantivyError::InvalidArgument( "Document does not exist".to_string(), @@ -122,6 +117,25 @@ where } } +fn automaton_score(automaton: &A, state: A::State) -> f32 +where + A: Automaton + Send + Sync + 'static, + A::State: Clone, +{ + if TypeId::of::() == automaton.type_id() && TypeId::of::() == state.type_id() { + let dfa = automaton as *const A as *const DfaWrapper; + let dfa = unsafe { &*dfa }; + + let id = &state as *const A::State as *const u32; + let id = unsafe { *id }; + + let dist = dfa.0.distance(id).to_u8() as f32; + 1.0 / (1.0 + dist) + } else { + 1.0 + } +} + #[cfg(test)] mod tests { use tantivy_fst::Automaton; diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index a0634b96b..d81a4dc42 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -299,7 +299,7 @@ mod test { searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?; assert_eq!(top_docs.len(), 1, "Expected only 1 document"); let (score, _) = top_docs[0]; - assert_nearly_equals!(1.0, score); + assert_nearly_equals!(0.5, score); } // fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n') diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index 0811725be..f06be093b 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -4,6 +4,7 @@ mod term_weight; pub use self::term_query::TermQuery; pub use self::term_scorer::TermScorer; + #[cfg(test)] mod tests { diff --git a/src/termdict/fst_termdict/mod.rs b/src/termdict/fst_termdict/mod.rs index 4201df6a4..673b569c0 100644 --- a/src/termdict/fst_termdict/mod.rs +++ b/src/termdict/fst_termdict/mod.rs @@ -24,5 +24,7 @@ mod term_info_store; mod termdict; pub use self::merger::TermMerger; -pub use self::streamer::{TermStreamer, TermStreamerBuilder}; +pub use self::streamer::{ + TermStreamer, TermStreamerBuilder, TermWithStateStreamer, TermWithStateStreamerBuilder, +}; pub use self::termdict::{TermDictionary, TermDictionaryBuilder}; diff --git a/src/termdict/fst_termdict/streamer.rs b/src/termdict/fst_termdict/streamer.rs index d2e31421f..ea68646ed 100644 --- a/src/termdict/fst_termdict/streamer.rs +++ b/src/termdict/fst_termdict/streamer.rs @@ -1,7 +1,7 @@ use std::io; use tantivy_fst::automaton::AlwaysMatch; -use tantivy_fst::map::{Stream, StreamBuilder}; +use tantivy_fst::map::{Stream, StreamBuilder, StreamWithState}; use tantivy_fst::{Automaton, IntoStreamer, Streamer}; use super::TermDictionary; @@ -145,3 +145,152 @@ where A: Automaton } } } + +/// `TermWithStateStreamerBuilder` is a helper object used to define +/// a range of terms that should be streamed. +pub struct TermWithStateStreamerBuilder<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + fst_map: &'a TermDictionary, + stream_builder: StreamBuilder<'a, A>, +} + +impl<'a, A> TermWithStateStreamerBuilder<'a, A> +where + A: Automaton, + A::State: Clone, +{ + pub(crate) fn new(fst_map: &'a TermDictionary, stream_builder: StreamBuilder<'a, A>) -> Self { + TermWithStateStreamerBuilder { + fst_map, + stream_builder, + } + } + + /// Limit the range to terms greater or equal to the bound + pub fn ge>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.ge(bound); + self + } + + /// Limit the range to terms strictly greater than the bound + pub fn gt>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.gt(bound); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn le>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.le(bound); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn lt>(mut self, bound: T) -> Self { + self.stream_builder = self.stream_builder.lt(bound); + self + } + + /// Iterate over the range backwards. + pub fn backward(mut self) -> Self { + self.stream_builder = self.stream_builder.backward(); + self + } + + /// Creates the stream corresponding to the range + /// of terms defined using the `TermWithStateStreamerBuilder`. + pub fn into_stream(self) -> io::Result> { + Ok(TermWithStateStreamer { + fst_map: self.fst_map, + stream: self.stream_builder.with_state().into_stream(), + term_ord: 0u64, + current_key: Vec::with_capacity(100), + current_value: TermInfo::default(), + current_state: None, + }) + } +} + +/// `TermWithStateStreamer` acts as a cursor over a range of terms of a segment. +/// Terms are guaranteed to be sorted. +pub struct TermWithStateStreamer<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + fst_map: &'a TermDictionary, + stream: StreamWithState<'a, A>, + term_ord: TermOrdinal, + current_key: Vec, + current_value: TermInfo, + current_state: Option, +} + +impl<'a, A> TermWithStateStreamer<'a, A> +where + A: Automaton, + A::State: Clone, +{ + /// Advance position the stream on the next item. + /// Before the first call to `.advance()`, the stream + /// is an unitialized state. + pub fn advance(&mut self) -> bool { + if let Some((term, term_ord, state)) = self.stream.next() { + self.current_key.clear(); + self.current_key.extend_from_slice(term); + self.term_ord = term_ord; + self.current_value = self.fst_map.term_info_from_ord(term_ord); + self.current_state = Some(state); + true + } else { + false + } + } + + /// Returns the `TermOrdinal` of the given term. + /// + /// May panic if the called as `.advance()` as never + /// been called before. + pub fn term_ord(&self) -> TermOrdinal { + self.term_ord + } + + /// Accesses the current key. + /// + /// `.key()` should return the key that was returned + /// by the `.next()` method. + /// + /// If the end of the stream as been reached, and `.next()` + /// has been called and returned `None`, `.key()` remains + /// the value of the last key encountered. + /// + /// Before any call to `.next()`, `.key()` returns an empty array. + pub fn key(&self) -> &[u8] { + &self.current_key + } + + /// Accesses the current value. + /// + /// Calling `.value()` after the end of the stream will return the + /// last `.value()` encountered. + /// + /// # Panics + /// + /// Calling `.value()` before the first call to `.advance()` returns + /// `V::default()`. + pub fn value(&self) -> &TermInfo { + &self.current_value + } + + /// Return the next `(key, value, state)` triplet. + pub fn next(&mut self) -> Option<(&[u8], &TermInfo, A::State)> { + if self.advance() { + let state = self.current_state.take().unwrap(); // always Some(_) after advance + Some((self.key(), self.value(), state)) + } else { + None + } + } +} diff --git a/src/termdict/fst_termdict/termdict.rs b/src/termdict/fst_termdict/termdict.rs index 3f2a5a5f3..d0fbe5f1f 100644 --- a/src/termdict/fst_termdict/termdict.rs +++ b/src/termdict/fst_termdict/termdict.rs @@ -7,7 +7,7 @@ use tantivy_fst::raw::Fst; use tantivy_fst::Automaton; use super::term_info_store::{TermInfoStore, TermInfoStoreWriter}; -use super::{TermStreamer, TermStreamerBuilder}; +use super::{TermStreamer, TermStreamerBuilder, TermWithStateStreamerBuilder}; use crate::directory::{FileSlice, OwnedBytes}; use crate::postings::TermInfo; use crate::termdict::TermOrdinal; @@ -218,4 +218,15 @@ impl TermDictionary { let stream_builder = self.fst_index.search(automaton); TermStreamerBuilder::::new(self, stream_builder) } + + /// Returns a search builder, to stream all of the terms + /// within the Automaton + pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A> + where + A: Automaton + 'a, + A::State: Clone, + { + let stream_builder = self.fst_index.search(automaton); + TermWithStateStreamerBuilder::::new(self, stream_builder) + } } diff --git a/src/termdict/mod.rs b/src/termdict/mod.rs index 4000b08d4..a6dde7162 100644 --- a/src/termdict/mod.rs +++ b/src/termdict/mod.rs @@ -40,11 +40,12 @@ use common::file_slice::FileSlice; use common::BinarySerializable; use tantivy_fst::Automaton; +use self::fst_termdict::TermWithStateStreamerBuilder; use self::termdict::{ TermDictionary as InnerTermDict, TermDictionaryBuilder as InnerTermDictBuilder, TermStreamerBuilder, }; -pub use self::termdict::{TermMerger, TermStreamer}; +pub use self::termdict::{TermMerger, TermStreamer, TermWithStateStreamer}; use crate::postings::TermInfo; #[derive(Debug, Eq, PartialEq)] @@ -178,6 +179,16 @@ impl TermDictionary { ) -> FileSlice { self.0.file_slice_for_range(key_range, limit) } + + /// Returns a search builder, to stream all of the terms + /// within the Automaton + pub fn search_with_state<'a, A>(&'a self, automaton: A) -> TermWithStateStreamerBuilder<'a, A> + where + A: Automaton + 'a, + A::State: Clone, + { + self.0.search_with_state(automaton) + } } /// A TermDictionaryBuilder wrapping either an FST or a SSTable dictionary builder. diff --git a/tests/fuzzy_scoring.rs b/tests/fuzzy_scoring.rs new file mode 100644 index 000000000..4f594f090 --- /dev/null +++ b/tests/fuzzy_scoring.rs @@ -0,0 +1,131 @@ +#[cfg(test)] +mod test { + use maplit::hashmap; + use tantivy::collector::TopDocs; + use tantivy::query::FuzzyTermQuery; + use tantivy::schema::{Schema, Value, STORED, TEXT}; + use tantivy::{doc, Index, TantivyDocument, Term}; + + #[test] + pub fn test_fuzzy_term() { + // Define a list of documents to be indexed. Each entry represents a text + // that will be associated with the field "country" in the index. + let docs = vec![ + "WENN ROT WIE RUBIN", + "WENN ROT WIE ROBIN", + "WHEN RED LIKE ROBIN", + "WENN RED AS ROBIN", + "WHEN ROYAL BLUE ROBIN", + "IF RED LIKE RUBEN", + "WHEN GREEN LIKE ROBIN", + "WENN ROSE LIKE ROBIN", + "IF PINK LIKE ROBIN", + "WENN ROT WIE RABIN", + "WENN BLU WIE ROBIN", + "WHEN YELLOW LIKE RABBIT", + "IF BLUE LIKE ROBIN", + "WHEN ORANGE LIKE RIBBON", + "WENN VIOLET WIE RUBIX", + "WHEN INDIGO LIKE ROBBIE", + "IF TEAL LIKE RUBY", + "WHEN GOLD LIKE ROB", + "WENN SILVER WIE ROBY", + "IF BRONZE LIKE ROBE", + ]; + + // Define the expected scores when queried with "robin" and a fuzziness of 2. + // This map associates each document text with its expected score. + let expected_scores = hashmap! { + "WHEN GREEN LIKE ROBIN" => 1.0, + "WENN RED AS ROBIN" => 1.0, + "WHEN RED LIKE ROBIN" => 1.0, + "WENN ROSE LIKE ROBIN" => 1.0, + "WENN ROT WIE ROBIN" => 1.0, + "WHEN ROYAL BLUE ROBIN" => 1.0, + "IF PINK LIKE ROBIN" => 1.0, + "IF BLUE LIKE ROBIN" => 1.0, + "WENN BLU WIE ROBIN" => 1.0, + "WENN ROT WIE RUBIN" => 0.5, + "WENN ROT WIE RABIN" => 0.5, + "IF RED LIKE RUBEN" => 0.33333334, + "WENN VIOLET WIE RUBIX" => 0.33333334, + "IF BRONZE LIKE ROBE" => 0.33333334, + "WENN SILVER WIE ROBY" => 0.33333334, + "WHEN GOLD LIKE ROB" => 0.33333334, + "WHEN INDIGO LIKE ROBBIE" => 0.33333334, + }; + + // Build a schema for the index. + // The schema determines how documents are indexed and searched. + let mut schema_builder = Schema::builder(); + + // Add a text field named "country" to the schema. This field will store the text and + // is indexed in a way that makes it searchable. + let country_field = schema_builder.add_text_field("country", TEXT | STORED); + // Build the schema based on the provided definitions. + let schema = schema_builder.build(); + // Create a new index in RAM based on the defined schema. + let index = Index::create_in_ram(schema); + { + // Create an index writer with one thread and a certain memory limit. + // The writer allows us to add documents to the index. + let mut index_writer = index.writer_with_num_threads(1, 15_000_000).unwrap(); + + // Index each document in the docs list. + for &doc in &docs { + index_writer + .add_document(doc!(country_field => doc)) + .unwrap(); + } + + // Commit changes to the index. This finalizes the addition of documents. + index_writer.commit().unwrap(); + } + + // Create a reader for the index to search the indexed documents. + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + + { + // Define a term based on the field "country" and the text "robin". + let term = Term::from_field_text(country_field, "robin"); + + // Create a fuzzy query for "robin", a fuzziness of 2, and a prefix length of 0. + let fuzzy_query = FuzzyTermQuery::new(term, 2, true); + + // Search the index with the fuzzy query and retrieve up to 100 top documents. + let top_docs = searcher + .search(&fuzzy_query, &TopDocs::with_limit(100).order_by_score()) + .unwrap(); + + // Print out the scores and documents retrieved by the search. + for (score, adr) in &top_docs { + let doc: TantivyDocument = searcher.doc(*adr).expect("document"); + println!( + "{score}, {:?}", + doc.field_values().next().unwrap().1.as_str() + ); + } + + // Assert that 17 documents match the fuzzy query criteria. + // We don't expect anything that has a larger fuzziness than 2 + // to be returned in the query, leaving us with 17 expected results. + assert_eq!(top_docs.len(), 17, "Expected 17 documents"); + + // Check the scores of the returned documents against the expected scores. + for (score, adr) in &top_docs { + let doc: TantivyDocument = searcher.doc(*adr).expect("document"); + let doc_text = doc.field_values().next().unwrap().1.as_str().unwrap(); + + // Ensure the retrieved score for each document is close to the expected score. + assert!( + (score - expected_scores[doc_text]).abs() < f32::EPSILON, + "Unexpected score for document {}. Expected: {}, Actual: {}", + doc_text, + expected_scores[doc_text], + score + ); + } + } + } +}