mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-06 01:02:55 +00:00
172 lines
5.3 KiB
Rust
172 lines
5.3 KiB
Rust
use std::io;
|
|
use std::sync::Arc;
|
|
|
|
use common::BitSet;
|
|
use tantivy_fst::Automaton;
|
|
|
|
use crate::core::SegmentReader;
|
|
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
|
|
use crate::schema::{Field, IndexRecordOption};
|
|
use crate::termdict::{TermDictionary, TermStreamer};
|
|
use crate::{DocId, Score, TantivyError};
|
|
|
|
/// A weight struct for Fuzzy Term and Regex Queries
|
|
pub struct AutomatonWeight<A> {
|
|
field: Field,
|
|
automaton: Arc<A>,
|
|
}
|
|
|
|
impl<A> AutomatonWeight<A>
|
|
where
|
|
A: Automaton + Send + Sync + 'static,
|
|
A::State: Clone,
|
|
{
|
|
/// Create a new AutomationWeight
|
|
pub fn new<IntoArcA: Into<Arc<A>>>(field: Field, automaton: IntoArcA) -> AutomatonWeight<A> {
|
|
AutomatonWeight {
|
|
field,
|
|
automaton: automaton.into(),
|
|
}
|
|
}
|
|
|
|
fn automaton_stream<'a>(
|
|
&'a self,
|
|
term_dict: &'a TermDictionary,
|
|
) -> io::Result<TermStreamer<'a, &'a A>> {
|
|
let automaton: &A = &*self.automaton;
|
|
let term_stream_builder = term_dict.search(automaton);
|
|
term_stream_builder.into_stream()
|
|
}
|
|
}
|
|
|
|
impl<A> Weight for AutomatonWeight<A>
|
|
where
|
|
A: Automaton + Send + Sync + 'static,
|
|
A::State: Clone,
|
|
{
|
|
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
|
let max_doc = reader.max_doc();
|
|
let mut doc_bitset = BitSet::with_max_value(max_doc);
|
|
let inverted_index = reader.inverted_index(self.field)?;
|
|
let term_dict = inverted_index.terms();
|
|
let mut term_stream = self.automaton_stream(term_dict)?;
|
|
while term_stream.advance() {
|
|
let term_info = term_stream.value();
|
|
let mut block_segment_postings = inverted_index
|
|
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
|
|
loop {
|
|
let docs = block_segment_postings.docs();
|
|
if docs.is_empty() {
|
|
break;
|
|
}
|
|
for &doc in docs {
|
|
doc_bitset.insert(doc);
|
|
}
|
|
block_segment_postings.advance();
|
|
}
|
|
}
|
|
let doc_bitset = BitSetDocSet::from(doc_bitset);
|
|
let const_scorer = ConstScorer::new(doc_bitset, boost);
|
|
Ok(Box::new(const_scorer))
|
|
}
|
|
|
|
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
|
let mut scorer = self.scorer(reader, 1.0)?;
|
|
if scorer.seek(doc) == doc {
|
|
Ok(Explanation::new("AutomatonScorer", 1.0))
|
|
} else {
|
|
Err(TantivyError::InvalidArgument(
|
|
"Document does not exist".to_string(),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use tantivy_fst::Automaton;
|
|
|
|
use super::AutomatonWeight;
|
|
use crate::docset::TERMINATED;
|
|
use crate::query::Weight;
|
|
use crate::schema::{Schema, STRING};
|
|
use crate::Index;
|
|
|
|
fn create_index() -> crate::Result<Index> {
|
|
let mut schema = Schema::builder();
|
|
let title = schema.add_text_field("title", STRING);
|
|
let index = Index::create_in_ram(schema.build());
|
|
let mut index_writer = index.writer_for_tests()?;
|
|
index_writer.add_document(doc!(title=>"abc"))?;
|
|
index_writer.add_document(doc!(title=>"bcd"))?;
|
|
index_writer.add_document(doc!(title=>"abcd"))?;
|
|
index_writer.commit()?;
|
|
Ok(index)
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
enum State {
|
|
Start,
|
|
NotMatching,
|
|
AfterA,
|
|
}
|
|
|
|
struct PrefixedByA;
|
|
|
|
impl Automaton for PrefixedByA {
|
|
type State = State;
|
|
|
|
fn start(&self) -> Self::State {
|
|
State::Start
|
|
}
|
|
|
|
fn is_match(&self, state: &Self::State) -> bool {
|
|
matches!(*state, State::AfterA)
|
|
}
|
|
|
|
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
|
|
match *state {
|
|
State::Start => {
|
|
if byte == b'a' {
|
|
State::AfterA
|
|
} else {
|
|
State::NotMatching
|
|
}
|
|
}
|
|
State::AfterA => State::AfterA,
|
|
State::NotMatching => State::NotMatching,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_automaton_weight() -> crate::Result<()> {
|
|
let index = create_index()?;
|
|
let field = index.schema().get_field("title").unwrap();
|
|
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
|
|
let reader = index.reader()?;
|
|
let searcher = reader.searcher();
|
|
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.0)?;
|
|
assert_eq!(scorer.doc(), 0u32);
|
|
assert_eq!(scorer.score(), 1.0);
|
|
assert_eq!(scorer.advance(), 2u32);
|
|
assert_eq!(scorer.doc(), 2u32);
|
|
assert_eq!(scorer.score(), 1.0);
|
|
assert_eq!(scorer.advance(), TERMINATED);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_automaton_weight_boost() -> crate::Result<()> {
|
|
let index = create_index()?;
|
|
let field = index.schema().get_field("title").unwrap();
|
|
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
|
|
let reader = index.reader()?;
|
|
let searcher = reader.searcher();
|
|
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.32)?;
|
|
assert_eq!(scorer.doc(), 0u32);
|
|
assert_eq!(scorer.score(), 1.32);
|
|
Ok(())
|
|
}
|
|
}
|