use std::io; use std::ops::Bound; use tantivy_fst::automaton::AlwaysMatch; use tantivy_fst::Automaton; use crate::dictionary::Dictionary; use crate::{DeltaReader, SSTable, TermOrdinal}; /// `StreamerBuilder` is a helper object used to define /// a range of terms that should be streamed. pub struct StreamerBuilder<'a, TSSTable, A = AlwaysMatch> where A: Automaton, A::State: Clone, TSSTable: SSTable, { term_dict: &'a Dictionary, automaton: A, lower: Bound>, upper: Bound>, limit: Option, } fn bound_as_byte_slice(bound: &Bound>) -> Bound<&[u8]> { match bound.as_ref() { Bound::Included(key) => Bound::Included(key.as_slice()), Bound::Excluded(key) => Bound::Excluded(key.as_slice()), Bound::Unbounded => Bound::Unbounded, } } impl<'a, TSSTable, A> StreamerBuilder<'a, TSSTable, A> where A: Automaton, A::State: Clone, TSSTable: SSTable, { pub(crate) fn new(term_dict: &'a Dictionary, automaton: A) -> Self { StreamerBuilder { term_dict, automaton, lower: Bound::Unbounded, upper: Bound::Unbounded, limit: None, } } /// Limit the range to terms greater or equal to the bound pub fn ge>(mut self, bound: T) -> Self { self.lower = Bound::Included(bound.as_ref().to_owned()); self } /// Limit the range to terms strictly greater than the bound pub fn gt>(mut self, bound: T) -> Self { self.lower = Bound::Excluded(bound.as_ref().to_owned()); self } /// Limit the range to terms lesser or equal to the bound pub fn le>(mut self, bound: T) -> Self { self.upper = Bound::Included(bound.as_ref().to_owned()); self } /// Limit the range to terms lesser or equal to the bound pub fn lt>(mut self, bound: T) -> Self { self.upper = Bound::Excluded(bound.as_ref().to_owned()); self } /// Load no more data than what's required to to get `limit` /// matching entries. /// /// The resulting [`Streamer`] can still return marginaly /// more than `limit` elements. pub fn limit(mut self, limit: u64) -> Self { self.limit = Some(limit); self } fn delta_reader(&self) -> io::Result> { let key_range = ( bound_as_byte_slice(&self.lower), bound_as_byte_slice(&self.upper), ); self.term_dict .sstable_delta_reader_for_key_range(key_range, self.limit) } async fn delta_reader_async(&self) -> io::Result> { let key_range = ( bound_as_byte_slice(&self.lower), bound_as_byte_slice(&self.upper), ); self.term_dict .sstable_delta_reader_for_key_range_async(key_range, self.limit) .await } fn into_stream_given_delta_reader( self, delta_reader: DeltaReader<::ValueReader>, ) -> io::Result> { let start_state = self.automaton.start(); let start_key = bound_as_byte_slice(&self.lower); let first_term = match start_key { Bound::Included(key) | Bound::Excluded(key) => self .term_dict .sstable_index .get_block_with_key(key) .map(|block| block.first_ordinal) .unwrap_or(0), Bound::Unbounded => 0, }; Ok(Streamer { automaton: self.automaton, states: vec![start_state], delta_reader, key: Vec::new(), term_ord: first_term.checked_sub(1), lower_bound: self.lower, upper_bound: self.upper, _lifetime: std::marker::PhantomData, }) } /// See `into_stream(..)` pub async fn into_stream_async(self) -> io::Result> { let delta_reader = self.delta_reader_async().await?; self.into_stream_given_delta_reader(delta_reader) } /// Creates the stream corresponding to the range /// of terms defined using the `StreamerBuilder`. pub fn into_stream(self) -> io::Result> { let delta_reader = self.delta_reader()?; self.into_stream_given_delta_reader(delta_reader) } } /// `Streamer` acts as a cursor over a range of terms of a segment. /// Terms are guaranteed to be sorted. pub struct Streamer<'a, TSSTable, A = AlwaysMatch> where A: Automaton, A::State: Clone, TSSTable: SSTable, { automaton: A, states: Vec, delta_reader: crate::DeltaReader, key: Vec, term_ord: Option, lower_bound: Bound>, upper_bound: Bound>, // this field is used to please the type-interface of a dictionary in tantivy _lifetime: std::marker::PhantomData<&'a ()>, } impl<'a, TSSTable> Streamer<'a, TSSTable, AlwaysMatch> where TSSTable: SSTable { pub fn empty() -> Self { Streamer { automaton: AlwaysMatch, states: Vec::new(), delta_reader: DeltaReader::empty(), key: Vec::new(), term_ord: None, lower_bound: Bound::Unbounded, upper_bound: Bound::Unbounded, _lifetime: std::marker::PhantomData, } } } impl<'a, TSSTable, A> Streamer<'a, TSSTable, A> where A: Automaton, A::State: Clone, TSSTable: SSTable, { /// Advance position the stream on the next item. /// Before the first call to `.advance()`, the stream /// is an uninitialized state. pub fn advance(&mut self) -> bool { while self.delta_reader.advance().unwrap() { self.term_ord = Some( self.term_ord .map(|term_ord| term_ord + 1u64) .unwrap_or(0u64), ); let common_prefix_len = self.delta_reader.common_prefix_len(); self.states.truncate(common_prefix_len + 1); self.key.truncate(common_prefix_len); let mut state: A::State = self.states.last().unwrap().clone(); for &b in self.delta_reader.suffix() { state = self.automaton.accept(&state, b); self.states.push(state.clone()); } self.key.extend_from_slice(self.delta_reader.suffix()); let match_lower_bound = match &self.lower_bound { Bound::Unbounded => true, Bound::Included(lower_bound_key) => lower_bound_key[..] <= self.key[..], Bound::Excluded(lower_bound_key) => lower_bound_key[..] < self.key[..], }; if !match_lower_bound { continue; } // We match the lower key once. All subsequent keys will pass that bar. self.lower_bound = Bound::Unbounded; let match_upper_bound = match &self.upper_bound { Bound::Unbounded => true, Bound::Included(upper_bound_key) => upper_bound_key[..] >= self.key[..], Bound::Excluded(upper_bound_key) => upper_bound_key[..] > self.key[..], }; if !match_upper_bound { return false; } if self.automaton.is_match(&state) { return true; } } false } /// Returns the `TermOrdinal` of the given term. /// /// May panic if the called as `.advance()` as never /// been called before. pub fn term_ord(&self) -> TermOrdinal { self.term_ord.unwrap_or(0u64) } /// Accesses the current key. /// /// `.key()` should return the key that was returned /// by the `.next()` method. /// /// If the end of the stream as been reached, and `.next()` /// has been called and returned `None`, `.key()` remains /// the value of the last key encountered. /// /// Before any call to `.next()`, `.key()` returns an empty array. pub fn key(&self) -> &[u8] { &self.key } /// Accesses the current value. /// /// Calling `.value()` after the end of the stream will return the /// last `.value()` encountered. /// /// # Panics /// /// Calling `.value()` before the first call to `.advance()` returns /// `V::default()`. pub fn value(&self) -> &TSSTable::Value { self.delta_reader.value() } /// Return the next `(key, value)` pair. #[allow(clippy::should_implement_trait)] pub fn next(&mut self) -> Option<(&[u8], &TSSTable::Value)> { if self.advance() { Some((self.key(), self.value())) } else { None } } } #[cfg(test)] mod tests { use std::io; use common::OwnedBytes; use crate::{Dictionary, MonotonicU64SSTable}; fn create_test_dictionary() -> io::Result> { let mut dict_builder = Dictionary::::builder(Vec::new())?; dict_builder.insert(b"abaisance", &0)?; dict_builder.insert(b"abalation", &1)?; dict_builder.insert(b"abalienate", &2)?; dict_builder.insert(b"abandon", &3)?; let buffer = dict_builder.finish()?; let owned_bytes = OwnedBytes::new(buffer); Dictionary::from_bytes(owned_bytes) } #[test] fn test_sstable_stream() -> io::Result<()> { let dict = create_test_dictionary()?; let mut streamer = dict.stream()?; assert!(streamer.advance()); assert_eq!(streamer.key(), b"abaisance"); assert_eq!(streamer.value(), &0); assert!(streamer.advance()); assert_eq!(streamer.key(), b"abalation"); assert_eq!(streamer.value(), &1); assert!(streamer.advance()); assert_eq!(streamer.key(), b"abalienate"); assert_eq!(streamer.value(), &2); assert!(streamer.advance()); assert_eq!(streamer.key(), b"abandon"); assert_eq!(streamer.value(), &3); assert!(!streamer.advance()); Ok(()) } #[test] fn test_sstable_search() -> io::Result<()> { let term_dict = create_test_dictionary()?; let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap(); let mut term_streamer = term_dict.search(ptn).into_stream()?; assert!(term_streamer.advance()); assert_eq!(term_streamer.key(), b"abalation"); assert_eq!(term_streamer.value(), &1u64); assert!(term_streamer.advance()); assert_eq!(term_streamer.key(), b"abalienate"); assert_eq!(term_streamer.value(), &2u64); assert!(!term_streamer.advance()); Ok(()) } }