Use Levenshtein distance to score documents in fuzzy term queries

This commit is contained in:
Neil Hansen
2024-01-31 16:11:16 -08:00
committed by Stu Hood
parent 794ff1ffc9
commit dee2dd3f21
9 changed files with 385 additions and 42 deletions

View File

@@ -1,15 +1,18 @@
use std::any::{Any, TypeId};
use std::io;
use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;
use super::phrase_prefix_query::prefix_end;
use super::BufferedUnionScorer;
use crate::index::SegmentReader;
use crate::postings::TermInfo;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::query::fuzzy_query::DfaWrapper;
use crate::query::score_combiner::SumCombiner;
use crate::query::{ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::termdict::{TermDictionary, TermWithStateStreamer};
use crate::{DocId, Score, TantivyError};
/// A weight struct for Fuzzy Term and Regex Queries
@@ -52,9 +55,9 @@ where
fn automaton_stream<'a>(
&'a self,
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
) -> io::Result<TermWithStateStreamer<'a, &'a A>> {
let automaton: &A = &self.automaton;
let mut term_stream_builder = term_dict.search(automaton);
let mut term_stream_builder = term_dict.search_with_state(automaton);
if let Some(json_path_bytes) = &self.json_path_bytes {
term_stream_builder = term_stream_builder.ge(json_path_bytes);
@@ -85,35 +88,27 @@ where
A::State: Clone,
{
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
let mut scorers = vec![];
while let Some((_term, term_info, state)) = term_stream.next() {
let score = automaton_score(self.automaton.as_ref(), state);
let segment_postings =
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
let scorer = ConstScorer::new(segment_postings, boost * score);
scorers.push(scorer);
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost);
Ok(Box::new(const_scorer))
let scorer = BufferedUnionScorer::build(scorers, SumCombiner::default);
Ok(Box::new(scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0))
Ok(Explanation::new("AutomatonScorer", scorer.score()))
} else {
Err(TantivyError::InvalidArgument(
"Document does not exist".to_string(),
@@ -122,6 +117,25 @@ where
}
}
fn automaton_score<A>(automaton: &A, state: A::State) -> f32
where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
if TypeId::of::<DfaWrapper>() == automaton.type_id() && TypeId::of::<u32>() == state.type_id() {
let dfa = automaton as *const A as *const DfaWrapper;
let dfa = unsafe { &*dfa };
let id = &state as *const A::State as *const u32;
let id = unsafe { *id };
let dist = dfa.0.distance(id).to_u8() as f32;
1.0 / (1.0 + dist)
} else {
1.0
}
}
#[cfg(test)]
mod tests {
use tantivy_fst::Automaton;

View File

@@ -299,7 +299,7 @@ mod test {
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1.0, score);
assert_nearly_equals!(0.5, score);
}
// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')

View File

@@ -4,6 +4,7 @@ mod term_weight;
pub use self::term_query::TermQuery;
pub use self::term_scorer::TermScorer;
#[cfg(test)]
mod tests {