diff --git a/Cargo.toml b/Cargo.toml index cfa706a4f..5195cec66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ snap = "1" tempfile = {version="3", optional=true} log = "0.4" serde = {version="1", features=["derive"]} +serde_cbor = "0.11" serde_json = "1" num_cpus = "1" fs2={version="0.4", optional=true} diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 8d45c1499..971565ead 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -1,3 +1,5 @@ +use rayon::iter::IntoParallelRefIterator; + use crate::core::SegmentReader; use crate::postings::FreqReadingOption; use crate::query::explanation::does_not_match; @@ -22,7 +24,7 @@ enum SpecializedScorer { fn scorer_union(scorers: Vec>) -> SpecializedScorer where - TScoreCombiner: ScoreCombiner, + TScoreCombiner: ScoreCombiner + Send, { assert!(!scorers.is_empty()); if scorers.len() == 1 { @@ -52,7 +54,7 @@ where SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers))) } -fn into_box_scorer(scorer: SpecializedScorer) -> Box { +fn into_box_scorer(scorer: SpecializedScorer) -> Box { match scorer { SpecializedScorer::TermUnion(term_scorers) => { let union_scorer = Union::::from(term_scorers); @@ -80,18 +82,32 @@ impl BooleanWeight { reader: &SegmentReader, boost: Score, ) -> crate::Result>>> { + use rayon::iter::ParallelIterator; + use rayon::iter::IndexedParallelIterator; let mut per_occur_scorers: HashMap>> = HashMap::new(); - for &(ref occur, ref subweight) in &self.weights { - let sub_scorer: Box = subweight.scorer(reader, boost)?; + let mut items_res: Vec)>> = Vec::new(); + let pool = rayon::ThreadPoolBuilder::new().num_threads(self.weights.len()).build().unwrap(); + pool.install(|| { + self.weights.iter() + .collect::>() + .par_iter() + .map(|(occur, subweight)| { + let sub_scorer: Box = subweight.scorer(reader, boost)?; + Ok((*occur, sub_scorer)) + }) + .collect_into_vec(&mut items_res); + }); + for item_res in items_res { + let (occur, sub_scorer) = item_res?; per_occur_scorers - .entry(*occur) + .entry(occur) .or_insert_with(Vec::new) .push(sub_scorer); } Ok(per_occur_scorers) } - fn complex_scorer( + fn complex_scorer( &self, reader: &SegmentReader, boost: Score, diff --git a/src/termdict/mod.rs b/src/termdict/mod.rs index d8b8acb62..475fea7bd 100644 --- a/src/termdict/mod.rs +++ b/src/termdict/mod.rs @@ -22,8 +22,10 @@ A second datastructure makes it possible to access a [`TermInfo`](../postings/st use tantivy_fst::automaton::AlwaysMatch; -mod fst_termdict; -use fst_termdict as termdict; +// mod fst_termdict; +// use fst_termdict as termdict; +mod sstable_termdict; +use sstable_termdict as termdict; mod merger; diff --git a/src/termdict/sstable_termdict/mod.rs b/src/termdict/sstable_termdict/mod.rs new file mode 100644 index 000000000..dc5c868ce --- /dev/null +++ b/src/termdict/sstable_termdict/mod.rs @@ -0,0 +1,148 @@ +use std::io; + +mod sstable; +mod streamer; +mod termdict; + +use self::sstable::value::{ValueReader, ValueWriter}; +use self::sstable::{BlockReader, SSTable}; + +use crate::common::VInt; +use crate::postings::TermInfo; + +pub use self::streamer::{TermStreamer, TermStreamerBuilder}; +pub use self::termdict::{TermDictionary, TermDictionaryBuilder}; + +pub struct TermSSTable; + +impl SSTable for TermSSTable { + type Value = TermInfo; + type Reader = TermInfoReader; + type Writer = TermInfoWriter; +} + +#[derive(Default)] +pub struct TermInfoReader { + term_infos: Vec, +} + +impl ValueReader for TermInfoReader { + type Value = TermInfo; + + fn value(&self, idx: usize) -> &TermInfo { + &self.term_infos[idx] + } + + fn read(&mut self, reader: &mut BlockReader) -> io::Result<()> { + self.term_infos.clear(); + let num_els = VInt::deserialize_u64(reader)?; + let mut start_offset = VInt::deserialize_u64(reader)?; + let mut positions_idx = 0; + for _ in 0..num_els { + let doc_freq = VInt::deserialize_u64(reader)? as u32; + let posting_num_bytes = VInt::deserialize_u64(reader)?; + let stop_offset = start_offset + posting_num_bytes; + let delta_positions_idx = VInt::deserialize_u64(reader)?; + positions_idx += delta_positions_idx; + let term_info = TermInfo { + doc_freq, + postings_start_offset: start_offset, + postings_stop_offset: stop_offset, + positions_idx, + }; + self.term_infos.push(term_info); + start_offset = stop_offset; + } + Ok(()) + } +} + +#[derive(Default)] +pub struct TermInfoWriter { + term_infos: Vec, +} + +impl ValueWriter for TermInfoWriter { + type Value = TermInfo; + + fn write(&mut self, term_info: &TermInfo) { + self.term_infos.push(term_info.clone()); + } + + fn write_block(&mut self, buffer: &mut Vec) { + VInt(self.term_infos.len() as u64).serialize_into_vec(buffer); + if self.term_infos.is_empty() { + return; + } + let mut prev_position_idx = 0u64; + VInt(self.term_infos[0].postings_start_offset).serialize_into_vec(buffer); + for term_info in &self.term_infos { + VInt(term_info.doc_freq as u64).serialize_into_vec(buffer); + VInt(term_info.postings_stop_offset - term_info.postings_start_offset) + .serialize_into_vec(buffer); + VInt(term_info.positions_idx - prev_position_idx).serialize_into_vec(buffer); + prev_position_idx = term_info.positions_idx; + } + self.term_infos.clear(); + } +} + +#[cfg(test)] +mod tests { + use std::io; + + use super::BlockReader; + + use crate::directory::OwnedBytes; + use crate::postings::TermInfo; + use crate::termdict::sstable_termdict::sstable::value::{ValueReader, ValueWriter}; + use crate::termdict::sstable_termdict::TermInfoReader; + + #[test] + fn test_block_terminfos() -> io::Result<()> { + let mut term_info_writer = super::TermInfoWriter::default(); + term_info_writer.write(&TermInfo { + doc_freq: 120u32, + postings_start_offset: 17u64, + postings_stop_offset: 45u64, + positions_idx: 10u64, + }); + term_info_writer.write(&TermInfo { + doc_freq: 10u32, + postings_start_offset: 45u64, + postings_stop_offset: 450u64, + positions_idx: 104u64, + }); + term_info_writer.write(&TermInfo { + doc_freq: 17u32, + postings_start_offset: 450u64, + postings_stop_offset: 462u64, + positions_idx: 210u64, + }); + let mut buffer = Vec::new(); + term_info_writer.write_block(&mut buffer); + let mut block_reader = make_block_reader(&buffer[..]); + let mut term_info_reader = TermInfoReader::default(); + term_info_reader.read(&mut block_reader)?; + assert_eq!( + term_info_reader.value(0), + &TermInfo { + doc_freq: 120u32, + postings_start_offset: 17u64, + postings_stop_offset: 45u64, + positions_idx: 10u64 + } + ); + assert!(block_reader.buffer().is_empty()); + Ok(()) + } + + fn make_block_reader(data: &[u8]) -> BlockReader { + let mut buffer = (data.len() as u32).to_le_bytes().to_vec(); + buffer.extend_from_slice(data); + let owned_bytes = OwnedBytes::new(buffer); + let mut block_reader = BlockReader::new(Box::new(owned_bytes)); + block_reader.read_block().unwrap(); + block_reader + } +} diff --git a/src/termdict/sstable_termdict/sstable/block_reader.rs b/src/termdict/sstable_termdict/sstable/block_reader.rs new file mode 100644 index 000000000..8c2634055 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/block_reader.rs @@ -0,0 +1,84 @@ +use byteorder::{LittleEndian, ReadBytesExt}; +use std::io::{self, Read}; + +pub struct BlockReader<'a> { + buffer: Vec, + reader: Box, + offset: usize, +} + +impl<'a> BlockReader<'a> { + pub fn new(reader: Box) -> BlockReader<'a> { + BlockReader { + buffer: Vec::new(), + reader, + offset: 0, + } + } + + pub fn deserialize_u64(&mut self) -> u64 { + let (num_bytes, val) = super::vint::deserialize_read(self.buffer()); + self.advance(num_bytes); + val + } + + #[inline(always)] + pub fn buffer_from_to(&self, start: usize, end: usize) -> &[u8] { + &self.buffer[start..end] + } + + pub fn buffer_from(&self, start: usize) -> &[u8] { + &self.buffer[start..] + } + + pub fn read_block(&mut self) -> io::Result { + self.offset = 0; + let block_len_res = self.reader.read_u32::(); + if let Err(err) = &block_len_res { + if err.kind() == io::ErrorKind::UnexpectedEof { + return Ok(false); + } + } + let block_len = block_len_res?; + if block_len == 0u32 { + self.buffer.clear(); + return Ok(false); + } + self.buffer.resize(block_len as usize, 0u8); + self.reader.read_exact(&mut self.buffer[..])?; + Ok(true) + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn advance(&mut self, num_bytes: usize) { + self.offset += num_bytes; + } + + pub fn buffer(&self) -> &[u8] { + &self.buffer[self.offset..] + } +} + +impl<'a> io::Read for BlockReader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = self.buffer().read(buf)?; + self.advance(len); + Ok(len) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + let len = self.buffer.len(); + buf.extend_from_slice(self.buffer()); + self.advance(len); + Ok(len) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.buffer().read_exact(buf)?; + self.advance(buf.len()); + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/sstable/delta.rs b/src/termdict/sstable_termdict/sstable/delta.rs new file mode 100644 index 000000000..68d5b42fe --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/delta.rs @@ -0,0 +1,203 @@ +use std::io::{self, BufWriter, Write}; + +use crate::common::CountingWriter; + +use super::value::ValueWriter; +use super::{value, vint, BlockReader}; + +const FOUR_BIT_LIMITS: usize = 1 << 4; +const VINT_MODE: u8 = 1u8; +const BLOCK_LEN: usize = 256_000; + +pub struct DeltaWriter +where + W: io::Write, +{ + block: Vec, + write: CountingWriter>, + value_writer: TValueWriter, +} + +impl DeltaWriter +where + W: io::Write, + TValueWriter: ValueWriter, +{ + pub fn new(wrt: W) -> Self { + DeltaWriter { + block: Vec::with_capacity(BLOCK_LEN * 2), + write: CountingWriter::wrap(BufWriter::new(wrt)), + value_writer: TValueWriter::default(), + } + } +} + +impl DeltaWriter +where + W: io::Write, + TValueWriter: value::ValueWriter, +{ + pub fn flush_block(&mut self) -> io::Result> { + if self.block.is_empty() { + return Ok(None); + } + let start_offset = self.write.written_bytes(); + // TODO avoid buffer allocation + let mut buffer = Vec::new(); + self.value_writer.write_block(&mut buffer); + let block_len = buffer.len() + self.block.len(); + self.write.write_all(&(block_len as u32).to_le_bytes())?; + self.write.write_all(&buffer[..])?; + self.write.write_all(&mut self.block[..])?; + let end_offset = self.write.written_bytes(); + self.block.clear(); + Ok(Some((start_offset, end_offset))) + } + + fn encode_keep_add(&mut self, keep_len: usize, add_len: usize) { + if keep_len < FOUR_BIT_LIMITS && add_len < FOUR_BIT_LIMITS { + let b = (keep_len | add_len << 4) as u8; + self.block.extend_from_slice(&[b]) + } else { + let mut buf = [VINT_MODE; 20]; + let mut len = 1 + vint::serialize(keep_len as u64, &mut buf[1..]); + len += vint::serialize(add_len as u64, &mut buf[len..]); + self.block.extend_from_slice(&mut buf[..len]) + } + } + + pub(crate) fn write_suffix(&mut self, common_prefix_len: usize, suffix: &[u8]) { + let keep_len = common_prefix_len; + let add_len = suffix.len(); + self.encode_keep_add(keep_len, add_len); + self.block.extend_from_slice(suffix); + } + + pub(crate) fn write_value(&mut self, value: &TValueWriter::Value) { + self.value_writer.write(value); + } + + pub fn write_delta( + &mut self, + common_prefix_len: usize, + suffix: &[u8], + value: &TValueWriter::Value, + ) { + self.write_suffix(common_prefix_len, suffix); + self.write_value(value); + } + + pub fn flush_block_if_required(&mut self) -> io::Result> { + if self.block.len() > BLOCK_LEN { + return self.flush_block(); + } + Ok(None) + } + + pub fn finalize(mut self) -> CountingWriter> { + self.write + } +} + +pub struct DeltaReader<'a, TValueReader> { + common_prefix_len: usize, + suffix_start: usize, + suffix_end: usize, + value_reader: TValueReader, + block_reader: BlockReader<'a>, + idx: usize, +} + +impl<'a, TValueReader> DeltaReader<'a, TValueReader> +where + TValueReader: value::ValueReader, +{ + pub fn new(reader: R) -> Self { + DeltaReader { + idx: 0, + common_prefix_len: 0, + suffix_start: 0, + suffix_end: 0, + value_reader: TValueReader::default(), + block_reader: BlockReader::new(Box::new(reader)), + } + } + + fn deserialize_vint(&mut self) -> u64 { + self.block_reader.deserialize_u64() + } + + fn read_keep_add(&mut self) -> Option<(usize, usize)> { + let b = { + let buf = &self.block_reader.buffer(); + if buf.is_empty() { + return None; + } + buf[0] + }; + self.block_reader.advance(1); + match b { + VINT_MODE => { + let keep = self.deserialize_vint() as usize; + let add = self.deserialize_vint() as usize; + Some((keep, add)) + } + b => { + let keep = (b & 0b1111) as usize; + let add = (b >> 4) as usize; + Some((keep, add)) + } + } + } + + fn read_delta_key(&mut self) -> bool { + if let Some((keep, add)) = self.read_keep_add() { + self.common_prefix_len = keep; + self.suffix_start = self.block_reader.offset(); + self.suffix_end = self.suffix_start + add; + self.block_reader.advance(add); + true + } else { + false + } + } + + pub fn advance(&mut self) -> io::Result { + if self.block_reader.buffer().is_empty() { + if !self.block_reader.read_block()? { + return Ok(false); + } + self.value_reader.read(&mut self.block_reader)?; + self.idx = 0; + } else { + self.idx += 1; + } + if !self.read_delta_key() { + return Ok(false); + } + Ok(true) + } + + pub fn common_prefix_len(&self) -> usize { + self.common_prefix_len + } + + pub fn suffix(&self) -> &[u8] { + &self + .block_reader + .buffer_from_to(self.suffix_start, self.suffix_end) + } + + pub fn suffix_from(&self, offset: usize) -> &[u8] { + &self.block_reader.buffer_from_to( + self.suffix_start + .wrapping_add(offset) + .wrapping_sub(self.common_prefix_len), + self.suffix_end, + ) + } + + pub fn value(&self) -> &TValueReader::Value { + self.value_reader.value(self.idx) + } +} diff --git a/src/termdict/sstable_termdict/sstable/merge/heap_merge.rs b/src/termdict/sstable_termdict/sstable/merge/heap_merge.rs new file mode 100644 index 000000000..4693707e8 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/merge/heap_merge.rs @@ -0,0 +1,72 @@ +use crate::termdict::sstable_termdict::sstable::{Reader, SSTable, Writer}; + +use super::SingleValueMerger; +use super::ValueMerger; +use std::cmp::Ordering; +use std::collections::binary_heap::PeekMut; +use std::collections::BinaryHeap; +use std::io; + +struct HeapItem>(B); + +impl> Ord for HeapItem { + fn cmp(&self, other: &Self) -> Ordering { + other.0.as_ref().cmp(self.0.as_ref()) + } +} +impl> PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(other.0.as_ref().cmp(self.0.as_ref())) + } +} + +impl> Eq for HeapItem {} +impl> PartialEq for HeapItem { + fn eq(&self, other: &Self) -> bool { + self.0.as_ref() == other.0.as_ref() + } +} + +pub fn merge_sstable>( + readers: Vec>, + mut writer: Writer, + mut merger: M, +) -> io::Result<()> { + let mut heap: BinaryHeap>> = + BinaryHeap::with_capacity(readers.len()); + for mut reader in readers { + if reader.advance()? { + heap.push(HeapItem(reader)); + } + } + loop { + let len = heap.len(); + let mut value_merger; + if let Some(mut head) = heap.peek_mut() { + writer.write_key(head.0.key()); + value_merger = merger.new_value(head.0.value()); + if !head.0.advance()? { + PeekMut::pop(head); + } + } else { + break; + } + for _ in 0..len - 1 { + if let Some(mut head) = heap.peek_mut() { + if head.0.key() == writer.current_key() { + value_merger.add(head.0.value()); + if !head.0.advance()? { + PeekMut::pop(head); + } + continue; + } + } + break; + } + let value = value_merger.finish(); + writer.write_value(&value); + writer.flush_block_if_required()?; + } + writer.finalize()?; + Ok(()) +} diff --git a/src/termdict/sstable_termdict/sstable/merge/mod.rs b/src/termdict/sstable_termdict/sstable/merge/mod.rs new file mode 100644 index 000000000..14805ecb6 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/merge/mod.rs @@ -0,0 +1,184 @@ +mod heap_merge; + +pub use self::heap_merge::merge_sstable; + +pub trait SingleValueMerger { + fn add(&mut self, v: &V); + fn finish(self) -> V; +} + +pub trait ValueMerger { + type TSingleValueMerger: SingleValueMerger; + fn new_value(&mut self, v: &V) -> Self::TSingleValueMerger; +} + +#[derive(Default)] +pub struct KeepFirst; + +pub struct FirstVal(V); + +impl ValueMerger for KeepFirst { + type TSingleValueMerger = FirstVal; + + fn new_value(&mut self, v: &V) -> FirstVal { + FirstVal(v.clone()) + } +} + +impl SingleValueMerger for FirstVal { + fn add(&mut self, _: &V) {} + + fn finish(self) -> V { + self.0 + } +} + +pub struct VoidMerge; +impl ValueMerger<()> for VoidMerge { + type TSingleValueMerger = (); + + fn new_value(&mut self, _: &()) -> () { + () + } +} + +pub struct U64Merge; +impl ValueMerger for U64Merge { + type TSingleValueMerger = u64; + + fn new_value(&mut self, val: &u64) -> u64 { + *val + } +} + +impl SingleValueMerger for u64 { + fn add(&mut self, val: &u64) { + *self += *val; + } + + fn finish(self) -> u64 { + self + } +} + +impl SingleValueMerger<()> for () { + fn add(&mut self, _: &()) {} + + fn finish(self) -> () { + () + } +} + +#[cfg(test)] +mod tests { + + use super::super::SSTable; + use super::super::{SSTableMonotonicU64, VoidSSTable}; + use super::U64Merge; + use super::VoidMerge; + use std::collections::{BTreeMap, BTreeSet}; + use std::str; + + fn write_sstable(keys: &[&'static str]) -> Vec { + let mut buffer: Vec = vec![]; + { + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + for &key in keys { + assert!(sstable_writer.write(key.as_bytes(), &()).is_ok()); + } + assert!(sstable_writer.finalize().is_ok()); + } + dbg!(&buffer); + buffer + } + + fn write_sstable_u64(keys: &[(&'static str, u64)]) -> Vec { + let mut buffer: Vec = vec![]; + { + let mut sstable_writer = SSTableMonotonicU64::writer(&mut buffer); + for (key, val) in keys { + assert!(sstable_writer.write(key.as_bytes(), val).is_ok()); + } + assert!(sstable_writer.finalize().is_ok()); + } + buffer + } + + fn merge_test_aux(arrs: &[&[&'static str]]) { + let sstables = arrs.iter().cloned().map(write_sstable).collect::>(); + let sstables_ref: Vec<&[u8]> = sstables.iter().map(|s| s.as_ref()).collect(); + let mut merged = BTreeSet::new(); + for &arr in arrs.iter() { + for &s in arr { + merged.insert(s.to_string()); + } + } + let mut w = Vec::new(); + assert!(VoidSSTable::merge(sstables_ref, &mut w, VoidMerge).is_ok()); + let mut reader = VoidSSTable::reader(&w[..]); + for k in merged { + assert!(reader.advance().unwrap()); + assert_eq!(reader.key(), k.as_bytes()); + } + assert!(!reader.advance().unwrap()); + } + + fn merge_test_u64_monotonic_aux(arrs: &[&[(&'static str, u64)]]) { + let sstables = arrs + .iter() + .cloned() + .map(write_sstable_u64) + .collect::>(); + let sstables_ref: Vec<&[u8]> = sstables.iter().map(|s| s.as_ref()).collect(); + let mut merged = BTreeMap::new(); + for &arr in arrs.iter() { + for (key, val) in arr { + let entry = merged.entry(key.to_string()).or_insert(0u64); + *entry += val; + } + } + let mut w = Vec::new(); + assert!(SSTableMonotonicU64::merge(sstables_ref, &mut w, U64Merge).is_ok()); + let mut reader = SSTableMonotonicU64::reader(&w[..]); + for (k, v) in merged { + assert!(reader.advance().unwrap()); + assert_eq!(reader.key(), k.as_bytes()); + assert_eq!(reader.value(), &v); + } + assert!(!reader.advance().unwrap()); + } + + #[test] + fn test_merge_simple_reproduce() { + let sstable_data = write_sstable(&["a"]); + let mut reader = VoidSSTable::reader(&sstable_data[..]); + assert!(reader.advance().unwrap()); + assert_eq!(reader.key(), b"a"); + assert!(!reader.advance().unwrap()); + } + + #[test] + fn test_merge() { + merge_test_aux(&[]); + merge_test_aux(&[&["a"]]); + merge_test_aux(&[&["a", "b"], &["ab"]]); // a, ab, b + merge_test_aux(&[&["a", "b"], &["a", "b"]]); + merge_test_aux(&[ + &["happy", "hello", "payer", "tax"], + &["habitat", "hello", "zoo"], + &[], + &["a"], + ]); + merge_test_aux(&[&["a"]]); + merge_test_aux(&[&["a", "b"], &["ab"]]); + merge_test_aux(&[&["a", "b"], &["a", "b"]]); + } + + #[test] + fn test_merge_u64() { + merge_test_u64_monotonic_aux(&[]); + merge_test_u64_monotonic_aux(&[&[("a", 1u64)]]); + merge_test_u64_monotonic_aux(&[&[("a", 1u64), ("b", 3u64)], &[("ab", 2u64)]]); // a, ab, b + merge_test_u64_monotonic_aux(&[&[("a", 1u64), ("b", 2u64)], &[("a", 16u64), ("b", 23u64)]]); + } +} diff --git a/src/termdict/sstable_termdict/sstable/mod.rs b/src/termdict/sstable_termdict/sstable/mod.rs new file mode 100644 index 000000000..0cd6ec5b8 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/mod.rs @@ -0,0 +1,365 @@ +use merge::ValueMerger; +use std::io::{self, Write}; +use std::usize; + +mod delta; +pub mod merge; +pub mod value; + +pub(crate) mod sstable_index; + +pub(crate) use self::sstable_index::{SSTableIndex, SSTableIndexBuilder}; +pub(crate) mod vint; + +mod block_reader; +pub use self::delta::DeltaReader; +use self::delta::DeltaWriter; +use self::value::{U64MonotonicReader, U64MonotonicWriter, ValueReader, ValueWriter}; + +pub use self::block_reader::BlockReader; +pub use self::merge::VoidMerge; + +const DEFAULT_KEY_CAPACITY: usize = 50; + +pub(crate) fn common_prefix_len(left: &[u8], right: &[u8]) -> usize { + left.iter() + .cloned() + .zip(right.iter().cloned()) + .take_while(|(left, right)| left == right) + .count() +} + +pub trait SSTable: Sized { + type Value; + type Reader: ValueReader; + type Writer: ValueWriter; + + fn delta_writer(write: W) -> DeltaWriter { + DeltaWriter::new(write) + } + + fn writer(write: W) -> Writer { + Writer { + previous_key: Vec::with_capacity(DEFAULT_KEY_CAPACITY), + num_terms: 0u64, + index_builder: SSTableIndexBuilder::default(), + delta_writer: Self::delta_writer(write), + first_ordinal_of_the_block: 0u64, + } + } + + fn delta_reader<'a, R: io::Read + 'a>(reader: R) -> DeltaReader<'a, Self::Reader> { + DeltaReader::new(reader) + } + + fn reader<'a, R: io::Read + 'a>(reader: R) -> Reader<'a, Self::Reader> { + Reader { + key: Vec::with_capacity(DEFAULT_KEY_CAPACITY), + delta_reader: Self::delta_reader(reader), + } + } + + fn merge>( + io_readers: Vec, + w: W, + merger: M, + ) -> io::Result<()> { + let readers: Vec<_> = io_readers.into_iter().map(Self::reader).collect(); + let writer = Self::writer(w); + merge::merge_sstable::(readers, writer, merger) + } +} + +pub struct VoidSSTable; + +impl SSTable for VoidSSTable { + type Value = (); + type Reader = value::VoidReader; + type Writer = value::VoidWriter; +} + +pub struct SSTableMonotonicU64; + +impl SSTable for SSTableMonotonicU64 { + type Value = u64; + + type Reader = U64MonotonicReader; + + type Writer = U64MonotonicWriter; +} + +pub struct Reader<'a, TValueReader> { + key: Vec, + delta_reader: DeltaReader<'a, TValueReader>, +} + +impl<'a, TValueReader> Reader<'a, TValueReader> +where + TValueReader: ValueReader, +{ + pub fn advance(&mut self) -> io::Result { + if !self.delta_reader.advance()? { + return Ok(false); + } + let common_prefix_len = self.delta_reader.common_prefix_len(); + let suffix = self.delta_reader.suffix(); + let new_len = self.delta_reader.common_prefix_len() + suffix.len(); + self.key.resize(new_len, 0u8); + self.key[common_prefix_len..].copy_from_slice(suffix); + Ok(true) + } + + pub fn key(&self) -> &[u8] { + &self.key + } + + pub fn value(&self) -> &TValueReader::Value { + self.delta_reader.value() + } + + pub(crate) fn into_delta_reader(self) -> DeltaReader<'a, TValueReader> { + assert!(self.key.is_empty()); + self.delta_reader + } +} + +impl<'a, TValueReader> AsRef<[u8]> for Reader<'a, TValueReader> { + fn as_ref(&self) -> &[u8] { + &self.key + } +} + +pub struct Writer +where + W: io::Write, +{ + previous_key: Vec, + index_builder: SSTableIndexBuilder, + delta_writer: DeltaWriter, + num_terms: u64, + first_ordinal_of_the_block: u64, +} + +impl Writer +where + W: io::Write, + TValueWriter: value::ValueWriter, +{ + pub(crate) fn current_key(&self) -> &[u8] { + &self.previous_key[..] + } + + pub fn write_key(&mut self, key: &[u8]) { + let keep_len = common_prefix_len(&self.previous_key, key); + let add_len = key.len() - keep_len; + let increasing_keys = add_len > 0 && (self.previous_key.len() == keep_len) + || self.previous_key.is_empty() + || self.previous_key[keep_len] < key[keep_len]; + assert!( + increasing_keys, + "Keys should be increasing. ({:?} > {:?})", + self.previous_key, key + ); + self.previous_key.resize(key.len(), 0u8); + self.previous_key[keep_len..].copy_from_slice(&key[keep_len..]); + self.delta_writer.write_suffix(keep_len, &key[keep_len..]); + } + + pub(crate) fn into_delta_writer(self) -> DeltaWriter { + self.delta_writer + } + + pub fn write(&mut self, key: &[u8], value: &TValueWriter::Value) -> io::Result<()> { + self.write_key(key); + self.write_value(value)?; + Ok(()) + } + + pub fn write_value(&mut self, value: &TValueWriter::Value) -> io::Result<()> { + self.delta_writer.write_value(value); + self.num_terms += 1u64; + self.flush_block_if_required() + } + + pub fn flush_block_if_required(&mut self) -> io::Result<()> { + if let Some((start_offset, end_offset)) = self.delta_writer.flush_block_if_required()? { + self.index_builder.add_block( + &self.previous_key[..], + start_offset, + end_offset, + self.first_ordinal_of_the_block, + ); + self.first_ordinal_of_the_block = self.num_terms; + self.previous_key.clear(); + } + Ok(()) + } + + pub fn finalize(mut self) -> io::Result { + if let Some((start_offset, end_offset)) = self.delta_writer.flush_block()? { + self.index_builder.add_block( + &self.previous_key[..], + start_offset, + end_offset, + self.first_ordinal_of_the_block, + ); + self.first_ordinal_of_the_block = self.num_terms; + } + let mut wrt = self.delta_writer.finalize(); + wrt.write_all(&0u32.to_le_bytes())?; + + let offset = wrt.written_bytes(); + + self.index_builder.serialize(&mut wrt)?; + wrt.write_all(&offset.to_le_bytes())?; + wrt.write_all(&self.num_terms.to_le_bytes())?; + let wrt = wrt.finish(); + Ok(wrt.into_inner()?) + } +} +#[cfg(test)] +mod test { + use std::io; + + use super::SSTable; + use super::VoidMerge; + use super::VoidSSTable; + use super::{common_prefix_len, SSTableMonotonicU64}; + + fn aux_test_common_prefix_len(left: &str, right: &str, expect_len: usize) { + assert_eq!( + common_prefix_len(left.as_bytes(), right.as_bytes()), + expect_len + ); + assert_eq!( + common_prefix_len(right.as_bytes(), left.as_bytes()), + expect_len + ); + } + + #[test] + fn test_common_prefix_len() { + aux_test_common_prefix_len("a", "ab", 1); + aux_test_common_prefix_len("", "ab", 0); + aux_test_common_prefix_len("ab", "abc", 2); + aux_test_common_prefix_len("abde", "abce", 2); + } + + #[test] + fn test_long_key_diff() { + let long_key = (0..1_024).map(|x| (x % 255) as u8).collect::>(); + let long_key2 = (1..300).map(|x| (x % 255) as u8).collect::>(); + let mut buffer = vec![]; + { + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + assert!(sstable_writer.write(&long_key[..], &()).is_ok()); + assert!(sstable_writer.write(&[0, 3, 4], &()).is_ok()); + assert!(sstable_writer.write(&long_key2[..], &()).is_ok()); + assert!(sstable_writer.finalize().is_ok()); + } + let mut sstable_reader = VoidSSTable::reader(&buffer[..]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &long_key[..]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[0, 3, 4]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &long_key2[..]); + assert!(!sstable_reader.advance().unwrap()); + } + + #[test] + fn test_simple_sstable() { + let mut buffer = vec![]; + { + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + assert!(sstable_writer.write(&[17u8], &()).is_ok()); + assert!(sstable_writer.write(&[17u8, 18u8, 19u8], &()).is_ok()); + assert!(sstable_writer.write(&[17u8, 20u8], &()).is_ok()); + assert!(sstable_writer.finalize().is_ok()); + } + assert_eq!( + &buffer, + &[ + // block len + 7u8, 0u8, 0u8, 0u8, // keep 0 push 1 | "" + 16u8, 17u8, // keep 1 push 2 | 18 19 + 33u8, 18u8, 19u8, // keep 1 push 1 | 20 + 17u8, 20u8, 0u8, 0u8, 0u8, 0u8, // no more blocks + // index + 161, 102, 98, 108, 111, 99, 107, 115, 129, 162, 104, 108, 97, 115, 116, 95, 107, + 101, 121, 130, 17, 20, 106, 98, 108, 111, 99, 107, 95, 97, 100, 100, 114, 163, 108, + 115, 116, 97, 114, 116, 95, 111, 102, 102, 115, 101, 116, 0, 106, 101, 110, 100, + 95, 111, 102, 102, 115, 101, 116, 11, 109, 102, 105, 114, 115, 116, 95, 111, 114, + 100, 105, 110, 97, 108, 0, 15, 0, 0, 0, 0, 0, 0, 0, // offset for the index + 3u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8 // num terms + ] + ); + let mut sstable_reader = VoidSSTable::reader(&buffer[..]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[17u8]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[17u8, 18u8, 19u8]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[17u8, 20u8]); + assert!(!sstable_reader.advance().unwrap()); + } + + #[test] + #[should_panic] + fn test_simple_sstable_non_increasing_key() { + let mut buffer = vec![]; + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + assert!(sstable_writer.write(&[17u8], &()).is_ok()); + assert!(sstable_writer.write(&[16u8], &()).is_ok()); + } + + #[test] + fn test_merge_abcd_abe() { + let mut buffer = Vec::new(); + { + let mut writer = VoidSSTable::writer(&mut buffer); + writer.write(b"abcd", &()).unwrap(); + writer.write(b"abe", &()).unwrap(); + writer.finalize().unwrap(); + } + let mut output = Vec::new(); + assert!(VoidSSTable::merge(vec![&buffer[..], &buffer[..]], &mut output, VoidMerge).is_ok()); + assert_eq!(&output[..], &buffer[..]); + } + + #[test] + fn test_sstable() { + let mut buffer = Vec::new(); + { + let mut writer = VoidSSTable::writer(&mut buffer); + writer.write(b"abcd", &()).unwrap(); + writer.write(b"abe", &()).unwrap(); + writer.finalize().unwrap(); + } + let mut output = Vec::new(); + assert!(VoidSSTable::merge(vec![&buffer[..], &buffer[..]], &mut output, VoidMerge).is_ok()); + assert_eq!(&output[..], &buffer[..]); + } + + #[test] + fn test_sstable_u64() -> io::Result<()> { + let mut buffer = Vec::new(); + let mut writer = SSTableMonotonicU64::writer(&mut buffer); + writer.write(b"abcd", &1u64)?; + writer.write(b"abe", &4u64)?; + writer.write(b"gogo", &4324234234234234u64)?; + writer.finalize()?; + let mut reader = SSTableMonotonicU64::reader(&buffer[..]); + assert!(reader.advance()?); + assert_eq!(reader.key(), b"abcd"); + assert_eq!(reader.value(), &1u64); + assert!(reader.advance()?); + assert_eq!(reader.key(), b"abe"); + assert_eq!(reader.value(), &4u64); + assert!(reader.advance()?); + assert_eq!(reader.key(), b"gogo"); + assert_eq!(reader.value(), &4324234234234234u64); + assert!(!reader.advance()?); + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/sstable/sstable_index.rs b/src/termdict/sstable_termdict/sstable/sstable_index.rs new file mode 100644 index 000000000..a74916fb1 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/sstable_index.rs @@ -0,0 +1,90 @@ +use std::io; + +use serde; +use serde::{Deserialize, Serialize}; + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct SSTableIndex { + blocks: Vec, +} + +impl SSTableIndex { + pub fn load(data: &[u8]) -> SSTableIndex { + // TODO + serde_cbor::de::from_slice(data).unwrap() + } + + pub fn search(&self, key: &[u8]) -> Option { + self.blocks + .iter() + .find(|block| &block.last_key[..] >= &key) + .map(|block| block.block_addr) + } +} + +#[derive(Clone, Eq, PartialEq, Debug, Copy, Serialize, Deserialize)] +pub struct BlockAddr { + pub start_offset: u64, + pub end_offset: u64, + pub first_ordinal: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +struct BlockMeta { + pub last_key: Vec, + pub block_addr: BlockAddr, +} + +#[derive(Default)] +pub struct SSTableIndexBuilder { + index: SSTableIndex, +} + +impl SSTableIndexBuilder { + pub fn add_block( + &mut self, + last_key: &[u8], + start_offset: u64, + stop_offset: u64, + first_ordinal: u64, + ) { + self.index.blocks.push(BlockMeta { + last_key: last_key.to_vec(), + block_addr: BlockAddr { + start_offset, + end_offset: stop_offset, + first_ordinal, + }, + }) + } + + pub fn serialize(&self, wrt: &mut dyn io::Write) -> io::Result<()> { + serde_cbor::ser::to_writer(wrt, &self.index).unwrap(); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{BlockAddr, SSTableIndex, SSTableIndexBuilder}; + + #[test] + fn test_sstable_index() { + let mut sstable_builder = SSTableIndexBuilder::default(); + sstable_builder.add_block(b"aaa", 10u64, 20u64, 0u64); + sstable_builder.add_block(b"bbbbbbb", 20u64, 30u64, 564); + sstable_builder.add_block(b"ccc", 30u64, 40u64, 10u64); + sstable_builder.add_block(b"dddd", 40u64, 50u64, 15u64); + let mut buffer: Vec = Vec::new(); + sstable_builder.serialize(&mut buffer).unwrap(); + let sstable = SSTableIndex::load(&buffer[..]); + assert_eq!( + sstable.search(b"bbbde"), + Some(BlockAddr { + first_ordinal: 10u64, + start_offset: 30u64, + end_offset: 40u64 + }) + ); + } +} diff --git a/src/termdict/sstable_termdict/sstable/value.rs b/src/termdict/sstable_termdict/sstable/value.rs new file mode 100644 index 000000000..b98b4584c --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/value.rs @@ -0,0 +1,94 @@ +use super::{vint, BlockReader}; +use std::io; + +pub trait ValueReader: Default { + type Value; + + fn value(&self, idx: usize) -> &Self::Value; + + fn read(&mut self, reader: &mut BlockReader) -> io::Result<()>; +} + +pub trait ValueWriter: Default { + type Value; + + fn write(&mut self, val: &Self::Value); + + fn write_block(&mut self, writer: &mut Vec); +} + +#[derive(Default)] +pub struct VoidReader; + +impl ValueReader for VoidReader { + type Value = (); + + fn value(&self, _idx: usize) -> &() { + &() + } + + fn read(&mut self, _reader: &mut BlockReader) -> io::Result<()> { + Ok(()) + } +} + +#[derive(Default)] +pub struct VoidWriter; + +impl ValueWriter for VoidWriter { + type Value = (); + + fn write(&mut self, _val: &()) {} + + fn write_block(&mut self, _writer: &mut Vec) {} +} + +#[derive(Default)] +pub struct U64MonotonicWriter { + vals: Vec, +} + +impl ValueWriter for U64MonotonicWriter { + type Value = u64; + + fn write(&mut self, val: &Self::Value) { + self.vals.push(*val); + } + + fn write_block(&mut self, writer: &mut Vec) { + let mut prev_val = 0u64; + vint::serialize_into_vec(self.vals.len() as u64, writer); + for &val in &self.vals { + let delta = val - prev_val; + vint::serialize_into_vec(delta, writer); + prev_val = val; + } + self.vals.clear(); + } +} + +#[derive(Default)] +pub struct U64MonotonicReader { + vals: Vec, +} + +impl ValueReader for U64MonotonicReader { + type Value = u64; + + fn value(&self, idx: usize) -> &Self::Value { + &self.vals[idx] + } + + fn read(&mut self, reader: &mut BlockReader) -> io::Result<()> { + let len = reader.deserialize_u64() as usize; + self.vals.clear(); + let mut prev_val = 0u64; + for _ in 0..len { + let delta = reader.deserialize_u64() as u64; + let val = prev_val + delta; + self.vals.push(val); + prev_val = val; + } + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/sstable/vint.rs b/src/termdict/sstable_termdict/sstable/vint.rs new file mode 100644 index 000000000..3aaadf357 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/vint.rs @@ -0,0 +1,74 @@ +use super::BlockReader; + +const CONTINUE_BIT: u8 = 128u8; + +pub fn serialize(mut val: u64, buffer: &mut [u8]) -> usize { + for (i, b) in buffer.iter_mut().enumerate() { + let next_byte: u8 = (val & 127u64) as u8; + val = val >> 7; + if val == 0u64 { + *b = next_byte; + return i + 1; + } else { + *b = next_byte | CONTINUE_BIT; + } + } + 10 //< actually unreachable +} + +pub fn serialize_into_vec(val: u64, buffer: &mut Vec) { + let mut buf = [0u8; 10]; + let num_bytes = serialize(val, &mut buf[..]); + buffer.extend_from_slice(&buf[..num_bytes]); +} + +// super slow but we don't care +pub fn deserialize_read(buf: &[u8]) -> (usize, u64) { + let mut result = 0u64; + let mut shift = 0u64; + let mut consumed = 0; + + for &b in buf { + consumed += 1; + result |= u64::from(b % 128u8) << shift; + if b < CONTINUE_BIT { + break; + } + shift += 7; + } + (consumed, result) +} + +pub fn deserialize_from_block(block: &mut BlockReader) -> u64 { + let (num_bytes, val) = deserialize_read(block.buffer()); + block.advance(num_bytes); + val +} + +#[cfg(test)] +mod tests { + use super::{deserialize_read, serialize}; + use std::u64; + + fn aux_test_int(val: u64, expect_len: usize) { + let mut buffer = [0u8; 14]; + assert_eq!(serialize(val, &mut buffer[..]), expect_len); + assert_eq!(deserialize_read(&buffer), (expect_len, val)); + } + + #[test] + fn test_vint() { + aux_test_int(0u64, 1); + aux_test_int(17u64, 1); + aux_test_int(127u64, 1); + aux_test_int(128u64, 2); + aux_test_int(123423418u64, 4); + for i in 1..63 { + let power_of_two = 1u64 << i; + aux_test_int(power_of_two + 1, (i / 7) + 1); + aux_test_int(power_of_two, (i / 7) + 1); + aux_test_int(power_of_two - 1, ((i - 1) / 7) + 1); + } + aux_test_int(u64::MAX, 10); + } +} diff --git a/src/termdict/sstable_termdict/streamer.rs b/src/termdict/sstable_termdict/streamer.rs new file mode 100644 index 000000000..de9d49531 --- /dev/null +++ b/src/termdict/sstable_termdict/streamer.rs @@ -0,0 +1,227 @@ +use super::TermDictionary; +use crate::postings::TermInfo; +use crate::termdict::sstable_termdict::TermInfoReader; +use crate::termdict::TermOrdinal; +use std::io; +use std::ops::Bound; +use tantivy_fst::automaton::AlwaysMatch; +use tantivy_fst::Automaton; + +/// `TermStreamerBuilder` is a helper object used to define +/// a range of terms that should be streamed. +pub struct TermStreamerBuilder<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + term_dict: &'a TermDictionary, + automaton: A, + lower: Bound>, + upper: Bound>, +} + +impl<'a, A> TermStreamerBuilder<'a, A> +where + A: Automaton, + A::State: Clone, +{ + pub(crate) fn new(term_dict: &'a TermDictionary, automaton: A) -> Self { + TermStreamerBuilder { + term_dict, + automaton, + lower: Bound::Unbounded, + upper: Bound::Unbounded, + } + } + + /// Limit the range to terms greater or equal to the bound + pub fn ge>(mut self, bound: T) -> Self { + self.lower = Bound::Included(bound.as_ref().to_owned()); + self + } + + /// Limit the range to terms strictly greater than the bound + pub fn gt>(mut self, bound: T) -> Self { + self.lower = Bound::Excluded(bound.as_ref().to_owned()); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn le>(mut self, bound: T) -> Self { + self.upper = Bound::Included(bound.as_ref().to_owned()); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn lt>(mut self, bound: T) -> Self { + self.lower = Bound::Excluded(bound.as_ref().to_owned()); + self + } + + pub fn backward(mut self) -> Self { + unimplemented!() + } + + /// Creates the stream corresponding to the range + /// of terms defined using the `TermStreamerBuilder`. + pub fn into_stream(self) -> io::Result> { + let start_state = self.automaton.start(); + let delta_reader = self.term_dict.sstable_delta_reader()?; + Ok(TermStreamer { + automaton: self.automaton, + states: vec![start_state], + delta_reader, + key: Vec::new(), + term_ord: 0u64, + }) + } +} + +/// `TermStreamer` acts as a cursor over a range of terms of a segment. +/// Terms are guaranteed to be sorted. +pub struct TermStreamer<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + automaton: A, + states: Vec, + delta_reader: super::sstable::DeltaReader<'a, TermInfoReader>, + key: Vec, + term_ord: TermOrdinal, +} + +impl<'a, A> TermStreamer<'a, A> +where + A: Automaton, + A::State: Clone, +{ + /// Advance position the stream on the next item. + /// Before the first call to `.advance()`, the stream + /// is an unitialized state. + pub fn advance(&mut self) -> bool { + while self.delta_reader.advance().unwrap() { + self.term_ord += 1u64; + let common_prefix_len = self.delta_reader.common_prefix_len(); + self.states.truncate(common_prefix_len + 1); + self.key.truncate(common_prefix_len); + let mut state: A::State = self.states.last().unwrap().clone(); + for &b in self.delta_reader.suffix() { + state = self.automaton.accept(&state, b); + self.states.push(state.clone()); + } + self.key.extend_from_slice(self.delta_reader.suffix()); + if self.automaton.is_match(&state) { + return true; + } + } + false + } + + /// Returns the `TermOrdinal` of the given term. + /// + /// May panic if the called as `.advance()` as never + /// been called before. + pub fn term_ord(&self) -> TermOrdinal { + self.term_ord + } + + /// Accesses the current key. + /// + /// `.key()` should return the key that was returned + /// by the `.next()` method. + /// + /// If the end of the stream as been reached, and `.next()` + /// has been called and returned `None`, `.key()` remains + /// the value of the last key encountered. + /// + /// Before any call to `.next()`, `.key()` returns an empty array. + pub fn key(&self) -> &[u8] { + &self.key + } + + /// Accesses the current value. + /// + /// Calling `.value()` after the end of the stream will return the + /// last `.value()` encountered. + /// + /// # Panics + /// + /// Calling `.value()` before the first call to `.advance()` returns + /// `V::default()`. + pub fn value(&self) -> &TermInfo { + self.delta_reader.value() + } + + /// Return the next `(key, value)` pair. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::should_implement_trait))] + pub fn next(&mut self) -> Option<(&[u8], &TermInfo)> { + if self.advance() { + Some((self.key(), self.value())) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::super::TermDictionary; + use crate::directory::OwnedBytes; + use crate::postings::TermInfo; + + fn make_term_info(i: u64) -> TermInfo { + TermInfo { + doc_freq: 1000u32 + i as u32, + positions_idx: i * 500, + postings_start_offset: (i + 10) * (i * 10), + postings_stop_offset: ((i + 1) + 10) * ((i + 1) * 10), + } + } + + fn create_test_term_dictionary() -> crate::Result { + let mut term_dict_builder = super::super::TermDictionaryBuilder::create(Vec::new())?; + term_dict_builder.insert(b"abaisance", &make_term_info(0u64))?; + term_dict_builder.insert(b"abalation", &make_term_info(1u64))?; + term_dict_builder.insert(b"abalienate", &make_term_info(2u64))?; + term_dict_builder.insert(b"abandon", &make_term_info(3u64))?; + let buffer = term_dict_builder.finish()?; + let owned_bytes = OwnedBytes::new(buffer); + TermDictionary::from_bytes(owned_bytes) + } + + #[test] + fn test_sstable_stream() -> crate::Result<()> { + let term_dict = create_test_term_dictionary()?; + let mut term_streamer = term_dict.stream()?; + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abaisance"); + assert_eq!(term_streamer.value().doc_freq, 1000u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalation"); + assert_eq!(term_streamer.value().doc_freq, 1001u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalienate"); + assert_eq!(term_streamer.value().doc_freq, 1002u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abandon"); + assert_eq!(term_streamer.value().doc_freq, 1003u32); + assert!(!term_streamer.advance()); + Ok(()) + } + + #[test] + fn test_sstable_search() -> crate::Result<()> { + let term_dict = create_test_term_dictionary()?; + let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap(); + let mut term_streamer = term_dict.search(ptn).into_stream()?; + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalation"); + assert_eq!(term_streamer.value().doc_freq, 1001u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalienate"); + assert_eq!(term_streamer.value().doc_freq, 1002u32); + assert!(!term_streamer.advance()); + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/termdict.rs b/src/termdict/sstable_termdict/termdict.rs new file mode 100644 index 000000000..458462451 --- /dev/null +++ b/src/termdict/sstable_termdict/termdict.rs @@ -0,0 +1,228 @@ +use std::io; + +use crate::common::BinarySerializable; +use crate::directory::{FileSlice, OwnedBytes}; +use crate::postings::TermInfo; +use crate::termdict::sstable_termdict::sstable::sstable_index::BlockAddr; +use crate::termdict::sstable_termdict::sstable::Writer; +use crate::termdict::sstable_termdict::sstable::{DeltaReader, SSTable}; +use crate::termdict::sstable_termdict::sstable::{Reader, SSTableIndex}; +use crate::termdict::sstable_termdict::{ + TermInfoReader, TermInfoWriter, TermSSTable, TermStreamer, TermStreamerBuilder, +}; +use crate::termdict::TermOrdinal; +use crate::HasLen; +use once_cell::sync::Lazy; +use tantivy_fst::automaton::AlwaysMatch; +use tantivy_fst::Automaton; + +pub struct TermInfoSSTable; +impl SSTable for TermInfoSSTable { + type Value = TermInfo; + type Reader = TermInfoReader; + type Writer = TermInfoWriter; +} +pub struct TermDictionaryBuilder { + sstable_writer: Writer, +} + +impl TermDictionaryBuilder { + /// Creates a new `TermDictionaryBuilder` + pub fn create(w: W) -> io::Result { + let sstable_writer = TermSSTable::writer(w); + Ok(TermDictionaryBuilder { sstable_writer }) + } + + /// Inserts a `(key, value)` pair in the term dictionary. + /// + /// *Keys have to be inserted in order.* + pub fn insert>(&mut self, key_ref: K, value: &TermInfo) -> io::Result<()> { + let key = key_ref.as_ref(); + self.insert_key(key)?; + self.insert_value(value)?; + Ok(()) + } + + /// # Warning + /// Horribly dangerous internal API + /// + /// If used, it must be used by systematically alternating calls + /// to insert_key and insert_value. + /// + /// Prefer using `.insert(key, value)` + pub(crate) fn insert_key(&mut self, key: &[u8]) -> io::Result<()> { + self.sstable_writer.write_key(key); + Ok(()) + } + + /// # Warning + /// + /// Horribly dangerous internal API. See `.insert_key(...)`. + pub(crate) fn insert_value(&mut self, term_info: &TermInfo) -> io::Result<()> { + self.sstable_writer.write_value(term_info); + Ok(()) + } + + /// Finalize writing the builder, and returns the underlying + /// `Write` object. + pub fn finish(self) -> io::Result { + self.sstable_writer.finalize() + } +} + +static EMPTY_TERM_DICT_FILE: Lazy = Lazy::new(|| { + let term_dictionary_data: Vec = TermDictionaryBuilder::create(Vec::::new()) + .expect("Creating a TermDictionaryBuilder in a Vec should never fail") + .finish() + .expect("Writing in a Vec should never fail"); + FileSlice::from(term_dictionary_data) +}); + +/// The term dictionary contains all of the terms in +/// `tantivy index` in a sorted manner. +/// +/// The `Fst` crate is used to associate terms to their +/// respective `TermOrdinal`. The `TermInfoStore` then makes it +/// possible to fetch the associated `TermInfo`. +pub struct TermDictionary { + sstable_slice: FileSlice, + sstable_index: SSTableIndex, + num_terms: u64, +} + +impl TermDictionary { + pub(crate) fn sstable_reader(&self) -> io::Result> { + let data = self.sstable_slice.read_bytes()?; + Ok(TermInfoSSTable::reader(data)) + } + + pub(crate) fn sstable_reader_block( + &self, + block_addr: BlockAddr, + ) -> io::Result> { + let data = self.sstable_slice.read_bytes_slice( + block_addr.start_offset as usize, + block_addr.end_offset as usize, + )?; + Ok(TermInfoSSTable::reader(data)) + } + + pub(crate) fn sstable_delta_reader(&self) -> io::Result> { + let data = self.sstable_slice.read_bytes()?; + Ok(TermInfoSSTable::delta_reader(data)) + } + + /// Opens a `TermDictionary`. + pub fn open(term_dictionary_file: FileSlice) -> crate::Result { + let (main_slice, footer_len_slice) = term_dictionary_file.split_from_end(16); + let mut footer_len_bytes: OwnedBytes = footer_len_slice.read_bytes()?; + let index_offset = u64::deserialize(&mut footer_len_bytes)?; + let num_terms = u64::deserialize(&mut footer_len_bytes)?; + let (sstable_slice, index_slice) = main_slice.split(index_offset as usize); + // dbg!(index_slice.len()); + let sstable_index_bytes = index_slice.read_bytes()?; + let sstable_index = SSTableIndex::load(sstable_index_bytes.as_slice()); + // dbg!(&sstable_index); + Ok(TermDictionary { + sstable_slice, + sstable_index, + num_terms, + }) + } + + pub fn from_bytes(owned_bytes: OwnedBytes) -> crate::Result { + TermDictionary::open(FileSlice::new(Box::new(owned_bytes))) + } + + /// Creates an empty term dictionary which contains no terms. + pub fn empty() -> Self { + TermDictionary::open(EMPTY_TERM_DICT_FILE.clone()).unwrap() + } + + /// Returns the number of terms in the dictionary. + /// Term ordinals range from 0 to `num_terms() - 1`. + pub fn num_terms(&self) -> usize { + self.num_terms as usize + } + + /// Returns the ordinal associated to a given term. + pub fn term_ord>(&self, key: K) -> io::Result> { + let mut term_ord = 0u64; + let key_bytes = key.as_ref(); + let mut sstable_reader = self.sstable_reader()?; + while sstable_reader.advance().unwrap_or(false) { + if sstable_reader.key() == key_bytes { + return Ok(Some(term_ord)); + } + term_ord += 1; + } + Ok(None) + } + + /// Returns the term associated to a given term ordinal. + /// + /// Term ordinals are defined as the position of the term in + /// the sorted list of terms. + /// + /// Returns true iff the term has been found. + /// + /// Regardless of whether the term is found or not, + /// the buffer may be modified. + pub fn ord_to_term(&self, ord: TermOrdinal, bytes: &mut Vec) -> io::Result { + let mut sstable_reader = self.sstable_reader()?; + bytes.clear(); + for _ in 0..(ord + 1) { + if !sstable_reader.advance().unwrap_or(false) { + return Ok(false); + } + } + bytes.extend_from_slice(sstable_reader.key()); + Ok(true) + } + + /// Returns the number of terms in the dictionary. + pub fn term_info_from_ord(&self, term_ord: TermOrdinal) -> io::Result { + let mut sstable_reader = self.sstable_reader()?; + for _ in 0..(term_ord + 1) { + if !sstable_reader.advance().unwrap_or(false) { + return Ok(TermInfo::default()); + } + } + Ok(sstable_reader.value().clone()) + } + + /// Lookups the value corresponding to the key. + pub fn get>(&self, key: K) -> io::Result> { + if let Some(block_addr) = self.sstable_index.search(key.as_ref()) { + let mut sstable_reader = self.sstable_reader_block(block_addr)?; + let key_bytes = key.as_ref(); + while sstable_reader.advance().unwrap_or(false) { + if sstable_reader.key() == key_bytes { + let term_info = sstable_reader.value().clone(); + return Ok(Some(term_info)); + } + } + } + Ok(None) + } + + // Returns a range builder, to stream all of the terms + // within an interval. + pub fn range(&self) -> TermStreamerBuilder<'_> { + TermStreamerBuilder::new(self, AlwaysMatch) + } + + // A stream of all the sorted terms. [See also `.stream_field()`](#method.stream_field) + pub fn stream(&self) -> io::Result> { + self.range().into_stream() + } + + // Returns a search builder, to stream all of the terms + // within the Automaton + pub fn search<'a, A: Automaton + 'a>(&'a self, automaton: A) -> TermStreamerBuilder<'a, A> + where + A::State: Clone, + { + TermStreamerBuilder::::new(self, automaton) + } +}