mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-05-28 06:00:40 +00:00
Use Levenshtein distance to score documents in fuzzy term queries
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
use std::any::{Any, TypeId};
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::BitSet;
|
||||
use tantivy_fst::Automaton;
|
||||
|
||||
use super::phrase_prefix_query::prefix_end;
|
||||
use super::BufferedUnionScorer;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::postings::TermInfo;
|
||||
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
|
||||
use crate::query::fuzzy_query::DfaWrapper;
|
||||
use crate::query::score_combiner::SumCombiner;
|
||||
use crate::query::{ConstScorer, Explanation, Scorer, Weight};
|
||||
use crate::schema::{Field, IndexRecordOption};
|
||||
use crate::termdict::{TermDictionary, TermStreamer};
|
||||
use crate::termdict::{TermDictionary, TermWithStateStreamer};
|
||||
use crate::{DocId, Score, TantivyError};
|
||||
|
||||
/// A weight struct for Fuzzy Term and Regex Queries
|
||||
@@ -52,9 +55,9 @@ where
|
||||
fn automaton_stream<'a>(
|
||||
&'a self,
|
||||
term_dict: &'a TermDictionary,
|
||||
) -> io::Result<TermStreamer<'a, &'a A>> {
|
||||
) -> io::Result<TermWithStateStreamer<'a, &'a A>> {
|
||||
let automaton: &A = &self.automaton;
|
||||
let mut term_stream_builder = term_dict.search(automaton);
|
||||
let mut term_stream_builder = term_dict.search_with_state(automaton);
|
||||
|
||||
if let Some(json_path_bytes) = &self.json_path_bytes {
|
||||
term_stream_builder = term_stream_builder.ge(json_path_bytes);
|
||||
@@ -85,35 +88,27 @@ where
|
||||
A::State: Clone,
|
||||
{
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let max_doc = reader.max_doc();
|
||||
let mut doc_bitset = BitSet::with_max_value(max_doc);
|
||||
let inverted_index = reader.inverted_index(self.field)?;
|
||||
let term_dict = inverted_index.terms();
|
||||
let mut term_stream = self.automaton_stream(term_dict)?;
|
||||
while term_stream.advance() {
|
||||
let term_info = term_stream.value();
|
||||
let mut block_segment_postings = inverted_index
|
||||
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
|
||||
loop {
|
||||
let docs = block_segment_postings.docs();
|
||||
if docs.is_empty() {
|
||||
break;
|
||||
}
|
||||
for &doc in docs {
|
||||
doc_bitset.insert(doc);
|
||||
}
|
||||
block_segment_postings.advance();
|
||||
}
|
||||
|
||||
let mut scorers = vec![];
|
||||
while let Some((_term, term_info, state)) = term_stream.next() {
|
||||
let score = automaton_score(self.automaton.as_ref(), state);
|
||||
let segment_postings =
|
||||
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
|
||||
let scorer = ConstScorer::new(segment_postings, boost * score);
|
||||
scorers.push(scorer);
|
||||
}
|
||||
let doc_bitset = BitSetDocSet::from(doc_bitset);
|
||||
let const_scorer = ConstScorer::new(doc_bitset, boost);
|
||||
Ok(Box::new(const_scorer))
|
||||
|
||||
let scorer = BufferedUnionScorer::build(scorers, SumCombiner::default);
|
||||
Ok(Box::new(scorer))
|
||||
}
|
||||
|
||||
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
||||
let mut scorer = self.scorer(reader, 1.0)?;
|
||||
if scorer.seek(doc) == doc {
|
||||
Ok(Explanation::new("AutomatonScorer", 1.0))
|
||||
Ok(Explanation::new("AutomatonScorer", scorer.score()))
|
||||
} else {
|
||||
Err(TantivyError::InvalidArgument(
|
||||
"Document does not exist".to_string(),
|
||||
@@ -122,6 +117,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn automaton_score<A>(automaton: &A, state: A::State) -> f32
|
||||
where
|
||||
A: Automaton + Send + Sync + 'static,
|
||||
A::State: Clone,
|
||||
{
|
||||
if TypeId::of::<DfaWrapper>() == automaton.type_id() && TypeId::of::<u32>() == state.type_id() {
|
||||
let dfa = automaton as *const A as *const DfaWrapper;
|
||||
let dfa = unsafe { &*dfa };
|
||||
|
||||
let id = &state as *const A::State as *const u32;
|
||||
let id = unsafe { *id };
|
||||
|
||||
let dist = dfa.0.distance(id).to_u8() as f32;
|
||||
1.0 / (1.0 + dist)
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tantivy_fst::Automaton;
|
||||
|
||||
@@ -299,7 +299,7 @@ mod test {
|
||||
searcher.search(&fuzzy_query, &TopDocs::with_limit(2).order_by_score())?;
|
||||
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
|
||||
let (score, _) = top_docs[0];
|
||||
assert_nearly_equals!(1.0, score);
|
||||
assert_nearly_equals!(0.5, score);
|
||||
}
|
||||
|
||||
// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
|
||||
|
||||
@@ -4,6 +4,7 @@ mod term_weight;
|
||||
|
||||
pub use self::term_query::TermQuery;
|
||||
pub use self::term_scorer::TermScorer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user