diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7c733d9b4..f2b79c2b7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,16 +24,23 @@ jobs: toolchain: nightly override: true components: rustfmt + - name: Install latest nightly to test also against unstable feature flag uses: actions-rs/toolchain@v1 with: toolchain: stable override: true components: rustfmt, clippy + - name: Run tests run: cargo +stable test --features mmap,brotli-compression,lz4-compression,snappy-compression,failpoints --verbose --workspace + + - name: Run tests quickwit feature + run: cargo +stable test --features mmap,quickwit,failpoints --verbose --workspace + - name: Check Formatting run: cargo +nightly fmt --all -- --check + - uses: actions-rs/clippy-check@v1 with: toolchain: stable diff --git a/Cargo.toml b/Cargo.toml index 6fbd485bb..2c9bb068d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,8 @@ fastdivide = "0.4" itertools = "0.10.0" measure_time = "0.8.0" pretty_assertions = "1.1.0" +serde_cbor = {version="0.11", optional=true} +async-trait = "0.1" [target.'cfg(windows)'.dependencies] winapi = "0.3.9" @@ -94,6 +96,8 @@ snappy-compression = ["snap"] failpoints = ["fail/failpoints"] unstable = [] # useful for benches. +quickwit = ["serde_cbor"] + [workspace] members = ["query-grammar", "bitpacker", "common", "fastfield_codecs", "ownedbytes"] diff --git a/README.md b/README.md index a58038cba..0d42430d4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ - [![Docs](https://docs.rs/tantivy/badge.svg)](https://docs.rs/crate/tantivy/) [![Build Status](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml/badge.svg)](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/quickwit-oss/tantivy/branch/main/graph/badge.svg)](https://codecov.io/gh/quickwit-oss/tantivy) diff --git a/src/core/inverted_index_reader.rs b/src/core/inverted_index_reader.rs index d1fb791f6..81482fb7a 100644 --- a/src/core/inverted_index_reader.rs +++ b/src/core/inverted_index_reader.rs @@ -197,3 +197,36 @@ impl InvertedIndexReader { .unwrap_or(0u32)) } } + +#[cfg(feature = "quickwit")] +impl InvertedIndexReader { + pub(crate) async fn get_term_info_async( + &self, + term: &Term, + ) -> crate::AsyncIoResult> { + self.termdict.get_async(term.value_bytes()).await + } + + /// Returns a block postings given a `Term`. + /// This method is for an advanced usage only. + /// + /// Most user should prefer using `read_postings` instead. + pub async fn warm_postings( + &self, + term: &Term, + with_positions: bool, + ) -> crate::AsyncIoResult<()> { + let term_info_opt = self.get_term_info_async(term).await?; + if let Some(term_info) = term_info_opt { + self.postings_file_slice + .read_bytes_slice_async(term_info.postings_range.clone()) + .await?; + if with_positions { + self.positions_file_slice + .read_bytes_slice_async(term_info.positions_range.clone()) + .await?; + } + } + Ok(()) + } +} diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 37bafd35e..1b8f1257e 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -110,6 +110,13 @@ impl Searcher { store_reader.get(doc_address.doc_id) } + /// Fetches a document in an asynchronous manner. + #[cfg(feature = "quickwit")] + pub async fn doc_async(&self, doc_address: DocAddress) -> crate::Result { + let store_reader = &self.store_readers[doc_address.segment_ord as usize]; + store_reader.get_async(doc_address.doc_id).await + } + /// Access the schema associated to the index of this searcher. pub fn schema(&self) -> &Schema { &self.schema diff --git a/src/directory/file_slice.rs b/src/directory/file_slice.rs index 076caeeb1..dac821f84 100644 --- a/src/directory/file_slice.rs +++ b/src/directory/file_slice.rs @@ -2,6 +2,7 @@ use std::ops::{Deref, Range}; use std::sync::{Arc, Weak}; use std::{fmt, io}; +use async_trait::async_trait; use common::HasLen; use stable_deref_trait::StableDeref; @@ -18,18 +19,35 @@ pub type WeakArcBytes = Weak + Send + Sync + 'static>; /// The underlying behavior is therefore specific to the `Directory` that created it. /// Despite its name, a `FileSlice` may or may not directly map to an actual file /// on the filesystem. + +#[async_trait] pub trait FileHandle: 'static + Send + Sync + HasLen + fmt::Debug { /// Reads a slice of bytes. /// /// This method may panic if the range requested is invalid. fn read_bytes(&self, range: Range) -> io::Result; + + #[cfg(feature = "quickwit")] + #[doc(hidden)] + async fn read_bytes_async( + &self, + _byte_range: Range, + ) -> crate::AsyncIoResult { + Err(crate::error::AsyncIoError::AsyncUnsupported) + } } +#[async_trait] impl FileHandle for &'static [u8] { fn read_bytes(&self, range: Range) -> io::Result { let bytes = &self[range]; Ok(OwnedBytes::new(bytes)) } + + #[cfg(feature = "quickwit")] + async fn read_bytes_async(&self, byte_range: Range) -> crate::AsyncIoResult { + Ok(self.read_bytes(byte_range)?) + } } impl From for FileSlice @@ -102,6 +120,12 @@ impl FileSlice { self.data.read_bytes(self.range.clone()) } + #[cfg(feature = "quickwit")] + #[doc(hidden)] + pub async fn read_bytes_async(&self) -> crate::AsyncIoResult { + self.data.read_bytes_async(self.range.clone()).await + } + /// Reads a specific slice of data. /// /// This is equivalent to running `file_slice.slice(from, to).read_bytes()`. @@ -116,6 +140,23 @@ impl FileSlice { .read_bytes(self.range.start + range.start..self.range.start + range.end) } + #[cfg(feature = "quickwit")] + #[doc(hidden)] + pub async fn read_bytes_slice_async( + &self, + byte_range: Range, + ) -> crate::AsyncIoResult { + assert!( + self.range.start + byte_range.end <= self.range.end, + "`to` exceeds the fileslice length" + ); + self.data + .read_bytes_async( + self.range.start + byte_range.start..self.range.start + byte_range.end, + ) + .await + } + /// Splits the FileSlice at the given offset and return two file slices. /// `file_slice[..split_offset]` and `file_slice[split_offset..]`. /// @@ -160,10 +201,16 @@ impl FileSlice { } } +#[async_trait] impl FileHandle for FileSlice { fn read_bytes(&self, range: Range) -> io::Result { self.read_bytes_slice(range) } + + #[cfg(feature = "quickwit")] + async fn read_bytes_async(&self, byte_range: Range) -> crate::AsyncIoResult { + self.read_bytes_slice_async(byte_range).await + } } impl HasLen for FileSlice { @@ -172,6 +219,19 @@ impl HasLen for FileSlice { } } +#[async_trait] +impl FileHandle for OwnedBytes { + fn read_bytes(&self, range: Range) -> io::Result { + Ok(self.slice(range)) + } + + #[cfg(feature = "quickwit")] + async fn read_bytes_async(&self, range: Range) -> crate::AsyncIoResult { + let bytes = self.read_bytes(range)?; + Ok(bytes) + } +} + #[cfg(test)] mod tests { use std::io; diff --git a/src/directory/mmap_directory.rs b/src/directory/mmap_directory.rs index e56036dcb..b17c3625f 100644 --- a/src/directory/mmap_directory.rs +++ b/src/directory/mmap_directory.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; -use std::convert::From; use std::fs::{self, File, OpenOptions}; -use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write}; +use std::io::{self, BufWriter, Read, Seek, Write}; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; @@ -265,7 +264,7 @@ impl Write for SafeFileWriter { } impl Seek for SafeFileWriter { - fn seek(&mut self, pos: SeekFrom) -> io::Result { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { self.0.seek(pos) } } diff --git a/src/directory/mod.rs b/src/directory/mod.rs index 62ed18bc0..cd7682a9b 100644 --- a/src/directory/mod.rs +++ b/src/directory/mod.rs @@ -9,7 +9,6 @@ mod file_slice; mod file_watcher; mod footer; mod managed_directory; -mod owned_bytes; mod ram_directory; mod watch_event_router; @@ -22,13 +21,13 @@ use std::io::BufWriter; use std::path::PathBuf; pub use common::{AntiCallToken, TerminatingWrite}; +pub use ownedbytes::OwnedBytes; pub(crate) use self::composite_file::{CompositeFile, CompositeWrite}; pub use self::directory::{Directory, DirectoryClone, DirectoryLock}; pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK}; pub(crate) use self::file_slice::{ArcBytes, WeakArcBytes}; pub use self::file_slice::{FileHandle, FileSlice}; -pub use self::owned_bytes::OwnedBytes; pub use self::ram_directory::RamDirectory; pub use self::watch_event_router::{WatchCallback, WatchCallbackList, WatchHandle}; diff --git a/src/directory/owned_bytes.rs b/src/directory/owned_bytes.rs deleted file mode 100644 index 39ba93c1a..000000000 --- a/src/directory/owned_bytes.rs +++ /dev/null @@ -1,12 +0,0 @@ -use std::io; -use std::ops::Range; - -pub use ownedbytes::OwnedBytes; - -use crate::directory::FileHandle; - -impl FileHandle for OwnedBytes { - fn read_bytes(&self, range: Range) -> io::Result { - Ok(self.slice(range)) - } -} diff --git a/src/error.rs b/src/error.rs index ba8f520cb..25b5989cc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,8 @@ use std::path::PathBuf; use std::sync::PoisonError; use std::{fmt, io}; +use thiserror::Error; + use crate::directory::error::{ Incompatibility, LockError, OpenDirectoryError, OpenReadError, OpenWriteError, }; @@ -38,9 +40,9 @@ impl DataCorruption { impl fmt::Debug for DataCorruption { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(f, "Data corruption: ")?; + write!(f, "Data corruption")?; if let Some(ref filepath) = &self.filepath { - write!(f, "(in file `{:?}`)", filepath)?; + write!(f, " (in file `{:?}`)", filepath)?; } write!(f, ": {}.", self.comment)?; Ok(()) @@ -97,6 +99,28 @@ pub enum TantivyError { IncompatibleIndex(Incompatibility), } +#[cfg(feature = "quickwit")] +#[derive(Error, Debug)] +#[doc(hidden)] +pub enum AsyncIoError { + #[error("io::Error `{0}`")] + Io(#[from] io::Error), + #[error("Asynchronous API is unsupported by this directory")] + AsyncUnsupported, +} + +#[cfg(feature = "quickwit")] +impl From for TantivyError { + fn from(async_io_err: AsyncIoError) -> Self { + match async_io_err { + AsyncIoError::Io(io_err) => TantivyError::from(io_err), + AsyncIoError::AsyncUnsupported => { + TantivyError::SystemError(format!("{:?}", async_io_err)) + } + } + } +} + impl From for TantivyError { fn from(data_corruption: DataCorruption) -> TantivyError { TantivyError::DataCorruption(data_corruption) diff --git a/src/fastfield/readers.rs b/src/fastfield/readers.rs index 065772fe8..f0e7bb512 100644 --- a/src/fastfield/readers.rs +++ b/src/fastfield/readers.rs @@ -55,7 +55,8 @@ impl FastFieldReaders { self.fast_fields_composite.space_usage() } - fn fast_field_data(&self, field: Field, idx: usize) -> crate::Result { + #[doc(hidden)] + pub fn fast_field_data(&self, field: Field, idx: usize) -> crate::Result { self.fast_fields_composite .open_read_with_idx(field, idx) .ok_or_else(|| { diff --git a/src/indexer/merger.rs b/src/indexer/merger.rs index 92f724a78..775baca74 100644 --- a/src/indexer/merger.rs +++ b/src/indexer/merger.rs @@ -278,7 +278,7 @@ impl IndexMerger { mut term_ord_mappings: HashMap, doc_id_mapping: &SegmentDocIdMapping, ) -> crate::Result<()> { - debug_time!("write_fast_fields"); + debug_time!("write-fast-fields"); for (field, field_entry) in self.schema.fields() { let field_type = field_entry.field_type(); @@ -597,7 +597,7 @@ impl IndexMerger { fast_field_serializer: &mut CompositeFastFieldSerializer, doc_id_mapping: &SegmentDocIdMapping, ) -> crate::Result<()> { - debug_time!("write_hierarchical_facet_field"); + debug_time!("write-hierarchical-facet-field"); // Multifastfield consists of 2 fastfields. // The first serves as an index into the second one and is stricly increasing. @@ -827,7 +827,7 @@ impl IndexMerger { fieldnorm_reader: Option, doc_id_mapping: &SegmentDocIdMapping, ) -> crate::Result> { - debug_time!("write_postings_for_field"); + debug_time!("write-postings-for-field"); let mut positions_buffer: Vec = Vec::with_capacity(1_000); let mut delta_computer = DeltaComputer::new(); @@ -1023,7 +1023,8 @@ impl IndexMerger { store_writer: &mut StoreWriter, doc_id_mapping: &SegmentDocIdMapping, ) -> crate::Result<()> { - debug_time!("write_storable_fields"); + debug_time!("write-storable-fields"); + debug!("write-storable-field"); let store_readers: Vec<_> = self .readers @@ -1036,6 +1037,7 @@ impl IndexMerger { .map(|(i, store)| store.iter_raw(self.readers[i].alive_bitset())) .collect(); if !doc_id_mapping.is_trivial() { + debug!("non-trivial-doc-id-mapping"); for (old_doc_id, reader_ordinal) in doc_id_mapping.iter() { let doc_bytes_it = &mut document_iterators[*reader_ordinal as usize]; if let Some(doc_bytes_res) = doc_bytes_it.next() { @@ -1050,6 +1052,7 @@ impl IndexMerger { } } } else { + debug!("trivial-doc-id-mapping"); for reader in &self.readers { let store_reader = reader.get_store_reader()?; if reader.has_deletes() @@ -1099,10 +1102,11 @@ impl IndexMerger { } else { self.get_doc_id_from_concatenated_data()? }; - + debug!("write-fieldnorms"); if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() { self.write_fieldnorms(fieldnorms_serializer, &doc_id_mapping)?; } + debug!("write-postings"); let fieldnorm_data = serializer .segment() .open_read(SegmentComponent::FieldNorms)?; @@ -1112,12 +1116,15 @@ impl IndexMerger { fieldnorm_readers, &doc_id_mapping, )?; + debug!("write-fastfields"); self.write_fast_fields( serializer.get_fast_field_serializer(), term_ord_mappings, &doc_id_mapping, )?; + debug!("write-storagefields"); self.write_storable_fields(serializer.get_store_writer(), &doc_id_mapping)?; + debug!("close-serializer"); serializer.close()?; Ok(self.max_doc) } diff --git a/src/indexer/segment_writer.rs b/src/indexer/segment_writer.rs index 76fdfaf95..2a52abdc5 100644 --- a/src/indexer/segment_writer.rs +++ b/src/indexer/segment_writer.rs @@ -338,6 +338,7 @@ fn remap_and_write( mut serializer: SegmentSerializer, doc_id_map: Option<&DocIdMapping>, ) -> crate::Result<()> { + debug!("remap-and-write"); if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() { fieldnorms_writer.serialize(fieldnorms_serializer, doc_id_map)?; } @@ -353,12 +354,14 @@ fn remap_and_write( schema, serializer.get_postings_serializer(), )?; + debug!("fastfield-serialize"); fast_field_writers.serialize( serializer.get_fast_field_serializer(), &term_ord_map, doc_id_map, )?; + debug!("resort-docstore"); // finalize temp docstore and create version, which reflects the doc_id_map if let Some(doc_id_map) = doc_id_map { let store_write = serializer @@ -381,6 +384,7 @@ fn remap_and_write( } } + debug!("serializer-close"); serializer.close()?; Ok(()) @@ -585,8 +589,8 @@ mod tests { let mut doc = Document::default(); let json_val: serde_json::Map = serde_json::from_str(r#"{"mykey": "repeated token token"}"#).unwrap(); - doc.add_json_object(json_field, json_val.clone()); - let index = Index::create_in_ram(schema.clone()); + doc.add_json_object(json_field, json_val); + let index = Index::create_in_ram(schema); let mut writer = index.writer_for_tests().unwrap(); writer.add_document(doc).unwrap(); writer.commit().unwrap(); @@ -631,7 +635,7 @@ mod tests { let json_val: serde_json::Map = serde_json::from_str(r#"{"mykey": "two tokens"}"#).unwrap(); let doc = doc!(json_field=>json_val); - let index = Index::create_in_ram(schema.clone()); + let index = Index::create_in_ram(schema); let mut writer = index.writer_for_tests().unwrap(); writer.add_document(doc).unwrap(); writer.commit().unwrap(); @@ -679,7 +683,7 @@ mod tests { ) .unwrap(); let doc = doc!(json_field=>json_val); - let index = Index::create_in_ram(schema.clone()); + let index = Index::create_in_ram(schema); let mut writer = index.writer_for_tests().unwrap(); writer.add_document(doc).unwrap(); writer.commit().unwrap(); diff --git a/src/lib.rs b/src/lib.rs index fc7dacc0d..aa70cbdc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,6 +134,10 @@ pub use crate::error::TantivyError; /// and instead, refer to this as `crate::Result`. pub type Result = std::result::Result; +/// Result for an Async io operation. +#[cfg(feature = "quickwit")] +pub type AsyncIoResult = std::result::Result; + /// Tantivy DateTime pub type DateTime = chrono::DateTime; diff --git a/src/postings/postings_writer.rs b/src/postings/postings_writer.rs index 020931c39..219313e4b 100644 --- a/src/postings/postings_writer.rs +++ b/src/postings/postings_writer.rs @@ -60,12 +60,10 @@ pub(crate) fn serialize_postings( Vec::with_capacity(ctx.term_index.len()); term_offsets.extend(ctx.term_index.iter()); term_offsets.sort_unstable_by_key(|(k, _, _)| k.clone()); - let mut unordered_term_mappings: HashMap> = HashMap::new(); let field_offsets = make_field_partition(&term_offsets); - for (field, byte_offsets) in field_offsets { let field_entry = schema.get_field_entry(field); match *field_entry.field_type() { diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index dcbe60c0b..5a8e507fd 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -38,7 +38,7 @@ impl PhraseWeight { Ok(FieldNormReader::constant(reader.max_doc(), 1)) } - fn phrase_scorer( + pub(crate) fn phrase_scorer( &self, reader: &SegmentReader, boost: Score, diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 9824487e9..4e742bc44 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -93,6 +93,10 @@ impl TermWeight { } } + pub fn term(&self) -> &Term { + &self.term + } + pub(crate) fn specialized_scorer( &self, reader: &SegmentReader, diff --git a/src/store/reader.rs b/src/store/reader.rs index 9ad6e307e..f3dadda2a 100644 --- a/src/store/reader.rs +++ b/src/store/reader.rs @@ -4,11 +4,12 @@ use std::sync::{Arc, Mutex}; use common::{BinarySerializable, HasLen, VInt}; use lru::LruCache; +use ownedbytes::OwnedBytes; use super::footer::DocStoreFooter; use super::index::SkipIndex; use super::Compressor; -use crate::directory::{FileSlice, OwnedBytes}; +use crate::directory::FileSlice; use crate::error::DataCorruption; use crate::fastfield::AliveBitSet; use crate::schema::Document; @@ -239,6 +240,60 @@ impl StoreReader { } } +#[cfg(feature = "quickwit")] +impl StoreReader { + async fn read_block_async(&self, checkpoint: &Checkpoint) -> crate::AsyncIoResult { + if let Some(block) = self.cache.lock().unwrap().get(&checkpoint.byte_range.start) { + self.cache_hits.fetch_add(1, Ordering::SeqCst); + return Ok(block.clone()); + } + + self.cache_misses.fetch_add(1, Ordering::SeqCst); + + let compressed_block = self + .data + .slice(checkpoint.byte_range.clone()) + .read_bytes_async() + .await?; + let mut decompressed_block = vec![]; + self.compressor + .decompress(compressed_block.as_slice(), &mut decompressed_block)?; + + let block = OwnedBytes::new(decompressed_block); + self.cache + .lock() + .unwrap() + .put(checkpoint.byte_range.start, block.clone()); + + Ok(block) + } + + /// Fetches a document asynchronously. + async fn get_document_bytes_async(&self, doc_id: DocId) -> crate::Result { + let checkpoint = self.block_checkpoint(doc_id).ok_or_else(|| { + crate::TantivyError::InvalidArgument(format!("Failed to lookup Doc #{}.", doc_id)) + })?; + let block = self.read_block_async(&checkpoint).await?; + let mut cursor = &block[..]; + let cursor_len_before = cursor.len(); + for _ in checkpoint.doc_range.start..doc_id { + let doc_length = VInt::deserialize(&mut cursor)?.val() as usize; + cursor = &cursor[doc_length..]; + } + let doc_length = VInt::deserialize(&mut cursor)?.val() as usize; + let start_pos = cursor_len_before - cursor.len(); + let end_pos = cursor_len_before - cursor.len() + doc_length; + Ok(block.slice(start_pos..end_pos)) + } + + /// Reads raw bytes of a given document. Returns `RawDocument`, which contains the block of a + /// document and its start and end position within the block. + pub(crate) async fn get_async(&self, doc_id: DocId) -> crate::Result { + let mut doc_bytes = self.get_document_bytes_async(doc_id).await?; + Ok(Document::deserialize(&mut doc_bytes)?) + } +} + #[cfg(test)] mod tests { use std::path::Path; diff --git a/src/termdict/merger.rs b/src/termdict/fst_termdict/merger.rs similarity index 91% rename from src/termdict/merger.rs rename to src/termdict/fst_termdict/merger.rs index 167b5d97d..037641747 100644 --- a/src/termdict/merger.rs +++ b/src/termdict/fst_termdict/merger.rs @@ -51,18 +51,19 @@ impl<'a> TermMerger<'a> { /// Returns `true` if there is indeed another term /// `false` if there is none. pub fn advance(&mut self) -> bool { - if let Some((k, values)) = self.union.next() { - self.current_key.clear(); - self.current_key.extend_from_slice(k); - self.current_segment_and_term_ordinals.clear(); - self.current_segment_and_term_ordinals - .extend_from_slice(values); - self.current_segment_and_term_ordinals - .sort_by_key(|iv| iv.index); - true + let (key, values) = if let Some((key, values)) = self.union.next() { + (key, values) } else { - false - } + return false; + }; + self.current_key.clear(); + self.current_key.extend_from_slice(key); + self.current_segment_and_term_ordinals.clear(); + self.current_segment_and_term_ordinals + .extend_from_slice(values); + self.current_segment_and_term_ordinals + .sort_by_key(|iv| iv.index); + true } /// Returns the current term. diff --git a/src/termdict/fst_termdict/mod.rs b/src/termdict/fst_termdict/mod.rs index 2ed4db970..a809360db 100644 --- a/src/termdict/fst_termdict/mod.rs +++ b/src/termdict/fst_termdict/mod.rs @@ -18,9 +18,11 @@ //! //! A second datastructure makes it possible to access a //! [`TermInfo`](../postings/struct.TermInfo.html). +mod merger; mod streamer; mod term_info_store; mod termdict; +pub use self::merger::TermMerger; pub use self::streamer::{TermStreamer, TermStreamerBuilder}; pub use self::termdict::{TermDictionary, TermDictionaryBuilder}; diff --git a/src/termdict/mod.rs b/src/termdict/mod.rs index a1fe98a02..324b6a3a6 100644 --- a/src/termdict/mod.rs +++ b/src/termdict/mod.rs @@ -19,16 +19,41 @@ //! A second datastructure makes it possible to access a //! [`TermInfo`](../postings/struct.TermInfo.html). +#[cfg(not(feature = "quickwit"))] mod fst_termdict; +#[cfg(not(feature = "quickwit"))] use fst_termdict as termdict; -mod merger; +#[cfg(feature = "quickwit")] +mod sstable_termdict; +#[cfg(feature = "quickwit")] +use sstable_termdict as termdict; +use tantivy_fst::automaton::AlwaysMatch; -pub use self::merger::TermMerger; -pub use self::termdict::{TermDictionary, TermDictionaryBuilder, TermStreamer}; +#[cfg(test)] +mod tests; /// Position of the term in the sorted list of terms. pub type TermOrdinal = u64; -#[cfg(test)] -mod tests; +/// The term dictionary contains all of the terms in +/// `tantivy index` in a sorted manner. +pub type TermDictionary = self::termdict::TermDictionary; + +/// Builder for the new term dictionary. +/// +/// Inserting must be done in the order of the `keys`. +pub type TermDictionaryBuilder = self::termdict::TermDictionaryBuilder; + +/// Given a list of sorted term streams, +/// returns an iterator over sorted unique terms. +/// +/// The item yield is actually a pair with +/// - the term +/// - a slice with the ordinal of the segments containing +/// the terms. +pub type TermMerger<'a> = self::termdict::TermMerger<'a>; + +/// `TermStreamer` acts as a cursor over a range of terms of a segment. +/// Terms are guaranteed to be sorted. +pub type TermStreamer<'a, A = AlwaysMatch> = self::termdict::TermStreamer<'a, A>; diff --git a/src/termdict/sstable_termdict/merger.rs b/src/termdict/sstable_termdict/merger.rs new file mode 100644 index 000000000..6c98498ec --- /dev/null +++ b/src/termdict/sstable_termdict/merger.rs @@ -0,0 +1,120 @@ +use std::cmp::Ordering; +use std::collections::BinaryHeap; + +use crate::postings::TermInfo; +use crate::termdict::{TermOrdinal, TermStreamer}; + +pub struct HeapItem<'a> { + pub streamer: TermStreamer<'a>, + pub segment_ord: usize, +} + +impl<'a> PartialEq for HeapItem<'a> { + fn eq(&self, other: &Self) -> bool { + self.segment_ord == other.segment_ord + } +} + +impl<'a> Eq for HeapItem<'a> {} + +impl<'a> PartialOrd for HeapItem<'a> { + fn partial_cmp(&self, other: &HeapItem<'a>) -> Option { + Some(self.cmp(other)) + } +} + +impl<'a> Ord for HeapItem<'a> { + fn cmp(&self, other: &HeapItem<'a>) -> Ordering { + (&other.streamer.key(), &other.segment_ord).cmp(&(&self.streamer.key(), &self.segment_ord)) + } +} + +/// Given a list of sorted term streams, +/// returns an iterator over sorted unique terms. +/// +/// The item yield is actually a pair with +/// - the term +/// - a slice with the ordinal of the segments containing +/// the terms. +pub struct TermMerger<'a> { + heap: BinaryHeap>, + current_streamers: Vec>, +} + +impl<'a> TermMerger<'a> { + /// Stream of merged term dictionary + pub fn new(streams: Vec>) -> TermMerger<'a> { + TermMerger { + heap: BinaryHeap::new(), + current_streamers: streams + .into_iter() + .enumerate() + .map(|(ord, streamer)| HeapItem { + streamer, + segment_ord: ord, + }) + .collect(), + } + } + + pub(crate) fn matching_segments<'b: 'a>( + &'b self, + ) -> impl 'b + Iterator { + self.current_streamers + .iter() + .map(|heap_item| (heap_item.segment_ord, heap_item.streamer.term_ord())) + } + + fn advance_segments(&mut self) { + let streamers = &mut self.current_streamers; + let heap = &mut self.heap; + for mut heap_item in streamers.drain(..) { + if heap_item.streamer.advance() { + heap.push(heap_item); + } + } + } + + /// Advance the term iterator to the next term. + /// Returns true if there is indeed another term + /// False if there is none. + pub fn advance(&mut self) -> bool { + self.advance_segments(); + if let Some(head) = self.heap.pop() { + self.current_streamers.push(head); + while let Some(next_streamer) = self.heap.peek() { + if self.current_streamers[0].streamer.key() != next_streamer.streamer.key() { + break; + } + let next_heap_it = self.heap.pop().unwrap(); // safe : we peeked beforehand + self.current_streamers.push(next_heap_it); + } + true + } else { + false + } + } + + /// Returns the current term. + /// + /// This method may be called + /// iff advance() has been called before + /// and "true" was returned. + pub fn key(&self) -> &[u8] { + self.current_streamers[0].streamer.key() + } + + /// Returns the sorted list of segment ordinals + /// that include the current term. + /// + /// This method may be called + /// iff advance() has been called before + /// and "true" was returned. + pub fn current_segment_ords_and_term_infos<'b: 'a>( + &'b self, + ) -> impl 'b + Iterator { + self.current_streamers + .iter() + .map(|heap_item| (heap_item.segment_ord, heap_item.streamer.value().clone())) + } +} diff --git a/src/termdict/sstable_termdict/mod.rs b/src/termdict/sstable_termdict/mod.rs new file mode 100644 index 000000000..5b1174686 --- /dev/null +++ b/src/termdict/sstable_termdict/mod.rs @@ -0,0 +1,144 @@ +use std::io; + +mod merger; +mod sstable; +mod streamer; +mod termdict; + +use std::iter::ExactSizeIterator; + +use common::VInt; + +pub use self::merger::TermMerger; +use self::sstable::value::{ValueReader, ValueWriter}; +use self::sstable::{BlockReader, SSTable}; +pub use self::streamer::{TermStreamer, TermStreamerBuilder}; +pub use self::termdict::{TermDictionary, TermDictionaryBuilder}; +use crate::postings::TermInfo; + +pub struct TermSSTable; + +impl SSTable for TermSSTable { + type Value = TermInfo; + type Reader = TermInfoReader; + type Writer = TermInfoWriter; +} + +#[derive(Default)] +pub struct TermInfoReader { + term_infos: Vec, +} + +impl ValueReader for TermInfoReader { + type Value = TermInfo; + + fn value(&self, idx: usize) -> &TermInfo { + &self.term_infos[idx] + } + + fn read(&mut self, reader: &mut BlockReader) -> io::Result<()> { + self.term_infos.clear(); + let num_els = VInt::deserialize_u64(reader)?; + let mut postings_start = VInt::deserialize_u64(reader)? as usize; + let mut positions_start = VInt::deserialize_u64(reader)? as usize; + for _ in 0..num_els { + let doc_freq = VInt::deserialize_u64(reader)? as u32; + let postings_num_bytes = VInt::deserialize_u64(reader)?; + let positions_num_bytes = VInt::deserialize_u64(reader)?; + let postings_end = postings_start + postings_num_bytes as usize; + let positions_end = positions_start + positions_num_bytes as usize; + let term_info = TermInfo { + doc_freq, + postings_range: postings_start..postings_end, + positions_range: positions_start..positions_end, + }; + self.term_infos.push(term_info); + postings_start = postings_end; + positions_start = positions_end; + } + Ok(()) + } +} + +#[derive(Default)] +pub struct TermInfoWriter { + term_infos: Vec, +} + +impl ValueWriter for TermInfoWriter { + type Value = TermInfo; + + fn write(&mut self, term_info: &TermInfo) { + self.term_infos.push(term_info.clone()); + } + + fn write_block(&mut self, buffer: &mut Vec) { + VInt(self.term_infos.len() as u64).serialize_into_vec(buffer); + if self.term_infos.is_empty() { + return; + } + VInt(self.term_infos[0].postings_range.start as u64).serialize_into_vec(buffer); + VInt(self.term_infos[0].positions_range.start as u64).serialize_into_vec(buffer); + for term_info in &self.term_infos { + VInt(term_info.doc_freq as u64).serialize_into_vec(buffer); + VInt(term_info.postings_range.len() as u64).serialize_into_vec(buffer); + VInt(term_info.positions_range.len() as u64).serialize_into_vec(buffer); + } + self.term_infos.clear(); + } +} + +#[cfg(test)] +mod tests { + use std::io; + + use super::BlockReader; + use crate::directory::OwnedBytes; + use crate::postings::TermInfo; + use crate::termdict::sstable_termdict::sstable::value::{ValueReader, ValueWriter}; + use crate::termdict::sstable_termdict::TermInfoReader; + + #[test] + fn test_block_terminfos() -> io::Result<()> { + let mut term_info_writer = super::TermInfoWriter::default(); + term_info_writer.write(&TermInfo { + doc_freq: 120u32, + postings_range: 17..45, + positions_range: 10..122, + }); + term_info_writer.write(&TermInfo { + doc_freq: 10u32, + postings_range: 45..450, + positions_range: 122..1100, + }); + term_info_writer.write(&TermInfo { + doc_freq: 17u32, + postings_range: 450..462, + positions_range: 1100..1302, + }); + let mut buffer = Vec::new(); + term_info_writer.write_block(&mut buffer); + let mut block_reader = make_block_reader(&buffer[..]); + let mut term_info_reader = TermInfoReader::default(); + term_info_reader.read(&mut block_reader)?; + assert_eq!( + term_info_reader.value(0), + &TermInfo { + doc_freq: 120u32, + postings_range: 17..45, + positions_range: 10..122 + } + ); + assert!(block_reader.buffer().is_empty()); + Ok(()) + } + + fn make_block_reader(data: &[u8]) -> BlockReader { + let mut buffer = (data.len() as u32).to_le_bytes().to_vec(); + buffer.extend_from_slice(data); + let owned_bytes = OwnedBytes::new(buffer); + let mut block_reader = BlockReader::new(Box::new(owned_bytes)); + block_reader.read_block().unwrap(); + block_reader + } +} diff --git a/src/termdict/sstable_termdict/sstable/block_reader.rs b/src/termdict/sstable_termdict/sstable/block_reader.rs new file mode 100644 index 000000000..58ccf457a --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/block_reader.rs @@ -0,0 +1,81 @@ +use std::io::{self, Read}; + +use byteorder::{LittleEndian, ReadBytesExt}; + +pub struct BlockReader<'a> { + buffer: Vec, + reader: Box, + offset: usize, +} + +impl<'a> BlockReader<'a> { + pub fn new(reader: Box) -> BlockReader<'a> { + BlockReader { + buffer: Vec::new(), + reader, + offset: 0, + } + } + + pub fn deserialize_u64(&mut self) -> u64 { + let (num_bytes, val) = super::vint::deserialize_read(self.buffer()); + self.advance(num_bytes); + val + } + + #[inline(always)] + pub fn buffer_from_to(&self, start: usize, end: usize) -> &[u8] { + &self.buffer[start..end] + } + + pub fn read_block(&mut self) -> io::Result { + self.offset = 0; + let block_len_res = self.reader.read_u32::(); + if let Err(err) = &block_len_res { + if err.kind() == io::ErrorKind::UnexpectedEof { + return Ok(false); + } + } + let block_len = block_len_res?; + if block_len == 0u32 { + self.buffer.clear(); + return Ok(false); + } + self.buffer.resize(block_len as usize, 0u8); + self.reader.read_exact(&mut self.buffer[..])?; + Ok(true) + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn advance(&mut self, num_bytes: usize) { + self.offset += num_bytes; + } + + pub fn buffer(&self) -> &[u8] { + &self.buffer[self.offset..] + } +} + +impl<'a> io::Read for BlockReader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = self.buffer().read(buf)?; + self.advance(len); + Ok(len) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + let len = self.buffer.len(); + buf.extend_from_slice(self.buffer()); + self.advance(len); + Ok(len) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.buffer().read_exact(buf)?; + self.advance(buf.len()); + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/sstable/delta.rs b/src/termdict/sstable_termdict/sstable/delta.rs new file mode 100644 index 000000000..3551891cd --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/delta.rs @@ -0,0 +1,183 @@ +use std::io::{self, BufWriter, Write}; +use std::ops::Range; + +use common::CountingWriter; + +use super::value::ValueWriter; +use super::{value, vint, BlockReader}; + +const FOUR_BIT_LIMITS: usize = 1 << 4; +const VINT_MODE: u8 = 1u8; +const BLOCK_LEN: usize = 32_000; + +pub struct DeltaWriter +where W: io::Write +{ + block: Vec, + write: CountingWriter>, + value_writer: TValueWriter, +} + +impl DeltaWriter +where + W: io::Write, + TValueWriter: ValueWriter, +{ + pub fn new(wrt: W) -> Self { + DeltaWriter { + block: Vec::with_capacity(BLOCK_LEN * 2), + write: CountingWriter::wrap(BufWriter::new(wrt)), + value_writer: TValueWriter::default(), + } + } +} + +impl DeltaWriter +where + W: io::Write, + TValueWriter: value::ValueWriter, +{ + pub fn flush_block(&mut self) -> io::Result>> { + if self.block.is_empty() { + return Ok(None); + } + let start_offset = self.write.written_bytes() as usize; + // TODO avoid buffer allocation + let mut buffer = Vec::new(); + self.value_writer.write_block(&mut buffer); + let block_len = buffer.len() + self.block.len(); + self.write.write_all(&(block_len as u32).to_le_bytes())?; + self.write.write_all(&buffer[..])?; + self.write.write_all(&self.block[..])?; + let end_offset = self.write.written_bytes() as usize; + self.block.clear(); + Ok(Some(start_offset..end_offset)) + } + + fn encode_keep_add(&mut self, keep_len: usize, add_len: usize) { + if keep_len < FOUR_BIT_LIMITS && add_len < FOUR_BIT_LIMITS { + let b = (keep_len | add_len << 4) as u8; + self.block.extend_from_slice(&[b]) + } else { + let mut buf = [VINT_MODE; 20]; + let mut len = 1 + vint::serialize(keep_len as u64, &mut buf[1..]); + len += vint::serialize(add_len as u64, &mut buf[len..]); + self.block.extend_from_slice(&buf[..len]) + } + } + + pub(crate) fn write_suffix(&mut self, common_prefix_len: usize, suffix: &[u8]) { + let keep_len = common_prefix_len; + let add_len = suffix.len(); + self.encode_keep_add(keep_len, add_len); + self.block.extend_from_slice(suffix); + } + + pub(crate) fn write_value(&mut self, value: &TValueWriter::Value) { + self.value_writer.write(value); + } + + pub fn flush_block_if_required(&mut self) -> io::Result>> { + if self.block.len() > BLOCK_LEN { + return self.flush_block(); + } + Ok(None) + } + + pub fn finalize(self) -> CountingWriter> { + self.write + } +} + +pub struct DeltaReader<'a, TValueReader> { + common_prefix_len: usize, + suffix_start: usize, + suffix_end: usize, + value_reader: TValueReader, + block_reader: BlockReader<'a>, + idx: usize, +} + +impl<'a, TValueReader> DeltaReader<'a, TValueReader> +where TValueReader: value::ValueReader +{ + pub fn new(reader: R) -> Self { + DeltaReader { + idx: 0, + common_prefix_len: 0, + suffix_start: 0, + suffix_end: 0, + value_reader: TValueReader::default(), + block_reader: BlockReader::new(Box::new(reader)), + } + } + + fn deserialize_vint(&mut self) -> u64 { + self.block_reader.deserialize_u64() + } + + fn read_keep_add(&mut self) -> Option<(usize, usize)> { + let b = { + let buf = &self.block_reader.buffer(); + if buf.is_empty() { + return None; + } + buf[0] + }; + self.block_reader.advance(1); + match b { + VINT_MODE => { + let keep = self.deserialize_vint() as usize; + let add = self.deserialize_vint() as usize; + Some((keep, add)) + } + b => { + let keep = (b & 0b1111) as usize; + let add = (b >> 4) as usize; + Some((keep, add)) + } + } + } + + fn read_delta_key(&mut self) -> bool { + if let Some((keep, add)) = self.read_keep_add() { + self.common_prefix_len = keep; + self.suffix_start = self.block_reader.offset(); + self.suffix_end = self.suffix_start + add; + self.block_reader.advance(add); + true + } else { + false + } + } + + pub fn advance(&mut self) -> io::Result { + if self.block_reader.buffer().is_empty() { + if !self.block_reader.read_block()? { + return Ok(false); + } + self.value_reader.read(&mut self.block_reader)?; + self.idx = 0; + } else { + self.idx += 1; + } + if !self.read_delta_key() { + return Ok(false); + } + Ok(true) + } + + pub fn common_prefix_len(&self) -> usize { + self.common_prefix_len + } + + pub fn suffix(&self) -> &[u8] { + &self + .block_reader + .buffer_from_to(self.suffix_start, self.suffix_end) + } + + pub fn value(&self) -> &TValueReader::Value { + self.value_reader.value(self.idx) + } +} diff --git a/src/termdict/sstable_termdict/sstable/merge/heap_merge.rs b/src/termdict/sstable_termdict/sstable/merge/heap_merge.rs new file mode 100644 index 000000000..50af21962 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/merge/heap_merge.rs @@ -0,0 +1,72 @@ +use std::cmp::Ordering; +use std::collections::binary_heap::PeekMut; +use std::collections::BinaryHeap; +use std::io; + +use super::{SingleValueMerger, ValueMerger}; +use crate::termdict::sstable_termdict::sstable::{Reader, SSTable, Writer}; + +struct HeapItem>(B); + +impl> Ord for HeapItem { + fn cmp(&self, other: &Self) -> Ordering { + other.0.as_ref().cmp(self.0.as_ref()) + } +} +impl> PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(other.0.as_ref().cmp(self.0.as_ref())) + } +} + +impl> Eq for HeapItem {} +impl> PartialEq for HeapItem { + fn eq(&self, other: &Self) -> bool { + self.0.as_ref() == other.0.as_ref() + } +} + +#[allow(dead_code)] +pub fn merge_sstable>( + readers: Vec>, + mut writer: Writer, + mut merger: M, +) -> io::Result<()> { + let mut heap: BinaryHeap>> = + BinaryHeap::with_capacity(readers.len()); + for mut reader in readers { + if reader.advance()? { + heap.push(HeapItem(reader)); + } + } + loop { + let len = heap.len(); + let mut value_merger; + if let Some(mut head) = heap.peek_mut() { + writer.write_key(head.0.key()); + value_merger = merger.new_value(head.0.value()); + if !head.0.advance()? { + PeekMut::pop(head); + } + } else { + break; + } + for _ in 0..len - 1 { + if let Some(mut head) = heap.peek_mut() { + if head.0.key() == writer.current_key() { + value_merger.add(head.0.value()); + if !head.0.advance()? { + PeekMut::pop(head); + } + continue; + } + } + break; + } + let value = value_merger.finish(); + writer.write_value(&value)?; + writer.flush_block_if_required()?; + } + writer.finalize()?; + Ok(()) +} diff --git a/src/termdict/sstable_termdict/sstable/merge/mod.rs b/src/termdict/sstable_termdict/sstable/merge/mod.rs new file mode 100644 index 000000000..d60440962 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/merge/mod.rs @@ -0,0 +1,179 @@ +mod heap_merge; + +pub use self::heap_merge::merge_sstable; + +pub trait SingleValueMerger { + fn add(&mut self, v: &V); + fn finish(self) -> V; +} + +pub trait ValueMerger { + type TSingleValueMerger: SingleValueMerger; + fn new_value(&mut self, v: &V) -> Self::TSingleValueMerger; +} + +#[derive(Default)] +pub struct KeepFirst; + +pub struct FirstVal(V); + +impl ValueMerger for KeepFirst { + type TSingleValueMerger = FirstVal; + + fn new_value(&mut self, v: &V) -> FirstVal { + FirstVal(v.clone()) + } +} + +impl SingleValueMerger for FirstVal { + fn add(&mut self, _: &V) {} + + fn finish(self) -> V { + self.0 + } +} + +pub struct VoidMerge; +impl ValueMerger<()> for VoidMerge { + type TSingleValueMerger = (); + + fn new_value(&mut self, _: &()) {} +} + +pub struct U64Merge; +impl ValueMerger for U64Merge { + type TSingleValueMerger = u64; + + fn new_value(&mut self, val: &u64) -> u64 { + *val + } +} + +impl SingleValueMerger for u64 { + fn add(&mut self, val: &u64) { + *self += *val; + } + + fn finish(self) -> u64 { + self + } +} + +impl SingleValueMerger<()> for () { + fn add(&mut self, _: &()) {} + + fn finish(self) {} +} + +#[cfg(test)] +mod tests { + + use std::collections::{BTreeMap, BTreeSet}; + use std::str; + + use super::super::{SSTable, SSTableMonotonicU64, VoidSSTable}; + use super::{U64Merge, VoidMerge}; + + fn write_sstable(keys: &[&'static str]) -> Vec { + let mut buffer: Vec = vec![]; + { + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + for &key in keys { + assert!(sstable_writer.write(key.as_bytes(), &()).is_ok()); + } + assert!(sstable_writer.finalize().is_ok()); + } + dbg!(&buffer); + buffer + } + + fn write_sstable_u64(keys: &[(&'static str, u64)]) -> Vec { + let mut buffer: Vec = vec![]; + { + let mut sstable_writer = SSTableMonotonicU64::writer(&mut buffer); + for (key, val) in keys { + assert!(sstable_writer.write(key.as_bytes(), val).is_ok()); + } + assert!(sstable_writer.finalize().is_ok()); + } + buffer + } + + fn merge_test_aux(arrs: &[&[&'static str]]) { + let sstables = arrs.iter().cloned().map(write_sstable).collect::>(); + let sstables_ref: Vec<&[u8]> = sstables.iter().map(|s| s.as_ref()).collect(); + let mut merged = BTreeSet::new(); + for &arr in arrs.iter() { + for &s in arr { + merged.insert(s.to_string()); + } + } + let mut w = Vec::new(); + assert!(VoidSSTable::merge(sstables_ref, &mut w, VoidMerge).is_ok()); + let mut reader = VoidSSTable::reader(&w[..]); + for k in merged { + assert!(reader.advance().unwrap()); + assert_eq!(reader.key(), k.as_bytes()); + } + assert!(!reader.advance().unwrap()); + } + + fn merge_test_u64_monotonic_aux(arrs: &[&[(&'static str, u64)]]) { + let sstables = arrs + .iter() + .cloned() + .map(write_sstable_u64) + .collect::>(); + let sstables_ref: Vec<&[u8]> = sstables.iter().map(|s| s.as_ref()).collect(); + let mut merged = BTreeMap::new(); + for &arr in arrs.iter() { + for (key, val) in arr { + let entry = merged.entry(key.to_string()).or_insert(0u64); + *entry += val; + } + } + let mut w = Vec::new(); + assert!(SSTableMonotonicU64::merge(sstables_ref, &mut w, U64Merge).is_ok()); + let mut reader = SSTableMonotonicU64::reader(&w[..]); + for (k, v) in merged { + assert!(reader.advance().unwrap()); + assert_eq!(reader.key(), k.as_bytes()); + assert_eq!(reader.value(), &v); + } + assert!(!reader.advance().unwrap()); + } + + #[test] + fn test_merge_simple_reproduce() { + let sstable_data = write_sstable(&["a"]); + let mut reader = VoidSSTable::reader(&sstable_data[..]); + assert!(reader.advance().unwrap()); + assert_eq!(reader.key(), b"a"); + assert!(!reader.advance().unwrap()); + } + + #[test] + fn test_merge() { + merge_test_aux(&[]); + merge_test_aux(&[&["a"]]); + merge_test_aux(&[&["a", "b"], &["ab"]]); // a, ab, b + merge_test_aux(&[&["a", "b"], &["a", "b"]]); + merge_test_aux(&[ + &["happy", "hello", "payer", "tax"], + &["habitat", "hello", "zoo"], + &[], + &["a"], + ]); + merge_test_aux(&[&["a"]]); + merge_test_aux(&[&["a", "b"], &["ab"]]); + merge_test_aux(&[&["a", "b"], &["a", "b"]]); + } + + #[test] + fn test_merge_u64() { + merge_test_u64_monotonic_aux(&[]); + merge_test_u64_monotonic_aux(&[&[("a", 1u64)]]); + merge_test_u64_monotonic_aux(&[&[("a", 1u64), ("b", 3u64)], &[("ab", 2u64)]]); // a, ab, b + merge_test_u64_monotonic_aux(&[&[("a", 1u64), ("b", 2u64)], &[("a", 16u64), ("b", 23u64)]]); + } +} diff --git a/src/termdict/sstable_termdict/sstable/mod.rs b/src/termdict/sstable_termdict/sstable/mod.rs new file mode 100644 index 000000000..71c387cae --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/mod.rs @@ -0,0 +1,352 @@ +use std::io::{self, Write}; +use std::usize; + +use merge::ValueMerger; + +mod delta; +pub mod merge; +pub mod value; + +pub(crate) mod sstable_index; + +pub(crate) use self::sstable_index::{SSTableIndex, SSTableIndexBuilder}; +pub(crate) mod vint; + +mod block_reader; +pub use self::block_reader::BlockReader; +pub use self::delta::DeltaReader; +use self::delta::DeltaWriter; +pub use self::merge::VoidMerge; +use self::value::{U64MonotonicReader, U64MonotonicWriter, ValueReader, ValueWriter}; + +const DEFAULT_KEY_CAPACITY: usize = 50; + +pub(crate) fn common_prefix_len(left: &[u8], right: &[u8]) -> usize { + left.iter() + .cloned() + .zip(right.iter().cloned()) + .take_while(|(left, right)| left == right) + .count() +} + +pub trait SSTable: Sized { + type Value; + type Reader: ValueReader; + type Writer: ValueWriter; + + fn delta_writer(write: W) -> DeltaWriter { + DeltaWriter::new(write) + } + + fn writer(write: W) -> Writer { + Writer { + previous_key: Vec::with_capacity(DEFAULT_KEY_CAPACITY), + num_terms: 0u64, + index_builder: SSTableIndexBuilder::default(), + delta_writer: Self::delta_writer(write), + first_ordinal_of_the_block: 0u64, + } + } + + fn delta_reader<'a, R: io::Read + 'a>(reader: R) -> DeltaReader<'a, Self::Reader> { + DeltaReader::new(reader) + } + + fn reader<'a, R: io::Read + 'a>(reader: R) -> Reader<'a, Self::Reader> { + Reader { + key: Vec::with_capacity(DEFAULT_KEY_CAPACITY), + delta_reader: Self::delta_reader(reader), + } + } + + fn merge>( + io_readers: Vec, + w: W, + merger: M, + ) -> io::Result<()> { + let readers: Vec<_> = io_readers.into_iter().map(Self::reader).collect(); + let writer = Self::writer(w); + merge::merge_sstable::(readers, writer, merger) + } +} + +#[allow(dead_code)] +pub struct VoidSSTable; + +impl SSTable for VoidSSTable { + type Value = (); + type Reader = value::VoidReader; + type Writer = value::VoidWriter; +} + +#[allow(dead_code)] +pub struct SSTableMonotonicU64; + +impl SSTable for SSTableMonotonicU64 { + type Value = u64; + + type Reader = U64MonotonicReader; + + type Writer = U64MonotonicWriter; +} + +pub struct Reader<'a, TValueReader> { + key: Vec, + delta_reader: DeltaReader<'a, TValueReader>, +} + +impl<'a, TValueReader> Reader<'a, TValueReader> +where TValueReader: ValueReader +{ + pub fn advance(&mut self) -> io::Result { + if !self.delta_reader.advance()? { + return Ok(false); + } + let common_prefix_len = self.delta_reader.common_prefix_len(); + let suffix = self.delta_reader.suffix(); + let new_len = self.delta_reader.common_prefix_len() + suffix.len(); + self.key.resize(new_len, 0u8); + self.key[common_prefix_len..].copy_from_slice(suffix); + Ok(true) + } + + pub fn key(&self) -> &[u8] { + &self.key + } + + pub fn value(&self) -> &TValueReader::Value { + self.delta_reader.value() + } +} + +impl<'a, TValueReader> AsRef<[u8]> for Reader<'a, TValueReader> { + fn as_ref(&self) -> &[u8] { + &self.key + } +} + +pub struct Writer +where W: io::Write +{ + previous_key: Vec, + index_builder: SSTableIndexBuilder, + delta_writer: DeltaWriter, + num_terms: u64, + first_ordinal_of_the_block: u64, +} + +impl Writer +where + W: io::Write, + TValueWriter: value::ValueWriter, +{ + pub(crate) fn current_key(&self) -> &[u8] { + &self.previous_key[..] + } + + pub fn write_key(&mut self, key: &[u8]) { + let keep_len = common_prefix_len(&self.previous_key, key); + let add_len = key.len() - keep_len; + let increasing_keys = add_len > 0 && (self.previous_key.len() == keep_len) + || self.previous_key.is_empty() + || self.previous_key[keep_len] < key[keep_len]; + assert!( + increasing_keys, + "Keys should be increasing. ({:?} > {:?})", + self.previous_key, key + ); + self.previous_key.resize(key.len(), 0u8); + self.previous_key[keep_len..].copy_from_slice(&key[keep_len..]); + self.delta_writer.write_suffix(keep_len, &key[keep_len..]); + } + + #[allow(dead_code)] + pub fn write(&mut self, key: &[u8], value: &TValueWriter::Value) -> io::Result<()> { + self.write_key(key); + self.write_value(value)?; + Ok(()) + } + + pub fn write_value(&mut self, value: &TValueWriter::Value) -> io::Result<()> { + self.delta_writer.write_value(value); + self.num_terms += 1u64; + self.flush_block_if_required() + } + + pub fn flush_block_if_required(&mut self) -> io::Result<()> { + if let Some(byte_range) = self.delta_writer.flush_block_if_required()? { + self.index_builder.add_block( + &self.previous_key[..], + byte_range, + self.first_ordinal_of_the_block, + ); + self.first_ordinal_of_the_block = self.num_terms; + self.previous_key.clear(); + } + Ok(()) + } + + pub fn finalize(mut self) -> io::Result { + if let Some(byte_range) = self.delta_writer.flush_block()? { + self.index_builder.add_block( + &self.previous_key[..], + byte_range, + self.first_ordinal_of_the_block, + ); + self.first_ordinal_of_the_block = self.num_terms; + } + let mut wrt = self.delta_writer.finalize(); + wrt.write_all(&0u32.to_le_bytes())?; + + let offset = wrt.written_bytes(); + + self.index_builder.serialize(&mut wrt)?; + wrt.write_all(&offset.to_le_bytes())?; + wrt.write_all(&self.num_terms.to_le_bytes())?; + let wrt = wrt.finish(); + Ok(wrt.into_inner()?) + } +} +#[cfg(test)] +mod test { + use std::io; + + use super::{common_prefix_len, SSTable, SSTableMonotonicU64, VoidMerge, VoidSSTable}; + + fn aux_test_common_prefix_len(left: &str, right: &str, expect_len: usize) { + assert_eq!( + common_prefix_len(left.as_bytes(), right.as_bytes()), + expect_len + ); + assert_eq!( + common_prefix_len(right.as_bytes(), left.as_bytes()), + expect_len + ); + } + + #[test] + fn test_common_prefix_len() { + aux_test_common_prefix_len("a", "ab", 1); + aux_test_common_prefix_len("", "ab", 0); + aux_test_common_prefix_len("ab", "abc", 2); + aux_test_common_prefix_len("abde", "abce", 2); + } + + #[test] + fn test_long_key_diff() { + let long_key = (0..1_024).map(|x| (x % 255) as u8).collect::>(); + let long_key2 = (1..300).map(|x| (x % 255) as u8).collect::>(); + let mut buffer = vec![]; + { + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + assert!(sstable_writer.write(&long_key[..], &()).is_ok()); + assert!(sstable_writer.write(&[0, 3, 4], &()).is_ok()); + assert!(sstable_writer.write(&long_key2[..], &()).is_ok()); + assert!(sstable_writer.finalize().is_ok()); + } + let mut sstable_reader = VoidSSTable::reader(&buffer[..]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &long_key[..]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[0, 3, 4]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &long_key2[..]); + assert!(!sstable_reader.advance().unwrap()); + } + + #[test] + fn test_simple_sstable() { + let mut buffer = vec![]; + { + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + assert!(sstable_writer.write(&[17u8], &()).is_ok()); + assert!(sstable_writer.write(&[17u8, 18u8, 19u8], &()).is_ok()); + assert!(sstable_writer.write(&[17u8, 20u8], &()).is_ok()); + assert!(sstable_writer.finalize().is_ok()); + } + assert_eq!( + &buffer, + &[ + // block len + 7u8, 0u8, 0u8, 0u8, // keep 0 push 1 | "" + 16u8, 17u8, // keep 1 push 2 | 18 19 + 33u8, 18u8, 19u8, // keep 1 push 1 | 20 + 17u8, 20u8, 0u8, 0u8, 0u8, 0u8, // no more blocks + // index + 161, 102, 98, 108, 111, 99, 107, 115, 129, 162, 104, 108, 97, 115, 116, 95, 107, + 101, 121, 130, 17, 20, 106, 98, 108, 111, 99, 107, 95, 97, 100, 100, 114, 162, 106, + 98, 121, 116, 101, 95, 114, 97, 110, 103, 101, 162, 101, 115, 116, 97, 114, 116, 0, + 99, 101, 110, 100, 11, 109, 102, 105, 114, 115, 116, 95, 111, 114, 100, 105, 110, + 97, 108, 0, 15, 0, 0, 0, 0, 0, 0, 0, // offset for the index + 3u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8 // num terms + ] + ); + let mut sstable_reader = VoidSSTable::reader(&buffer[..]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[17u8]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[17u8, 18u8, 19u8]); + assert!(sstable_reader.advance().unwrap()); + assert_eq!(sstable_reader.key(), &[17u8, 20u8]); + assert!(!sstable_reader.advance().unwrap()); + } + + #[test] + #[should_panic] + fn test_simple_sstable_non_increasing_key() { + let mut buffer = vec![]; + let mut sstable_writer = VoidSSTable::writer(&mut buffer); + assert!(sstable_writer.write(&[17u8], &()).is_ok()); + assert!(sstable_writer.write(&[16u8], &()).is_ok()); + } + + #[test] + fn test_merge_abcd_abe() { + let mut buffer = Vec::new(); + { + let mut writer = VoidSSTable::writer(&mut buffer); + writer.write(b"abcd", &()).unwrap(); + writer.write(b"abe", &()).unwrap(); + writer.finalize().unwrap(); + } + let mut output = Vec::new(); + assert!(VoidSSTable::merge(vec![&buffer[..], &buffer[..]], &mut output, VoidMerge).is_ok()); + assert_eq!(&output[..], &buffer[..]); + } + + #[test] + fn test_sstable() { + let mut buffer = Vec::new(); + { + let mut writer = VoidSSTable::writer(&mut buffer); + writer.write(b"abcd", &()).unwrap(); + writer.write(b"abe", &()).unwrap(); + writer.finalize().unwrap(); + } + let mut output = Vec::new(); + assert!(VoidSSTable::merge(vec![&buffer[..], &buffer[..]], &mut output, VoidMerge).is_ok()); + assert_eq!(&output[..], &buffer[..]); + } + + #[test] + fn test_sstable_u64() -> io::Result<()> { + let mut buffer = Vec::new(); + let mut writer = SSTableMonotonicU64::writer(&mut buffer); + writer.write(b"abcd", &1u64)?; + writer.write(b"abe", &4u64)?; + writer.write(b"gogo", &4324234234234234u64)?; + writer.finalize()?; + let mut reader = SSTableMonotonicU64::reader(&buffer[..]); + assert!(reader.advance()?); + assert_eq!(reader.key(), b"abcd"); + assert_eq!(reader.value(), &1u64); + assert!(reader.advance()?); + assert_eq!(reader.key(), b"abe"); + assert_eq!(reader.value(), &4u64); + assert!(reader.advance()?); + assert_eq!(reader.key(), b"gogo"); + assert_eq!(reader.value(), &4324234234234234u64); + assert!(!reader.advance()?); + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/sstable/sstable_index.rs b/src/termdict/sstable_termdict/sstable/sstable_index.rs new file mode 100644 index 000000000..e1fceceee --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/sstable_index.rs @@ -0,0 +1,100 @@ +use std::io; +use std::ops::Range; + +use serde::{Deserialize, Serialize}; + +use crate::error::DataCorruption; + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct SSTableIndex { + blocks: Vec, +} + +impl SSTableIndex { + pub(crate) fn load(data: &[u8]) -> Result { + serde_cbor::de::from_slice(data) + .map_err(|_| DataCorruption::comment_only("SSTable index is corrupted")) + } + + pub fn search(&self, key: &[u8]) -> Option { + self.blocks + .iter() + .find(|block| &block.last_key[..] >= key) + .map(|block| block.block_addr.clone()) + } +} + +#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub struct BlockAddr { + pub byte_range: Range, + pub first_ordinal: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +struct BlockMeta { + pub last_key: Vec, + pub block_addr: BlockAddr, +} + +#[derive(Default)] +pub struct SSTableIndexBuilder { + index: SSTableIndex, +} + +impl SSTableIndexBuilder { + pub fn add_block(&mut self, last_key: &[u8], byte_range: Range, first_ordinal: u64) { + self.index.blocks.push(BlockMeta { + last_key: last_key.to_vec(), + block_addr: BlockAddr { + byte_range, + first_ordinal, + }, + }) + } + + pub fn serialize(&self, wrt: &mut dyn io::Write) -> io::Result<()> { + serde_cbor::ser::to_writer(wrt, &self.index).unwrap(); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{BlockAddr, SSTableIndex, SSTableIndexBuilder}; + + #[test] + fn test_sstable_index() { + let mut sstable_builder = SSTableIndexBuilder::default(); + sstable_builder.add_block(b"aaa", 10..20, 0u64); + sstable_builder.add_block(b"bbbbbbb", 20..30, 564); + sstable_builder.add_block(b"ccc", 30..40, 10u64); + sstable_builder.add_block(b"dddd", 40..50, 15u64); + let mut buffer: Vec = Vec::new(); + sstable_builder.serialize(&mut buffer).unwrap(); + let sstable_index = SSTableIndex::load(&buffer[..]).unwrap(); + assert_eq!( + sstable_index.search(b"bbbde"), + Some(BlockAddr { + first_ordinal: 10u64, + byte_range: 30..40 + }) + ); + } + + #[test] + fn test_sstable_with_corrupted_data() { + let mut sstable_builder = SSTableIndexBuilder::default(); + sstable_builder.add_block(b"aaa", 10..20, 0u64); + sstable_builder.add_block(b"bbbbbbb", 20..30, 564); + sstable_builder.add_block(b"ccc", 30..40, 10u64); + sstable_builder.add_block(b"dddd", 40..50, 15u64); + let mut buffer: Vec = Vec::new(); + sstable_builder.serialize(&mut buffer).unwrap(); + buffer[1] = 9u8; + let data_corruption_err = SSTableIndex::load(&buffer[..]).err().unwrap(); + assert_eq!( + format!("{data_corruption_err:?}"), + "Data corruption: SSTable index is corrupted." + ); + } +} diff --git a/src/termdict/sstable_termdict/sstable/value.rs b/src/termdict/sstable_termdict/sstable/value.rs new file mode 100644 index 000000000..969dae2f2 --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/value.rs @@ -0,0 +1,95 @@ +use std::io; + +use super::{vint, BlockReader}; + +pub trait ValueReader: Default { + type Value; + + fn value(&self, idx: usize) -> &Self::Value; + + fn read(&mut self, reader: &mut BlockReader) -> io::Result<()>; +} + +pub trait ValueWriter: Default { + type Value; + + fn write(&mut self, val: &Self::Value); + + fn write_block(&mut self, writer: &mut Vec); +} + +#[derive(Default)] +pub struct VoidReader; + +impl ValueReader for VoidReader { + type Value = (); + + fn value(&self, _idx: usize) -> &() { + &() + } + + fn read(&mut self, _reader: &mut BlockReader) -> io::Result<()> { + Ok(()) + } +} + +#[derive(Default)] +pub struct VoidWriter; + +impl ValueWriter for VoidWriter { + type Value = (); + + fn write(&mut self, _val: &()) {} + + fn write_block(&mut self, _writer: &mut Vec) {} +} + +#[derive(Default)] +pub struct U64MonotonicWriter { + vals: Vec, +} + +impl ValueWriter for U64MonotonicWriter { + type Value = u64; + + fn write(&mut self, val: &Self::Value) { + self.vals.push(*val); + } + + fn write_block(&mut self, writer: &mut Vec) { + let mut prev_val = 0u64; + vint::serialize_into_vec(self.vals.len() as u64, writer); + for &val in &self.vals { + let delta = val - prev_val; + vint::serialize_into_vec(delta, writer); + prev_val = val; + } + self.vals.clear(); + } +} + +#[derive(Default)] +pub struct U64MonotonicReader { + vals: Vec, +} + +impl ValueReader for U64MonotonicReader { + type Value = u64; + + fn value(&self, idx: usize) -> &Self::Value { + &self.vals[idx] + } + + fn read(&mut self, reader: &mut BlockReader) -> io::Result<()> { + let len = reader.deserialize_u64() as usize; + self.vals.clear(); + let mut prev_val = 0u64; + for _ in 0..len { + let delta = reader.deserialize_u64() as u64; + let val = prev_val + delta; + self.vals.push(val); + prev_val = val; + } + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/sstable/vint.rs b/src/termdict/sstable_termdict/sstable/vint.rs new file mode 100644 index 000000000..e15988d8b --- /dev/null +++ b/src/termdict/sstable_termdict/sstable/vint.rs @@ -0,0 +1,67 @@ +const CONTINUE_BIT: u8 = 128u8; + +pub fn serialize(mut val: u64, buffer: &mut [u8]) -> usize { + for (i, b) in buffer.iter_mut().enumerate() { + let next_byte: u8 = (val & 127u64) as u8; + val >>= 7; + if val == 0u64 { + *b = next_byte; + return i + 1; + } else { + *b = next_byte | CONTINUE_BIT; + } + } + 10 //< actually unreachable +} + +pub fn serialize_into_vec(val: u64, buffer: &mut Vec) { + let mut buf = [0u8; 10]; + let num_bytes = serialize(val, &mut buf[..]); + buffer.extend_from_slice(&buf[..num_bytes]); +} + +// super slow but we don't care +pub fn deserialize_read(buf: &[u8]) -> (usize, u64) { + let mut result = 0u64; + let mut shift = 0u64; + let mut consumed = 0; + + for &b in buf { + consumed += 1; + result |= u64::from(b % 128u8) << shift; + if b < CONTINUE_BIT { + break; + } + shift += 7; + } + (consumed, result) +} + +#[cfg(test)] +mod tests { + use std::u64; + + use super::{deserialize_read, serialize}; + + fn aux_test_int(val: u64, expect_len: usize) { + let mut buffer = [0u8; 14]; + assert_eq!(serialize(val, &mut buffer[..]), expect_len); + assert_eq!(deserialize_read(&buffer), (expect_len, val)); + } + + #[test] + fn test_vint() { + aux_test_int(0u64, 1); + aux_test_int(17u64, 1); + aux_test_int(127u64, 1); + aux_test_int(128u64, 2); + aux_test_int(123423418u64, 4); + for i in 1..63 { + let power_of_two = 1u64 << i; + aux_test_int(power_of_two + 1, (i / 7) + 1); + aux_test_int(power_of_two, (i / 7) + 1); + aux_test_int(power_of_two - 1, ((i - 1) / 7) + 1); + } + aux_test_int(u64::MAX, 10); + } +} diff --git a/src/termdict/sstable_termdict/streamer.rs b/src/termdict/sstable_termdict/streamer.rs new file mode 100644 index 000000000..7830a4385 --- /dev/null +++ b/src/termdict/sstable_termdict/streamer.rs @@ -0,0 +1,251 @@ +use std::io; +use std::ops::Bound; + +use tantivy_fst::automaton::AlwaysMatch; +use tantivy_fst::Automaton; + +use super::TermDictionary; +use crate::postings::TermInfo; +use crate::termdict::sstable_termdict::TermInfoReader; +use crate::termdict::TermOrdinal; + +/// `TermStreamerBuilder` is a helper object used to define +/// a range of terms that should be streamed. +pub struct TermStreamerBuilder<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + term_dict: &'a TermDictionary, + automaton: A, + lower: Bound>, + upper: Bound>, +} + +impl<'a, A> TermStreamerBuilder<'a, A> +where + A: Automaton, + A::State: Clone, +{ + pub(crate) fn new(term_dict: &'a TermDictionary, automaton: A) -> Self { + TermStreamerBuilder { + term_dict, + automaton, + lower: Bound::Unbounded, + upper: Bound::Unbounded, + } + } + + /// Limit the range to terms greater or equal to the bound + pub fn ge>(mut self, bound: T) -> Self { + self.lower = Bound::Included(bound.as_ref().to_owned()); + self + } + + /// Limit the range to terms strictly greater than the bound + pub fn gt>(mut self, bound: T) -> Self { + self.lower = Bound::Excluded(bound.as_ref().to_owned()); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn le>(mut self, bound: T) -> Self { + self.upper = Bound::Included(bound.as_ref().to_owned()); + self + } + + /// Limit the range to terms lesser or equal to the bound + pub fn lt>(mut self, bound: T) -> Self { + self.upper = Bound::Excluded(bound.as_ref().to_owned()); + self + } + + /// Creates the stream corresponding to the range + /// of terms defined using the `TermStreamerBuilder`. + pub fn into_stream(self) -> io::Result> { + // TODO Optimize by skipping to the right first block. + let start_state = self.automaton.start(); + let delta_reader = self.term_dict.sstable_delta_reader()?; + Ok(TermStreamer { + automaton: self.automaton, + states: vec![start_state], + delta_reader, + key: Vec::new(), + term_ord: None, + lower_bound: self.lower, + upper_bound: self.upper, + }) + } +} + +/// `TermStreamer` acts as a cursor over a range of terms of a segment. +/// Terms are guaranteed to be sorted. +pub struct TermStreamer<'a, A = AlwaysMatch> +where + A: Automaton, + A::State: Clone, +{ + automaton: A, + states: Vec, + delta_reader: super::sstable::DeltaReader<'a, TermInfoReader>, + key: Vec, + term_ord: Option, + lower_bound: Bound>, + upper_bound: Bound>, +} + +impl<'a, A> TermStreamer<'a, A> +where + A: Automaton, + A::State: Clone, +{ + /// Advance position the stream on the next item. + /// Before the first call to `.advance()`, the stream + /// is an unitialized state. + pub fn advance(&mut self) -> bool { + while self.delta_reader.advance().unwrap() { + self.term_ord = Some( + self.term_ord + .map(|term_ord| term_ord + 1u64) + .unwrap_or(0u64), + ); + let common_prefix_len = self.delta_reader.common_prefix_len(); + self.states.truncate(common_prefix_len + 1); + self.key.truncate(common_prefix_len); + let mut state: A::State = self.states.last().unwrap().clone(); + for &b in self.delta_reader.suffix() { + state = self.automaton.accept(&state, b); + self.states.push(state.clone()); + } + self.key.extend_from_slice(self.delta_reader.suffix()); + let match_lower_bound = match &self.lower_bound { + Bound::Unbounded => true, + Bound::Included(lower_bound_key) => lower_bound_key[..] <= self.key[..], + Bound::Excluded(lower_bound_key) => lower_bound_key[..] < self.key[..], + }; + if !match_lower_bound { + continue; + } + // We match the lower key once. All subsequent keys will pass that bar. + self.lower_bound = Bound::Unbounded; + let match_upper_bound = match &self.upper_bound { + Bound::Unbounded => true, + Bound::Included(upper_bound_key) => upper_bound_key[..] >= self.key[..], + Bound::Excluded(upper_bound_key) => upper_bound_key[..] > self.key[..], + }; + if !match_upper_bound { + return false; + } + if self.automaton.is_match(&state) { + return true; + } + } + false + } + + /// Returns the `TermOrdinal` of the given term. + /// + /// May panic if the called as `.advance()` as never + /// been called before. + pub fn term_ord(&self) -> TermOrdinal { + self.term_ord.unwrap_or(0u64) + } + + /// Accesses the current key. + /// + /// `.key()` should return the key that was returned + /// by the `.next()` method. + /// + /// If the end of the stream as been reached, and `.next()` + /// has been called and returned `None`, `.key()` remains + /// the value of the last key encountered. + /// + /// Before any call to `.next()`, `.key()` returns an empty array. + pub fn key(&self) -> &[u8] { + &self.key + } + + /// Accesses the current value. + /// + /// Calling `.value()` after the end of the stream will return the + /// last `.value()` encountered. + /// + /// # Panics + /// + /// Calling `.value()` before the first call to `.advance()` returns + /// `V::default()`. + pub fn value(&self) -> &TermInfo { + self.delta_reader.value() + } + + /// Return the next `(key, value)` pair. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::should_implement_trait))] + pub fn next(&mut self) -> Option<(&[u8], &TermInfo)> { + if self.advance() { + Some((self.key(), self.value())) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::super::TermDictionary; + use crate::directory::OwnedBytes; + use crate::postings::TermInfo; + + fn make_term_info(i: usize) -> TermInfo { + TermInfo { + doc_freq: 1000u32 + i as u32, + postings_range: (i + 10) * (i * 10)..((i + 1) + 10) * ((i + 1) * 10), + positions_range: i * 500..(i + 1) * 500, + } + } + + fn create_test_term_dictionary() -> crate::Result { + let mut term_dict_builder = super::super::TermDictionaryBuilder::create(Vec::new())?; + term_dict_builder.insert(b"abaisance", &make_term_info(0))?; + term_dict_builder.insert(b"abalation", &make_term_info(1))?; + term_dict_builder.insert(b"abalienate", &make_term_info(2))?; + term_dict_builder.insert(b"abandon", &make_term_info(3))?; + let buffer = term_dict_builder.finish()?; + let owned_bytes = OwnedBytes::new(buffer); + TermDictionary::from_bytes(owned_bytes) + } + + #[test] + fn test_sstable_stream() -> crate::Result<()> { + let term_dict = create_test_term_dictionary()?; + let mut term_streamer = term_dict.stream()?; + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abaisance"); + assert_eq!(term_streamer.value().doc_freq, 1000u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalation"); + assert_eq!(term_streamer.value().doc_freq, 1001u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalienate"); + assert_eq!(term_streamer.value().doc_freq, 1002u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abandon"); + assert_eq!(term_streamer.value().doc_freq, 1003u32); + assert!(!term_streamer.advance()); + Ok(()) + } + + #[test] + fn test_sstable_search() -> crate::Result<()> { + let term_dict = create_test_term_dictionary()?; + let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap(); + let mut term_streamer = term_dict.search(ptn).into_stream()?; + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalation"); + assert_eq!(term_streamer.value().doc_freq, 1001u32); + assert!(term_streamer.advance()); + assert_eq!(term_streamer.key(), b"abalienate"); + assert_eq!(term_streamer.value().doc_freq, 1002u32); + assert!(!term_streamer.advance()); + Ok(()) + } +} diff --git a/src/termdict/sstable_termdict/termdict.rs b/src/termdict/sstable_termdict/termdict.rs new file mode 100644 index 000000000..b4ff1b66f --- /dev/null +++ b/src/termdict/sstable_termdict/termdict.rs @@ -0,0 +1,254 @@ +use std::io; + +use common::BinarySerializable; +use once_cell::sync::Lazy; +use tantivy_fst::automaton::AlwaysMatch; +use tantivy_fst::Automaton; + +use crate::directory::{FileSlice, OwnedBytes}; +use crate::postings::TermInfo; +use crate::termdict::sstable_termdict::sstable::sstable_index::BlockAddr; +use crate::termdict::sstable_termdict::sstable::{ + DeltaReader, Reader, SSTable, SSTableIndex, Writer, +}; +use crate::termdict::sstable_termdict::{ + TermInfoReader, TermInfoWriter, TermSSTable, TermStreamer, TermStreamerBuilder, +}; +use crate::termdict::TermOrdinal; +use crate::AsyncIoResult; + +pub struct TermInfoSSTable; +impl SSTable for TermInfoSSTable { + type Value = TermInfo; + type Reader = TermInfoReader; + type Writer = TermInfoWriter; +} +pub struct TermDictionaryBuilder { + sstable_writer: Writer, +} + +impl TermDictionaryBuilder { + /// Creates a new `TermDictionaryBuilder` + pub fn create(w: W) -> io::Result { + let sstable_writer = TermSSTable::writer(w); + Ok(TermDictionaryBuilder { sstable_writer }) + } + + /// Inserts a `(key, value)` pair in the term dictionary. + /// + /// *Keys have to be inserted in order.* + pub fn insert>(&mut self, key_ref: K, value: &TermInfo) -> io::Result<()> { + let key = key_ref.as_ref(); + self.insert_key(key)?; + self.insert_value(value)?; + Ok(()) + } + + /// # Warning + /// Horribly dangerous internal API + /// + /// If used, it must be used by systematically alternating calls + /// to insert_key and insert_value. + /// + /// Prefer using `.insert(key, value)` + #[allow(clippy::clippy::clippy::unnecessary_wraps)] + pub(crate) fn insert_key(&mut self, key: &[u8]) -> io::Result<()> { + self.sstable_writer.write_key(key); + Ok(()) + } + + /// # Warning + /// + /// Horribly dangerous internal API. See `.insert_key(...)`. + pub(crate) fn insert_value(&mut self, term_info: &TermInfo) -> io::Result<()> { + self.sstable_writer.write_value(term_info) + } + + /// Finalize writing the builder, and returns the underlying + /// `Write` object. + pub fn finish(self) -> io::Result { + self.sstable_writer.finalize() + } +} + +static EMPTY_TERM_DICT_FILE: Lazy = Lazy::new(|| { + let term_dictionary_data: Vec = TermDictionaryBuilder::create(Vec::::new()) + .expect("Creating a TermDictionaryBuilder in a Vec should never fail") + .finish() + .expect("Writing in a Vec should never fail"); + FileSlice::from(term_dictionary_data) +}); + +/// The term dictionary contains all of the terms in +/// `tantivy index` in a sorted manner. +/// +/// The `Fst` crate is used to associate terms to their +/// respective `TermOrdinal`. The `TermInfoStore` then makes it +/// possible to fetch the associated `TermInfo`. +pub struct TermDictionary { + sstable_slice: FileSlice, + sstable_index: SSTableIndex, + num_terms: u64, +} + +impl TermDictionary { + pub(crate) fn sstable_reader(&self) -> io::Result> { + let data = self.sstable_slice.read_bytes()?; + Ok(TermInfoSSTable::reader(data)) + } + + pub(crate) fn sstable_reader_block( + &self, + block_addr: BlockAddr, + ) -> io::Result> { + let data = self.sstable_slice.read_bytes_slice(block_addr.byte_range)?; + Ok(TermInfoSSTable::reader(data)) + } + + pub(crate) async fn sstable_reader_block_async( + &self, + block_addr: BlockAddr, + ) -> AsyncIoResult> { + let data = self + .sstable_slice + .read_bytes_slice_async(block_addr.byte_range) + .await?; + Ok(TermInfoSSTable::reader(data)) + } + + pub(crate) fn sstable_delta_reader(&self) -> io::Result> { + let data = self.sstable_slice.read_bytes()?; + Ok(TermInfoSSTable::delta_reader(data)) + } + + /// Opens a `TermDictionary`. + pub fn open(term_dictionary_file: FileSlice) -> crate::Result { + let (main_slice, footer_len_slice) = term_dictionary_file.split_from_end(16); + let mut footer_len_bytes: OwnedBytes = footer_len_slice.read_bytes()?; + let index_offset = u64::deserialize(&mut footer_len_bytes)?; + let num_terms = u64::deserialize(&mut footer_len_bytes)?; + let (sstable_slice, index_slice) = main_slice.split(index_offset as usize); + let sstable_index_bytes = index_slice.read_bytes()?; + let sstable_index = SSTableIndex::load(sstable_index_bytes.as_slice())?; + Ok(TermDictionary { + sstable_slice, + sstable_index, + num_terms, + }) + } + + pub fn from_bytes(owned_bytes: OwnedBytes) -> crate::Result { + TermDictionary::open(FileSlice::new(Box::new(owned_bytes))) + } + + /// Creates an empty term dictionary which contains no terms. + pub fn empty() -> Self { + TermDictionary::open(EMPTY_TERM_DICT_FILE.clone()).unwrap() + } + + /// Returns the number of terms in the dictionary. + /// Term ordinals range from 0 to `num_terms() - 1`. + pub fn num_terms(&self) -> usize { + self.num_terms as usize + } + + /// Returns the ordinal associated to a given term. + pub fn term_ord>(&self, key: K) -> io::Result> { + let mut term_ord = 0u64; + let key_bytes = key.as_ref(); + let mut sstable_reader = self.sstable_reader()?; + while sstable_reader.advance().unwrap_or(false) { + if sstable_reader.key() == key_bytes { + return Ok(Some(term_ord)); + } + term_ord += 1; + } + Ok(None) + } + + /// Returns the term associated to a given term ordinal. + /// + /// Term ordinals are defined as the position of the term in + /// the sorted list of terms. + /// + /// Returns true iff the term has been found. + /// + /// Regardless of whether the term is found or not, + /// the buffer may be modified. + pub fn ord_to_term(&self, ord: TermOrdinal, bytes: &mut Vec) -> io::Result { + let mut sstable_reader = self.sstable_reader()?; + bytes.clear(); + for _ in 0..(ord + 1) { + if !sstable_reader.advance().unwrap_or(false) { + return Ok(false); + } + } + bytes.extend_from_slice(sstable_reader.key()); + Ok(true) + } + + /// Returns the number of terms in the dictionary. + pub fn term_info_from_ord(&self, term_ord: TermOrdinal) -> io::Result { + let mut sstable_reader = self.sstable_reader()?; + for _ in 0..(term_ord + 1) { + if !sstable_reader.advance().unwrap_or(false) { + return Ok(TermInfo::default()); + } + } + Ok(sstable_reader.value().clone()) + } + + /// Lookups the value corresponding to the key. + pub fn get>(&self, key: K) -> io::Result> { + if let Some(block_addr) = self.sstable_index.search(key.as_ref()) { + let mut sstable_reader = self.sstable_reader_block(block_addr)?; + let key_bytes = key.as_ref(); + while sstable_reader.advance().unwrap_or(false) { + if sstable_reader.key() == key_bytes { + let term_info = sstable_reader.value().clone(); + return Ok(Some(term_info)); + } + } + } + Ok(None) + } + + /// Lookups the value corresponding to the key. + pub async fn get_async>(&self, key: K) -> AsyncIoResult> { + if let Some(block_addr) = self.sstable_index.search(key.as_ref()) { + let mut sstable_reader = self.sstable_reader_block_async(block_addr).await?; + let key_bytes = key.as_ref(); + while sstable_reader.advance().unwrap_or(false) { + if sstable_reader.key() == key_bytes { + let term_info = sstable_reader.value().clone(); + return Ok(Some(term_info)); + } + } + } + Ok(None) + } + + // Returns a range builder, to stream all of the terms + // within an interval. + pub fn range(&self) -> TermStreamerBuilder<'_> { + TermStreamerBuilder::new(self, AlwaysMatch) + } + + // A stream of all the sorted terms. [See also `.stream_field()`](#method.stream_field) + pub fn stream(&self) -> io::Result> { + self.range().into_stream() + } + + // Returns a search builder, to stream all of the terms + // within the Automaton + pub fn search<'a, A: Automaton + 'a>(&'a self, automaton: A) -> TermStreamerBuilder<'a, A> + where A::State: Clone { + TermStreamerBuilder::::new(self, automaton) + } + + #[doc(hidden)] + pub async fn warm_up_dictionary(&self) -> AsyncIoResult<()> { + self.sstable_slice.read_bytes_async().await?; + Ok(()) + } +} diff --git a/src/termdict/tests.rs b/src/termdict/tests.rs index 2a649ca9a..ea26de841 100644 --- a/src/termdict/tests.rs +++ b/src/termdict/tests.rs @@ -302,6 +302,7 @@ fn test_stream_range_boundaries_forward() -> crate::Result<()> { Ok(()) } +#[cfg(not(feature = "quickwit"))] #[test] fn test_stream_range_boundaries_backward() -> crate::Result<()> { let term_dictionary = stream_range_test_dict()?;