mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-01-08 01:52:54 +00:00
221 lines
7.1 KiB
Rust
221 lines
7.1 KiB
Rust
use std::io;
|
|
use std::sync::Arc;
|
|
|
|
use common::BitSet;
|
|
use tantivy_fst::Automaton;
|
|
|
|
use super::phrase_prefix_query::prefix_end;
|
|
use crate::index::SegmentReader;
|
|
use crate::postings::TermInfo;
|
|
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
|
|
use crate::schema::{Field, IndexRecordOption};
|
|
use crate::termdict::{TermDictionary, TermStreamer};
|
|
use crate::{DocId, Score, TantivyError};
|
|
|
|
/// A weight struct for Fuzzy Term and Regex Queries
|
|
pub struct AutomatonWeight<A> {
|
|
field: Field,
|
|
automaton: Arc<A>,
|
|
// For JSON fields, the term dictionary include terms from all paths.
|
|
// We apply additional filtering based on the given JSON path, when searching within the term
|
|
// dictionary. This prevents terms from unrelated paths from matching the search criteria.
|
|
json_path_bytes: Option<Box<[u8]>>,
|
|
}
|
|
|
|
impl<A> AutomatonWeight<A>
|
|
where
|
|
A: Automaton + Send + Sync + 'static,
|
|
A::State: Clone,
|
|
{
|
|
/// Create a new AutomationWeight
|
|
pub fn new<IntoArcA: Into<Arc<A>>>(field: Field, automaton: IntoArcA) -> AutomatonWeight<A> {
|
|
AutomatonWeight {
|
|
field,
|
|
automaton: automaton.into(),
|
|
json_path_bytes: None,
|
|
}
|
|
}
|
|
|
|
/// Create a new AutomationWeight for a json path
|
|
pub fn new_for_json_path<IntoArcA: Into<Arc<A>>>(
|
|
field: Field,
|
|
automaton: IntoArcA,
|
|
json_path_bytes: &[u8],
|
|
) -> AutomatonWeight<A> {
|
|
AutomatonWeight {
|
|
field,
|
|
automaton: automaton.into(),
|
|
json_path_bytes: Some(json_path_bytes.to_vec().into_boxed_slice()),
|
|
}
|
|
}
|
|
|
|
fn automaton_stream<'a>(
|
|
&'a self,
|
|
term_dict: &'a TermDictionary,
|
|
) -> io::Result<TermStreamer<'a, &'a A>> {
|
|
let automaton: &A = &self.automaton;
|
|
let mut term_stream_builder = term_dict.search(automaton);
|
|
|
|
if let Some(json_path_bytes) = &self.json_path_bytes {
|
|
term_stream_builder = term_stream_builder.ge(json_path_bytes);
|
|
if let Some(end) = prefix_end(json_path_bytes) {
|
|
term_stream_builder = term_stream_builder.lt(&end);
|
|
}
|
|
}
|
|
|
|
term_stream_builder.into_stream()
|
|
}
|
|
|
|
/// Returns the term infos that match the automaton
|
|
pub fn get_match_term_infos(&self, reader: &SegmentReader) -> crate::Result<Vec<TermInfo>> {
|
|
let inverted_index = reader.inverted_index(self.field)?;
|
|
let term_dict = inverted_index.terms();
|
|
let mut term_stream = self.automaton_stream(term_dict)?;
|
|
let mut term_infos = Vec::new();
|
|
while term_stream.advance() {
|
|
term_infos.push(term_stream.value().clone());
|
|
}
|
|
Ok(term_infos)
|
|
}
|
|
}
|
|
|
|
impl<A> Weight for AutomatonWeight<A>
|
|
where
|
|
A: Automaton + Send + Sync + 'static,
|
|
A::State: Clone,
|
|
{
|
|
fn scorer(
|
|
&self,
|
|
reader: &SegmentReader,
|
|
boost: Score,
|
|
seek_doc: DocId,
|
|
) -> crate::Result<Box<dyn Scorer>> {
|
|
let max_doc = reader.max_doc();
|
|
let mut doc_bitset = BitSet::with_max_value(max_doc);
|
|
let inverted_index = reader.inverted_index(self.field)?;
|
|
let term_dict = inverted_index.terms();
|
|
let mut term_stream = self.automaton_stream(term_dict)?;
|
|
while term_stream.advance() {
|
|
let term_info = term_stream.value();
|
|
let (mut block_segment_postings, _) = inverted_index
|
|
.read_block_postings_from_terminfo_with_seek(
|
|
term_info,
|
|
IndexRecordOption::Basic,
|
|
seek_doc,
|
|
)?;
|
|
loop {
|
|
let docs = block_segment_postings.docs();
|
|
if docs.is_empty() {
|
|
break;
|
|
}
|
|
for &doc in docs {
|
|
doc_bitset.insert(doc);
|
|
}
|
|
block_segment_postings.advance();
|
|
}
|
|
}
|
|
let doc_bitset = BitSetDocSet::from(doc_bitset);
|
|
let const_scorer = ConstScorer::new(doc_bitset, boost);
|
|
Ok(Box::new(const_scorer))
|
|
}
|
|
|
|
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
|
let mut scorer = self.scorer(reader, 1.0, 0)?;
|
|
if scorer.seek(doc) == doc {
|
|
Ok(Explanation::new("AutomatonScorer", 1.0))
|
|
} else {
|
|
Err(TantivyError::InvalidArgument(
|
|
"Document does not exist".to_string(),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use tantivy_fst::Automaton;
|
|
|
|
use super::AutomatonWeight;
|
|
use crate::docset::TERMINATED;
|
|
use crate::query::Weight;
|
|
use crate::schema::{Schema, STRING};
|
|
use crate::{Index, IndexWriter};
|
|
|
|
fn create_index() -> crate::Result<Index> {
|
|
let mut schema = Schema::builder();
|
|
let title = schema.add_text_field("title", STRING);
|
|
let index = Index::create_in_ram(schema.build());
|
|
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
|
index_writer.add_document(doc!(title=>"abc"))?;
|
|
index_writer.add_document(doc!(title=>"bcd"))?;
|
|
index_writer.add_document(doc!(title=>"abcd"))?;
|
|
index_writer.commit()?;
|
|
Ok(index)
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
enum State {
|
|
Start,
|
|
NotMatching,
|
|
AfterA,
|
|
}
|
|
|
|
struct PrefixedByA;
|
|
|
|
impl Automaton for PrefixedByA {
|
|
type State = State;
|
|
|
|
fn start(&self) -> Self::State {
|
|
State::Start
|
|
}
|
|
|
|
fn is_match(&self, state: &Self::State) -> bool {
|
|
matches!(*state, State::AfterA)
|
|
}
|
|
|
|
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
|
|
match *state {
|
|
State::Start => {
|
|
if byte == b'a' {
|
|
State::AfterA
|
|
} else {
|
|
State::NotMatching
|
|
}
|
|
}
|
|
State::AfterA => State::AfterA,
|
|
State::NotMatching => State::NotMatching,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_automaton_weight() -> crate::Result<()> {
|
|
let index = create_index()?;
|
|
let field = index.schema().get_field("title").unwrap();
|
|
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
|
|
let reader = index.reader()?;
|
|
let searcher = reader.searcher();
|
|
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.0, 0)?;
|
|
assert_eq!(scorer.doc(), 0u32);
|
|
assert_eq!(scorer.score(), 1.0);
|
|
assert_eq!(scorer.advance(), 2u32);
|
|
assert_eq!(scorer.doc(), 2u32);
|
|
assert_eq!(scorer.score(), 1.0);
|
|
assert_eq!(scorer.advance(), TERMINATED);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_automaton_weight_boost() -> crate::Result<()> {
|
|
let index = create_index()?;
|
|
let field = index.schema().get_field("title").unwrap();
|
|
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
|
|
let reader = index.reader()?;
|
|
let searcher = reader.searcher();
|
|
let mut scorer = automaton_weight.scorer(searcher.segment_reader(0u32), 1.32, 0)?;
|
|
assert_eq!(scorer.doc(), 0u32);
|
|
assert_eq!(scorer.score(), 1.32);
|
|
Ok(())
|
|
}
|
|
}
|