Files
tantivy/src/query/automaton_weight.rs
2026-01-06 14:37:10 +01:00

221 lines
7.1 KiB
Rust

use std::io;
use std::sync::Arc;
use common::BitSet;
use tantivy_fst::Automaton;
use super::phrase_prefix_query::prefix_end;
use crate::index::SegmentReader;
use crate::postings::TermInfo;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::{DocId, Score, TantivyError};
/// A weight struct for Fuzzy Term and Regex Queries
pub struct AutomatonWeight<A> {
field: Field,
automaton: Arc<A>,
// For JSON fields, the term dictionary include terms from all paths.
// We apply additional filtering based on the given JSON path, when searching within the term
// dictionary. This prevents terms from unrelated paths from matching the search criteria.
json_path_bytes: Option<Box<[u8]>>,
}
impl<A> AutomatonWeight<A>
where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
/// Create a new AutomationWeight
pub fn new<IntoArcA: Into<Arc<A>>>(field: Field, automaton: IntoArcA) -> AutomatonWeight<A> {
AutomatonWeight {
field,
automaton: automaton.into(),
json_path_bytes: None,
}
}
/// Create a new AutomationWeight for a json path
pub fn new_for_json_path<IntoArcA: Into<Arc<A>>>(
field: Field,
automaton: IntoArcA,
json_path_bytes: &[u8],
) -> AutomatonWeight<A> {
AutomatonWeight {
field,
automaton: automaton.into(),
json_path_bytes: Some(json_path_bytes.to_vec().into_boxed_slice()),
}
}
fn automaton_stream<'a>(
&'a self,
term_dict: &'a TermDictionary,
) -> io::Result<TermStreamer<'a, &'a A>> {
let automaton: &A = &self.automaton;
let mut term_stream_builder = term_dict.search(automaton);
if let Some(json_path_bytes) = &self.json_path_bytes {
term_stream_builder = term_stream_builder.ge(json_path_bytes);
if let Some(end) = prefix_end(json_path_bytes) {
term_stream_builder = term_stream_builder.lt(&end);
}
}
term_stream_builder.into_stream()
}
/// Returns the term infos that match the automaton
pub fn get_match_term_infos(&self, reader: &SegmentReader) -> crate::Result<Vec<TermInfo>> {
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
let mut term_infos = Vec::new();
while term_stream.advance() {
term_infos.push(term_stream.value().clone());
}
Ok(term_infos)
}
}
impl<A> Weight for AutomatonWeight<A>
where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
fn scorer(
&self,
reader: &SegmentReader,
boost: Score,
seek_doc: DocId,
) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() {
let term_info = term_stream.value();
let (mut block_segment_postings, _) = inverted_index
.read_block_postings_from_terminfo_with_seek(
term_info,
IndexRecordOption::Basic,
seek_doc,
)?;
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost);
Ok(Box::new(const_scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0, 0)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0))
} else {
Err(TantivyError::InvalidArgument(
"Document does not exist".to_string(),
))
}
}
}
#[cfg(test)]
mod tests {
use tantivy_fst::Automaton;
use super::AutomatonWeight;
use crate::docset::TERMINATED;
use crate::query::Weight;
use crate::schema::{Schema, STRING};
use crate::{Index, IndexWriter};
fn create_index() -> crate::Result<Index> {
let mut schema = Schema::builder();
let title = schema.add_text_field("title", STRING);
let index = Index::create_in_ram(schema.build());
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(title=>"abc"))?;
index_writer.add_document(doc!(title=>"bcd"))?;
index_writer.add_document(doc!(title=>"abcd"))?;
index_writer.commit()?;
Ok(index)
}
#[derive(Clone, Copy)]
enum State {
Start,
NotMatching,
AfterA,
}
struct PrefixedByA;
impl Automaton for PrefixedByA {
type State = State;
fn start(&self) -> Self::State {
State::Start
}
fn is_match(&self, state: &Self::State) -> bool {
matches!(*state, State::AfterA)
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
match *state {
State::Start => {
if byte == b'a' {
State::AfterA
} else {
State::NotMatching
}
}
State::AfterA => State::AfterA,
State::NotMatching => State::NotMatching,
}
}
}
#[test]
fn test_automaton_weight() -> crate::Result<()> {
let index = create_index()?;
let field = index.schema().get_field("title").unwrap();
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader()?;
let searcher = reader.searcher();
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?;
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.0);
assert_eq!(scorer.advance(), 2u32);
assert_eq!(scorer.doc(), 2u32);
assert_eq!(scorer.score(), 1.0);
assert_eq!(scorer.advance(), TERMINATED);
Ok(())
}
#[test]
fn test_automaton_weight_boost() -> crate::Result<()> {
let index = create_index()?;
let field = index.schema().get_field("title").unwrap();
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader()?;
let searcher = reader.searcher();
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.32, 0)?;
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.32);
Ok(())
}
}