diff --git a/CHANGELOG.md b/CHANGELOG.md index 0aedae3ae..94053798d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,20 @@ Tantivy 0.13.0 - Bugfix in `FuzzyTermQuery` not matching terms by prefix when it should (@Peachball) - Relaxed constraints on the custom/tweak score functions. At the segment level, they can be mut, and they are not required to be Sync + Send. - `MMapDirectory::open` does not return a `Result` anymore. +- Change in the DocSet and Scorer API. (@fulmicoton). +A freshly created DocSet point directly to their first doc. A sentinel value called TERMINATED marks the end of a DocSet. +`.advance()` returns the new DocId. `Scorer::skip(target)` has been replaced by `Scorer::seek(target)` and returns the resulting DocId. +As a result, iterating through DocSet now looks as follows +```rust +let mut doc = docset.doc(); +while doc != TERMINATED { + // ... + doc = docset.advance(); +} +``` +The change made it possible to greatly simplify a lot of the docset's code. +- Misc internal optimization and introduction of the `Scorer::for_each_pruning` function. (@fulmicoton) + Tantivy 0.12.0 ====================== diff --git a/examples/iterating_docs_and_positions.rs b/examples/iterating_docs_and_positions.rs index 0be84ec05..7883e47b8 100644 --- a/examples/iterating_docs_and_positions.rs +++ b/examples/iterating_docs_and_positions.rs @@ -10,7 +10,7 @@ // --- // Importing tantivy... use tantivy::schema::*; -use tantivy::{doc, DocId, DocSet, Index, Postings}; +use tantivy::{doc, DocSet, Index, Postings, TERMINATED}; fn main() -> tantivy::Result<()> { // We first create a schema for the sake of the @@ -62,12 +62,11 @@ fn main() -> tantivy::Result<()> { { // this buffer will be used to request for positions let mut positions: Vec = Vec::with_capacity(100); - while segment_postings.advance() { - // the number of time the term appears in the document. - let doc_id: DocId = segment_postings.doc(); //< do not try to access this before calling advance once. - + let mut doc_id = segment_postings.doc(); + while doc_id != TERMINATED { // This MAY contains deleted documents as well. if segment_reader.is_deleted(doc_id) { + doc_id = segment_postings.advance(); continue; } @@ -86,6 +85,7 @@ fn main() -> tantivy::Result<()> { // Doc 2: TermFreq 1: [0] // ``` println!("Doc {}: TermFreq {}: {:?}", doc_id, term_freq, positions); + doc_id = segment_postings.advance(); } } } diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index 800285145..b8e7f42a5 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -1,6 +1,5 @@ use crate::collector::Collector; use crate::collector::SegmentCollector; -use crate::docset::SkipResult; use crate::fastfield::FacetReader; use crate::schema::Facet; use crate::schema::Field; @@ -188,6 +187,11 @@ pub struct FacetSegmentCollector { collapse_facet_ords: Vec, } +enum SkipResult { + Found, + NotFound, +} + fn skip<'a, I: Iterator>( target: &[u8], collapse_it: &mut Peekable, @@ -197,14 +201,14 @@ fn skip<'a, I: Iterator>( Some(facet_bytes) => match facet_bytes.encoded_str().as_bytes().cmp(target) { Ordering::Less => {} Ordering::Greater => { - return SkipResult::OverStep; + return SkipResult::NotFound; } Ordering::Equal => { - return SkipResult::Reached; + return SkipResult::Found; } }, None => { - return SkipResult::End; + return SkipResult::NotFound; } } collapse_it.next(); @@ -281,7 +285,7 @@ impl Collector for FacetCollector { // is positionned on a term that has not been processed yet. let skip_result = skip(facet_streamer.key(), &mut collapse_facet_it); match skip_result { - SkipResult::Reached => { + SkipResult::Found => { // we reach a facet we decided to collapse. let collapse_depth = facet_depth(facet_streamer.key()); let mut collapsed_id = 0; @@ -301,7 +305,7 @@ impl Collector for FacetCollector { } break; } - SkipResult::End | SkipResult::OverStep => { + SkipResult::NotFound => { collapse_mapping.push(0); if !facet_streamer.advance() { break; diff --git a/src/collector/mod.rs b/src/collector/mod.rs index f32bd73a2..cd56cb4e5 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -109,6 +109,7 @@ pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker}; mod facet_collector; pub use self::facet_collector::FacetCollector; +use crate::query::Scorer; /// `Fruit` is the type for the result of our collection. /// e.g. `usize` for the `Count` collector. @@ -154,6 +155,28 @@ pub trait Collector: Sync { /// Combines the fruit associated to the collection of each segments /// into one fruit. fn merge_fruits(&self, segment_fruits: Vec) -> crate::Result; + + /// Created a segment collector and + fn collect_segment( + &self, + scorer: &mut dyn Scorer, + segment_ord: u32, + segment_reader: &SegmentReader, + ) -> crate::Result<::Fruit> { + let mut segment_collector = self.for_segment(segment_ord as u32, segment_reader)?; + if let Some(delete_bitset) = segment_reader.delete_bitset() { + scorer.for_each(&mut |doc, score| { + if delete_bitset.is_alive(doc) { + segment_collector.collect(doc, score); + } + }); + } else { + scorer.for_each(&mut |doc, score| { + segment_collector.collect(doc, score); + }) + } + Ok(segment_collector.harvest()) + } } /// The `SegmentCollector` is the trait in charge of defining the diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index f3b5742e3..e930d9024 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -18,9 +18,9 @@ use std::collections::BinaryHeap; /// Two elements are equal if their feature is equal, and regardless of whether `doc` /// is equal. This should be perfectly fine for this usage, but let's make sure this /// struct is never public. -struct ComparableDoc { - feature: T, - doc: D, +pub(crate) struct ComparableDoc { + pub feature: T, + pub doc: D, } impl PartialOrd for ComparableDoc { @@ -56,7 +56,7 @@ impl PartialEq for ComparableDoc { impl Eq for ComparableDoc {} pub(crate) struct TopCollector { - limit: usize, + pub limit: usize, _marker: PhantomData, } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 4a8cb48e0..260103e98 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -1,18 +1,21 @@ use super::Collector; use crate::collector::custom_score_top_collector::CustomScoreTopCollector; -use crate::collector::top_collector::TopCollector; use crate::collector::top_collector::TopSegmentCollector; +use crate::collector::top_collector::{ComparableDoc, TopCollector}; use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; use crate::collector::{ CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector, }; +use crate::docset::TERMINATED; use crate::fastfield::FastFieldReader; +use crate::query::Scorer; use crate::schema::Field; use crate::DocAddress; use crate::DocId; use crate::Score; use crate::SegmentLocalId; use crate::SegmentReader; +use std::collections::BinaryHeap; use std::fmt; /// The `TopDocs` collector keeps track of the top `K` documents @@ -423,6 +426,58 @@ impl Collector for TopDocs { ) -> crate::Result { self.0.merge_fruits(child_fruits) } + + fn collect_segment( + &self, + scorer: &mut dyn Scorer, + segment_ord: u32, + segment_reader: &SegmentReader, + ) -> crate::Result<::Fruit> { + let mut heap: BinaryHeap> = + BinaryHeap::with_capacity(self.0.limit); + // first we fill the heap with the first `limit` elements. + let mut doc = scorer.doc(); + while doc != TERMINATED && heap.len() < self.0.limit { + if !segment_reader.is_deleted(doc) { + let score = scorer.score(); + heap.push(ComparableDoc { + feature: score, + doc, + }); + } + doc = scorer.advance(); + } + + let threshold = heap.peek().map(|el| el.feature).unwrap_or(f32::MIN); + + if let Some(delete_bitset) = segment_reader.delete_bitset() { + scorer.for_each_pruning(threshold, &mut |doc, score| { + if delete_bitset.is_alive(doc) { + *heap.peek_mut().unwrap() = ComparableDoc { + feature: score, + doc, + }; + } + heap.peek().map(|el| el.feature).unwrap_or(f32::MIN) + }); + } else { + scorer.for_each_pruning(threshold, &mut |doc, score| { + *heap.peek_mut().unwrap() = ComparableDoc { + feature: score, + doc, + }; + heap.peek().map(|el| el.feature).unwrap_or(f32::MIN) + }); + } + + let fruit = heap + .into_sorted_vec() + .into_iter() + .map(|cid| (cid.feature, DocAddress(segment_ord, cid.doc))) + .collect(); + + Ok(fruit) + } } /// Segment Collector associated to `TopDocs`. @@ -432,7 +487,7 @@ impl SegmentCollector for TopScoreSegmentCollector { type Fruit = Vec<(Score, DocAddress)>; fn collect(&mut self, doc: DocId, score: Score) { - self.0.collect(doc, score) + self.0.collect(doc, score); } fn harvest(self) -> Vec<(Score, DocAddress)> { diff --git a/src/common/bitset.rs b/src/common/bitset.rs index 527aa8d4a..0a8d6f4de 100644 --- a/src/common/bitset.rs +++ b/src/common/bitset.rs @@ -33,6 +33,10 @@ impl TinySet { TinySet(0u64) } + pub fn clear(&mut self) { + self.0 = 0u64; + } + /// Returns the complement of the set in `[0, 64[`. fn complement(self) -> TinySet { TinySet(!self.0) @@ -43,6 +47,11 @@ impl TinySet { !self.intersect(TinySet::singleton(el)).is_empty() } + /// Returns the number of elements in the TinySet. + pub fn len(self) -> u32 { + self.0.count_ones() + } + /// Returns the intersection of `self` and `other` pub fn intersect(self, other: TinySet) -> TinySet { TinySet(self.0 & other.0) @@ -109,22 +118,12 @@ impl TinySet { pub fn range_greater_or_equal(from_included: u32) -> TinySet { TinySet::range_lower(from_included).complement() } - - pub fn clear(&mut self) { - self.0 = 0u64; - } - - pub fn len(self) -> u32 { - self.0.count_ones() - } } #[derive(Clone)] pub struct BitSet { tinysets: Box<[TinySet]>, - len: usize, //< Technically it should be u32, but we - // count multiple inserts. - // `usize` guards us from overflow. + len: usize, max_value: u32, } @@ -204,7 +203,7 @@ mod tests { use super::BitSet; use super::TinySet; - use crate::docset::DocSet; + use crate::docset::{DocSet, TERMINATED}; use crate::query::BitSetDocSet; use crate::tests; use crate::tests::generate_nonunique_unsorted; @@ -278,11 +277,13 @@ mod tests { } assert_eq!(btreeset.len(), bitset.len()); let mut bitset_docset = BitSetDocSet::from(bitset); + let mut remaining = true; for el in btreeset.into_iter() { - bitset_docset.advance(); + assert!(remaining); assert_eq!(bitset_docset.doc(), el); + remaining = bitset_docset.advance() != TERMINATED; } - assert!(!bitset_docset.advance()); + assert!(!remaining); } #[test] diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 8aa808a8e..f9fb9da4d 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -1,11 +1,8 @@ use crate::collector::Collector; -use crate::collector::SegmentCollector; use crate::core::Executor; use crate::core::InvertedIndexReader; use crate::core::SegmentReader; use crate::query::Query; -use crate::query::Scorer; -use crate::query::Weight; use crate::schema::Document; use crate::schema::Schema; use crate::schema::{Field, Term}; @@ -17,26 +14,6 @@ use crate::Index; use std::fmt; use std::sync::Arc; -fn collect_segment( - collector: &C, - weight: &dyn Weight, - segment_ord: u32, - segment_reader: &SegmentReader, -) -> crate::Result { - let mut scorer = weight.scorer(segment_reader, 1.0f32)?; - let mut segment_collector = collector.for_segment(segment_ord as u32, segment_reader)?; - if let Some(delete_bitset) = segment_reader.delete_bitset() { - scorer.for_each(&mut |doc, score| { - if delete_bitset.is_alive(doc) { - segment_collector.collect(doc, score); - } - }); - } else { - scorer.for_each(&mut |doc, score| segment_collector.collect(doc, score)); - } - Ok(segment_collector.harvest()) -} - /// Holds a list of `SegmentReader`s ready for search. /// /// It guarantees that the `Segment` will not be removed before @@ -163,12 +140,8 @@ impl Searcher { let segment_readers = self.segment_readers(); let fruits = executor.map( |(segment_ord, segment_reader)| { - collect_segment( - collector, - weight.as_ref(), - segment_ord as u32, - segment_reader, - ) + let mut scorer = weight.scorer(segment_reader, 1.0f32)?; + collector.collect_segment(scorer.as_mut(), segment_ord as u32, segment_reader) }, segment_readers.iter().enumerate(), )?; diff --git a/src/docset.rs b/src/docset.rs index f72e10225..cab25ebe1 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -1,58 +1,47 @@ -use crate::common::BitSet; use crate::fastfield::DeleteBitSet; use crate::DocId; use std::borrow::Borrow; use std::borrow::BorrowMut; -use std::cmp::Ordering; -/// Expresses the outcome of a call to `DocSet`'s `.skip_next(...)`. -#[derive(PartialEq, Eq, Debug)] -pub enum SkipResult { - /// target was in the docset - Reached, - /// target was not in the docset, skipping stopped as a greater element was found - OverStep, - /// the docset was entirely consumed without finding the target, nor any - /// element greater than the target. - End, -} +/// Sentinel value returned when a DocSet has been entirely consumed. +/// +/// This is not u32::MAX as one would have expected, due to the lack of SSE2 instructions +/// to compare [u32; 4]. +pub const TERMINATED: DocId = i32::MAX as u32; /// Represents an iterable set of sorted doc ids. pub trait DocSet { /// Goes to the next element. - /// `.advance(...)` needs to be called a first time to point to the correct - /// element. - fn advance(&mut self) -> bool; + /// + /// The DocId of the next element is returned. + /// In other words we should always have : + /// ```ignore + /// let doc = docset.advance(); + /// assert_eq!(doc, docset.doc()); + /// ``` + /// + /// If we reached the end of the DocSet, TERMINATED should be returned. + /// + /// Calling `.advance()` on a terminated DocSet should be supported, and TERMINATED should + /// be returned. + /// TODO Test existing docsets. + fn advance(&mut self) -> DocId; - /// After skipping, position the iterator in such a way that `.doc()` - /// will return a value greater than or equal to target. + /// Advances the DocSet forward until reaching the target, or going to the + /// lowest DocId greater than the target. /// - /// SkipResult expresses whether the `target value` was reached, overstepped, - /// or if the `DocSet` was entirely consumed without finding any value - /// greater or equal to the `target`. + /// If the end of the DocSet is reached, TERMINATED is returned. /// - /// WARNING: Calling skip always advances the docset. - /// More specifically, if the docset is already positionned on the target - /// skipping will advance to the next position and return SkipResult::Overstep. + /// Calling `.seek(target)` on a terminated DocSet is legal. Implementation + /// of DocSet should support it. /// - /// If `.skip_next()` oversteps, then the docset must be positionned correctly - /// on an existing document. In other words, `.doc()` should return the first document - /// greater than `DocId`. - fn skip_next(&mut self, target: DocId) -> SkipResult { - if !self.advance() { - return SkipResult::End; - } - loop { - match self.doc().cmp(&target) { - Ordering::Less => { - if !self.advance() { - return SkipResult::End; - } - } - Ordering::Equal => return SkipResult::Reached, - Ordering::Greater => return SkipResult::OverStep, - } + /// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a DocSet. + fn seek(&mut self, target: DocId) -> DocId { + let mut doc = self.doc(); + while doc < target { + doc = self.advance(); } + doc } /// Fills a given mutable buffer with the next doc ids from the @@ -71,38 +60,38 @@ pub trait DocSet { /// use case where batching. The normal way to /// go through the `DocId`'s is to call `.advance()`. fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { + if self.doc() == TERMINATED { + return 0; + } for (i, buffer_val) in buffer.iter_mut().enumerate() { - if self.advance() { - *buffer_val = self.doc(); - } else { - return i; + *buffer_val = self.doc(); + if self.advance() == TERMINATED { + return i + 1; } } buffer.len() } /// Returns the current document + /// Right after creating a new DocSet, the docset points to the first document. + /// + /// If the DocSet is empty, .doc() should return `TERMINATED`. fn doc(&self) -> DocId; /// Returns a best-effort hint of the /// length of the docset. fn size_hint(&self) -> u32; - /// Appends all docs to a `bitset`. - fn append_to_bitset(&mut self, bitset: &mut BitSet) { - while self.advance() { - bitset.insert(self.doc()); - } - } - /// Returns the number documents matching. /// Calling this method consumes the `DocSet`. fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 { let mut count = 0u32; - while self.advance() { - if !delete_bitset.is_deleted(self.doc()) { + let mut doc = self.doc(); + while doc != TERMINATED { + if !delete_bitset.is_deleted(doc) { count += 1u32; } + doc = self.advance(); } count } @@ -114,22 +103,42 @@ pub trait DocSet { /// given by `count()`. fn count_including_deleted(&mut self) -> u32 { let mut count = 0u32; - while self.advance() { + let mut doc = self.doc(); + while doc != TERMINATED { count += 1u32; + doc = self.advance(); } count } } +impl<'a> DocSet for &'a mut dyn DocSet { + fn advance(&mut self) -> u32 { + (**self).advance() + } + + fn seek(&mut self, target: DocId) -> DocId { + (**self).seek(target) + } + + fn doc(&self) -> u32 { + (**self).doc() + } + + fn size_hint(&self) -> u32 { + (**self).size_hint() + } +} + impl DocSet for Box { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.advance() } - fn skip_next(&mut self, target: DocId) -> SkipResult { + fn seek(&mut self, target: DocId) -> DocId { let unboxed: &mut TDocSet = self.borrow_mut(); - unboxed.skip_next(target) + unboxed.seek(target) } fn doc(&self) -> DocId { @@ -151,9 +160,4 @@ impl DocSet for Box { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.count_including_deleted() } - - fn append_to_bitset(&mut self, bitset: &mut BitSet) { - let unboxed: &mut TDocSet = self.borrow_mut(); - unboxed.append_to_bitset(bitset); - } } diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index a1e5eaa13..a0eb36787 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -10,7 +10,7 @@ use crate::core::SegmentMeta; use crate::core::SegmentReader; use crate::directory::TerminatingWrite; use crate::directory::{DirectoryLock, GarbageCollectionResult}; -use crate::docset::DocSet; +use crate::docset::{DocSet, TERMINATED}; use crate::error::TantivyError; use crate::fastfield::write_delete_bitset; use crate::indexer::delete_queue::{DeleteCursor, DeleteQueue}; @@ -112,15 +112,15 @@ fn compute_deleted_bitset( if let Some(mut docset) = inverted_index.read_postings(&delete_op.term, IndexRecordOption::Basic) { - while docset.advance() { - let deleted_doc = docset.doc(); + let mut deleted_doc = docset.doc(); + while deleted_doc != TERMINATED { if deleted_doc < limit_doc { delete_bitset.insert(deleted_doc); might_have_changed = true; } + deleted_doc = docset.advance(); } } - delete_cursor.advance(); } Ok(might_have_changed) diff --git a/src/indexer/merger.rs b/src/indexer/merger.rs index 3279d185f..0bdea7047 100644 --- a/src/indexer/merger.rs +++ b/src/indexer/merger.rs @@ -2,7 +2,7 @@ use crate::common::MAX_DOC_LIMIT; use crate::core::Segment; use crate::core::SegmentReader; use crate::core::SerializableSegment; -use crate::docset::DocSet; +use crate::docset::{DocSet, TERMINATED}; use crate::fastfield::BytesFastFieldReader; use crate::fastfield::DeleteBitSet; use crate::fastfield::FastFieldReader; @@ -574,10 +574,12 @@ impl IndexMerger { let inverted_index = segment_reader.inverted_index(indexed_field); let mut segment_postings = inverted_index .read_postings_from_terminfo(term_info, segment_postings_option); - while segment_postings.advance() { - if !segment_reader.is_deleted(segment_postings.doc()) { + let mut doc = segment_postings.doc(); + while doc != TERMINATED { + if !segment_reader.is_deleted(doc) { return Some((segment_ord, segment_postings)); } + doc = segment_postings.advance(); } None }) @@ -604,17 +606,9 @@ impl IndexMerger { // postings serializer. for (segment_ord, mut segment_postings) in segment_postings { let old_to_new_doc_id = &merged_doc_id_map[segment_ord]; - loop { - let doc = segment_postings.doc(); - - // `.advance()` has been called once before the loop. - // - // It was required to make sure we only consider segments - // that effectively contain at least one non-deleted document - // and remove terms that do not have documents associated. - // - // For this reason, we cannot use a `while segment_postings.advance()` loop. + let mut doc = segment_postings.doc(); + while doc != TERMINATED { // deleted doc are skipped as they do not have a `remapped_doc_id`. if let Some(remapped_doc_id) = old_to_new_doc_id[doc as usize] { // we make sure to only write the term iff @@ -629,9 +623,8 @@ impl IndexMerger { delta_positions, )?; } - if !segment_postings.advance() { - break; - } + + doc = segment_postings.advance(); } } diff --git a/src/lib.rs b/src/lib.rs index 55d71bd5f..7367050ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -156,7 +156,7 @@ mod snippet; pub use self::snippet::{Snippet, SnippetGenerator}; mod docset; -pub use self::docset::{DocSet, SkipResult}; +pub use self::docset::{DocSet, TERMINATED}; pub use crate::common::{f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64}; pub use crate::core::{Executor, SegmentComponent}; pub use crate::core::{Index, IndexMeta, Searcher, Segment, SegmentId, SegmentMeta}; @@ -285,7 +285,7 @@ mod tests { use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::core::SegmentReader; - use crate::docset::DocSet; + use crate::docset::{DocSet, TERMINATED}; use crate::query::BooleanQuery; use crate::schema::*; use crate::DocAddress; @@ -381,19 +381,12 @@ mod tests { index_writer.commit().unwrap(); } { - { - let doc = doc!(text_field=>"a"); - index_writer.add_document(doc); - } - { - let doc = doc!(text_field=>"a a"); - index_writer.add_document(doc); - } + index_writer.add_document(doc!(text_field=>"a")); + index_writer.add_document(doc!(text_field=>"a a")); index_writer.commit().unwrap(); } { - let doc = doc!(text_field=>"c"); - index_writer.add_document(doc); + index_writer.add_document(doc!(text_field=>"c")); index_writer.commit().unwrap(); } { @@ -472,10 +465,12 @@ mod tests { } fn advance_undeleted(docset: &mut dyn DocSet, reader: &SegmentReader) -> bool { - while docset.advance() { - if !reader.is_deleted(docset.doc()) { + let mut doc = docset.advance(); + while doc != TERMINATED { + if !reader.is_deleted(doc) { return true; } + doc = docset.advance(); } false } @@ -641,9 +636,8 @@ mod tests { .inverted_index(term.field()) .read_postings(&term, IndexRecordOption::Basic) .unwrap(); - assert!(postings.advance()); assert_eq!(postings.doc(), 0); - assert!(!postings.advance()); + assert_eq!(postings.advance(), TERMINATED); } #[test] @@ -665,9 +659,8 @@ mod tests { .inverted_index(term.field()) .read_postings(&term, IndexRecordOption::Basic) .unwrap(); - assert!(postings.advance()); assert_eq!(postings.doc(), 0); - assert!(!postings.advance()); + assert_eq!(postings.advance(), TERMINATED); } #[test] @@ -689,9 +682,8 @@ mod tests { .inverted_index(term.field()) .read_postings(&term, IndexRecordOption::Basic) .unwrap(); - assert!(postings.advance()); assert_eq!(postings.doc(), 0); - assert!(!postings.advance()); + assert_eq!(postings.advance(), TERMINATED); } #[test] @@ -760,10 +752,8 @@ mod tests { { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); - { - let doc = doc!(text_field=>"af af af bc bc"); - index_writer.add_document(doc); - } + let doc = doc!(text_field=>"af af af bc bc"); + index_writer.add_document(doc); index_writer.commit().unwrap(); } { @@ -779,10 +769,9 @@ mod tests { let mut postings = inverted_index .read_postings(&term_af, IndexRecordOption::WithFreqsAndPositions) .unwrap(); - assert!(postings.advance()); assert_eq!(postings.doc(), 0); assert_eq!(postings.term_freq(), 3); - assert!(!postings.advance()); + assert_eq!(postings.advance(), TERMINATED); } } diff --git a/src/postings/block_search.rs b/src/postings/block_search.rs index b5b3fca4e..010c9dd70 100644 --- a/src/postings/block_search.rs +++ b/src/postings/block_search.rs @@ -129,23 +129,23 @@ impl BlockSearcher { /// /// If SSE2 instructions are available in the `(platform, running CPU)`, /// then we use a different implementation that does an exhaustive linear search over - /// the full block whenever the block is full (`len == 128`). It is surprisingly faster, most likely because of the lack - /// of branch. + /// the block regardless of whether the block is full or not. + /// + /// Indeed, if the block is not full, the remaining items are TERMINATED. + /// It is surprisingly faster, most likely because of the lack of branch misprediction. pub(crate) fn search_in_block( self, block_docs: &AlignedBuffer, - len: usize, start: usize, target: u32, ) -> usize { #[cfg(target_arch = "x86_64")] { - use crate::postings::compression::COMPRESSION_BLOCK_SIZE; - if self == BlockSearcher::SSE2 && len == COMPRESSION_BLOCK_SIZE { + if self == BlockSearcher::SSE2 { return sse2::linear_search_sse2_128(block_docs, target); } } - start + galloping(&block_docs.0[start..len], target) + start + galloping(&block_docs.0[start..], target) } } @@ -166,6 +166,7 @@ mod tests { use super::exponential_search; use super::linear_search; use super::BlockSearcher; + use crate::docset::TERMINATED; use crate::postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE}; #[test] @@ -196,16 +197,11 @@ mod tests { fn util_test_search_in_block(block_searcher: BlockSearcher, block: &[u32], target: u32) { let cursor = search_in_block_trivial_but_slow(block, target); assert!(block.len() < COMPRESSION_BLOCK_SIZE); - let mut output_buffer = [u32::max_value(); COMPRESSION_BLOCK_SIZE]; + let mut output_buffer = [TERMINATED; COMPRESSION_BLOCK_SIZE]; output_buffer[..block.len()].copy_from_slice(block); for i in 0..cursor { assert_eq!( - block_searcher.search_in_block( - &AlignedBuffer(output_buffer), - block.len(), - i, - target - ), + block_searcher.search_in_block(&AlignedBuffer(output_buffer), i, target), cursor ); } diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index 342ba80cc..00648d3f0 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -1,4 +1,5 @@ use crate::common::FixedSize; +use crate::docset::TERMINATED; use bitpacking::{BitPacker, BitPacker4x}; pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN; @@ -90,14 +91,18 @@ impl BlockDecoder { } #[inline] - pub(crate) fn output_aligned(&self) -> (&AlignedBuffer, usize) { - (&self.output, self.output_len) + pub(crate) fn output_aligned(&self) -> &AlignedBuffer { + &self.output } #[inline] pub fn output(&self, idx: usize) -> u32 { self.output.0[idx] } + + pub fn clear(&mut self) { + self.output.0.iter_mut().for_each(|el| *el = TERMINATED); + } } pub trait VIntEncoder { @@ -134,9 +139,9 @@ pub trait VIntDecoder { /// For instance, if delta encoded are `1, 3, 9`, and the /// `offset` is 5, then the output will be: /// `5 + 1 = 6, 6 + 3= 9, 9 + 9 = 18` - fn uncompress_vint_sorted<'a>( + fn uncompress_vint_sorted( &mut self, - compressed_data: &'a [u8], + compressed_data: &[u8], offset: u32, num_els: usize, ) -> usize; @@ -146,7 +151,7 @@ pub trait VIntDecoder { /// /// The method takes a number of int to decompress, and returns /// the amount of bytes that were read to decompress them. - fn uncompress_vint_unsorted<'a>(&mut self, compressed_data: &'a [u8], num_els: usize) -> usize; + fn uncompress_vint_unsorted(&mut self, compressed_data: &[u8], num_els: usize) -> usize; } impl VIntEncoder for BlockEncoder { @@ -160,9 +165,9 @@ impl VIntEncoder for BlockEncoder { } impl VIntDecoder for BlockDecoder { - fn uncompress_vint_sorted<'a>( + fn uncompress_vint_sorted( &mut self, - compressed_data: &'a [u8], + compressed_data: &[u8], offset: u32, num_els: usize, ) -> usize { diff --git a/src/postings/compression/vint.rs b/src/postings/compression/vint.rs index 87a672e64..3de43749f 100644 --- a/src/postings/compression/vint.rs +++ b/src/postings/compression/vint.rs @@ -42,7 +42,7 @@ pub(crate) fn compress_unsorted<'a>(input: &[u32], output: &'a mut [u8]) -> &'a } #[inline(always)] -pub fn uncompress_sorted<'a>(compressed_data: &'a [u8], output: &mut [u32], offset: u32) -> usize { +pub fn uncompress_sorted(compressed_data: &[u8], output: &mut [u32], offset: u32) -> usize { let mut read_byte = 0; let mut result = offset; for output_mut in output.iter_mut() { diff --git a/src/postings/mod.rs b/src/postings/mod.rs index b66beb413..a01cadfc7 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -51,7 +51,7 @@ pub mod tests { use crate::core::Index; use crate::core::SegmentComponent; use crate::core::SegmentReader; - use crate::docset::{DocSet, SkipResult}; + use crate::docset::{DocSet, TERMINATED}; use crate::fieldnorm::FieldNormReader; use crate::indexer::operation::AddOperation; use crate::indexer::SegmentWriter; @@ -115,29 +115,12 @@ pub mod tests { let mut postings = inverted_index .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .unwrap(); - postings.advance(); + assert_eq!(postings.doc(), 0); postings.positions(&mut positions); assert_eq!(&[0, 1, 2], &positions[..]); postings.positions(&mut positions); assert_eq!(&[0, 1, 2], &positions[..]); - postings.advance(); - postings.positions(&mut positions); - assert_eq!(&[0, 5], &positions[..]); - } - { - let mut postings = inverted_index - .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) - .unwrap(); - postings.advance(); - postings.advance(); - postings.positions(&mut positions); - assert_eq!(&[0, 5], &positions[..]); - } - { - let mut postings = inverted_index - .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) - .unwrap(); - assert_eq!(postings.skip_next(1), SkipResult::Reached); + assert_eq!(postings.advance(), 1); assert_eq!(postings.doc(), 1); postings.positions(&mut positions); assert_eq!(&[0, 5], &positions[..]); @@ -146,7 +129,25 @@ pub mod tests { let mut postings = inverted_index .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .unwrap(); - assert_eq!(postings.skip_next(1002), SkipResult::Reached); + assert_eq!(postings.doc(), 0); + assert_eq!(postings.advance(), 1); + postings.positions(&mut positions); + assert_eq!(&[0, 5], &positions[..]); + } + { + let mut postings = inverted_index + .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) + .unwrap(); + assert_eq!(postings.seek(1), 1); + assert_eq!(postings.doc(), 1); + postings.positions(&mut positions); + assert_eq!(&[0, 5], &positions[..]); + } + { + let mut postings = inverted_index + .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) + .unwrap(); + assert_eq!(postings.seek(1002), 1002); assert_eq!(postings.doc(), 1002); postings.positions(&mut positions); assert_eq!(&[0, 5], &positions[..]); @@ -155,8 +156,8 @@ pub mod tests { let mut postings = inverted_index .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .unwrap(); - assert_eq!(postings.skip_next(100), SkipResult::Reached); - assert_eq!(postings.skip_next(1002), SkipResult::Reached); + assert_eq!(postings.seek(100), 100); + assert_eq!(postings.seek(1002), 1002); assert_eq!(postings.doc(), 1002); postings.positions(&mut positions); assert_eq!(&[0, 5], &positions[..]); @@ -281,22 +282,21 @@ pub mod tests { .read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions) .unwrap(); assert_eq!(postings_a.len(), 1000); - assert!(postings_a.advance()); assert_eq!(postings_a.doc(), 0); assert_eq!(postings_a.term_freq(), 6); postings_a.positions(&mut positions); assert_eq!(&positions[..], [0, 2, 4, 6, 7, 13]); - assert!(postings_a.advance()); + assert_eq!(postings_a.advance(), 1u32); assert_eq!(postings_a.doc(), 1u32); assert_eq!(postings_a.term_freq(), 1); for i in 2u32..1000u32 { - assert!(postings_a.advance()); + assert_eq!(postings_a.advance(), i); assert_eq!(postings_a.term_freq(), 1); postings_a.positions(&mut positions); assert_eq!(&positions[..], [i]); assert_eq!(postings_a.doc(), i); } - assert!(!postings_a.advance()); + assert_eq!(postings_a.advance(), TERMINATED); } { let term_e = Term::from_field_text(text_field, "e"); @@ -306,7 +306,6 @@ pub mod tests { .unwrap(); assert_eq!(postings_e.len(), 1000 - 2); for i in 2u32..1000u32 { - assert!(postings_e.advance()); assert_eq!(postings_e.term_freq(), i); postings_e.positions(&mut positions); assert_eq!(positions.len(), i as usize); @@ -314,8 +313,9 @@ pub mod tests { assert_eq!(positions[j], (j as u32)); } assert_eq!(postings_e.doc(), i); + postings_e.advance(); } - assert!(!postings_e.advance()); + assert_eq!(postings_e.doc(), TERMINATED); } } } @@ -329,16 +329,8 @@ pub mod tests { let index = Index::create_in_ram(schema); { let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); - { - let mut doc = Document::default(); - doc.add_text(text_field, "g b b d c g c"); - index_writer.add_document(doc); - } - { - let mut doc = Document::default(); - doc.add_text(text_field, "g a b b a d c g c"); - index_writer.add_document(doc); - } + index_writer.add_document(doc!(text_field => "g b b d c g c")); + index_writer.add_document(doc!(text_field => "g a b b a d c g c")); assert!(index_writer.commit().is_ok()); } let term_a = Term::from_field_text(text_field, "a"); @@ -348,7 +340,6 @@ pub mod tests { .inverted_index(text_field) .read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions) .unwrap(); - assert!(postings.advance()); assert_eq!(postings.doc(), 1u32); postings.positions(&mut positions); assert_eq!(&positions[..], &[1u32, 4]); @@ -370,11 +361,8 @@ pub mod tests { let index = Index::create_in_ram(schema); { let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); - for i in 0..num_docs { - let mut doc = Document::default(); - doc.add_u64(value_field, 2); - doc.add_u64(value_field, (i % 2) as u64); - + for i in 0u64..num_docs as u64 { + let doc = doc!(value_field => 2u64, value_field => i % 2u64); index_writer.add_document(doc); } assert!(index_writer.commit().is_ok()); @@ -391,11 +379,10 @@ pub mod tests { .inverted_index(term_2.field()) .read_postings(&term_2, IndexRecordOption::Basic) .unwrap(); - - assert_eq!(segment_postings.skip_next(i), SkipResult::Reached); + assert_eq!(segment_postings.seek(i), i); assert_eq!(segment_postings.doc(), i); - assert_eq!(segment_postings.skip_next(j), SkipResult::Reached); + assert_eq!(segment_postings.seek(j), j); assert_eq!(segment_postings.doc(), j); } } @@ -407,17 +394,16 @@ pub mod tests { .unwrap(); // check that `skip_next` advances the iterator - assert!(segment_postings.advance()); assert_eq!(segment_postings.doc(), 0); - assert_eq!(segment_postings.skip_next(1), SkipResult::Reached); + assert_eq!(segment_postings.seek(1), 1); assert_eq!(segment_postings.doc(), 1); - assert_eq!(segment_postings.skip_next(1), SkipResult::OverStep); - assert_eq!(segment_postings.doc(), 2); + assert_eq!(segment_postings.seek(1), 1); + assert_eq!(segment_postings.doc(), 1); // check that going beyond the end is handled - assert_eq!(segment_postings.skip_next(num_docs), SkipResult::End); + assert_eq!(segment_postings.seek(num_docs), TERMINATED); } // check that filtering works @@ -428,7 +414,7 @@ pub mod tests { .unwrap(); for i in 0..num_docs / 2 { - assert_eq!(segment_postings.skip_next(i * 2), SkipResult::Reached); + assert_eq!(segment_postings.seek(i * 2), i * 2); assert_eq!(segment_postings.doc(), i * 2); } @@ -438,7 +424,7 @@ pub mod tests { .unwrap(); for i in 0..num_docs / 2 - 1 { - assert_eq!(segment_postings.skip_next(i * 2 + 1), SkipResult::OverStep); + assert!(segment_postings.seek(i * 2 + 1) > (i * 1) * 2); assert_eq!(segment_postings.doc(), (i + 1) * 2); } } @@ -450,6 +436,7 @@ pub mod tests { assert!(index_writer.commit().is_ok()); } let searcher = index.reader().unwrap().searcher(); + assert_eq!(searcher.segment_readers().len(), 1); let segment_reader = searcher.segment_reader(0); // make sure seeking still works @@ -460,11 +447,11 @@ pub mod tests { .unwrap(); if i % 2 == 0 { - assert_eq!(segment_postings.skip_next(i), SkipResult::Reached); + assert_eq!(segment_postings.seek(i), i); assert_eq!(segment_postings.doc(), i); assert!(segment_reader.is_deleted(i)); } else { - assert_eq!(segment_postings.skip_next(i), SkipResult::Reached); + assert_eq!(segment_postings.seek(i), i); assert_eq!(segment_postings.doc(), i); } } @@ -479,12 +466,16 @@ pub mod tests { let mut last = 2; // start from 5 to avoid seeking to 3 twice let mut cur = 3; loop { - match segment_postings.skip_next(cur) { - SkipResult::End => break, - SkipResult::Reached => assert_eq!(segment_postings.doc(), cur), - SkipResult::OverStep => assert_eq!(segment_postings.doc(), cur + 1), + let seek = segment_postings.seek(cur); + if seek == TERMINATED { + break; + } + assert_eq!(seek, segment_postings.doc()); + if seek == cur { + assert_eq!(segment_postings.doc(), cur); + } else { + assert_eq!(segment_postings.doc(), cur + 1); } - let next = cur + last; last = cur; cur = next; @@ -570,7 +561,7 @@ pub mod tests { } impl DocSet for UnoptimizedDocSet { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { self.0.advance() } @@ -596,30 +587,22 @@ pub mod tests { for target in targets { let mut postings_opt = postings_factory(); let mut postings_unopt = UnoptimizedDocSet::wrap(postings_factory()); - let skip_result_opt = postings_opt.skip_next(target); - let skip_result_unopt = postings_unopt.skip_next(target); + let skip_result_opt = postings_opt.seek(target); + let skip_result_unopt = postings_unopt.seek(target); assert_eq!( skip_result_unopt, skip_result_opt, "Failed while skipping to {}", target ); - match skip_result_opt { - SkipResult::Reached => assert_eq!(postings_opt.doc(), target), - SkipResult::OverStep => assert!(postings_opt.doc() > target), - SkipResult::End => { - return; - } + assert!(skip_result_opt >= target); + assert_eq!(skip_result_opt, postings_opt.doc()); + if skip_result_opt == TERMINATED { + return; } - while postings_opt.advance() { - assert!(postings_unopt.advance()); - assert_eq!( - postings_opt.doc(), - postings_unopt.doc(), - "Failed while skipping to {}", - target - ); + while postings_opt.doc() != TERMINATED { + assert_eq!(postings_opt.doc(), postings_unopt.doc()); + assert_eq!(postings_opt.advance(), postings_unopt.advance()); } - assert!(!postings_unopt.advance()); } } } @@ -628,7 +611,7 @@ pub mod tests { mod bench { use super::tests::*; - use crate::docset::SkipResult; + use crate::docset::TERMINATED; use crate::query::Intersection; use crate::schema::IndexRecordOption; use crate::tests; @@ -646,7 +629,7 @@ mod bench { .inverted_index(TERM_A.field()) .read_postings(&*TERM_A, IndexRecordOption::Basic) .unwrap(); - while segment_postings.advance() {} + while segment_postings.advance() != TERMINATED {} }); } @@ -678,7 +661,7 @@ mod bench { segment_postings_c, segment_postings_d, ]); - while intersection.advance() {} + while intersection.advance() != TERMINATED {} }); } @@ -694,11 +677,10 @@ mod bench { .unwrap(); let mut existing_docs = Vec::new(); - segment_postings.advance(); for doc in &docs { if *doc >= segment_postings.doc() { existing_docs.push(*doc); - if segment_postings.skip_next(*doc) == SkipResult::End { + if segment_postings.seek(*doc) == TERMINATED { break; } } @@ -710,7 +692,7 @@ mod bench { .read_postings(&*TERM_A, IndexRecordOption::Basic) .unwrap(); for doc in &existing_docs { - if segment_postings.skip_next(*doc) == SkipResult::End { + if segment_postings.seek(*doc) == TERMINATED { break; } } @@ -749,8 +731,9 @@ mod bench { .read_postings(&*TERM_A, IndexRecordOption::Basic) .unwrap(); let mut s = 0u32; - while segment_postings.advance() { + while segment_postings.doc() != TERMINATED { s += (segment_postings.doc() & n) % 1024; + segment_postings.advance() } s }); diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index f70e5f429..b1c1ed75c 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -1,7 +1,6 @@ -use crate::common::BitSet; use crate::common::HasLen; use crate::common::{BinarySerializable, VInt}; -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::positions::PositionReader; use crate::postings::compression::{compressed_block_size, AlignedBuffer}; use crate::postings::compression::{BlockDecoder, VIntDecoder, COMPRESSION_BLOCK_SIZE}; @@ -14,7 +13,6 @@ use crate::postings::USE_SKIP_INFO_LIMIT; use crate::schema::IndexRecordOption; use crate::DocId; use owned_read::OwnedRead; -use std::cmp::Ordering; use tantivy_fst::Streamer; struct PositionComputer { @@ -68,12 +66,14 @@ impl SegmentPostings { /// Returns an empty segment postings object pub fn empty() -> Self { let empty_block_cursor = BlockSegmentPostings::empty(); - SegmentPostings { + let mut segment_postings = SegmentPostings { block_cursor: empty_block_cursor, cur: COMPRESSION_BLOCK_SIZE, position_computer: None, block_searcher: BlockSearcher::default(), - } + }; + segment_postings.advance(); + segment_postings } /// Creates a segment postings object with the given documents @@ -116,12 +116,14 @@ impl SegmentPostings { segment_block_postings: BlockSegmentPostings, positions_stream_opt: Option, ) -> SegmentPostings { - SegmentPostings { + let mut postings = SegmentPostings { block_cursor: segment_block_postings, cur: COMPRESSION_BLOCK_SIZE, // cursor within the block position_computer: positions_stream_opt.map(PositionComputer::new), block_searcher: BlockSearcher::default(), - } + }; + postings.advance(); + postings } } @@ -129,7 +131,7 @@ impl DocSet for SegmentPostings { // goes to the next element. // next needs to be called a first time to point to the correct element. #[inline] - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { if self.position_computer.is_some() && self.cur < COMPRESSION_BLOCK_SIZE { let term_freq = self.term_freq() as usize; if let Some(position_computer) = self.position_computer.as_mut() { @@ -138,29 +140,19 @@ impl DocSet for SegmentPostings { } self.cur += 1; if self.cur >= self.block_cursor.block_len() { - self.cur = 0; - if !self.block_cursor.advance() { - self.cur = COMPRESSION_BLOCK_SIZE; - return false; + if self.block_cursor.advance() { + self.cur = 0; + } else { + self.cur = COMPRESSION_BLOCK_SIZE - 1; + return TERMINATED; } } - true + self.doc() } - fn skip_next(&mut self, target: DocId) -> SkipResult { - if !self.advance() { - return SkipResult::End; - } - match self.doc().cmp(&target) { - Ordering::Equal => { - return SkipResult::Reached; - } - Ordering::Greater => { - return SkipResult::OverStep; - } - _ => { - // ... - } + fn seek(&mut self, target: DocId) -> DocId { + if self.doc() >= target { + return self.doc(); } // In the following, thanks to the call to advance above, @@ -170,44 +162,44 @@ impl DocSet for SegmentPostings { // skip blocks until one that might contain the target // check if we need to go to the next block let mut sum_freqs_skipped: u32 = 0; - if !self + if self .block_cursor .docs() .last() - .map(|doc| *doc >= target) - .unwrap_or(false) - // there should always be at least a document in the block - // since advance returned. + .map(|&doc| doc < target) + .unwrap_or(true) { - // we are not in the right block. - // - // First compute all of the freqs skipped from the current block. + // We are not in the right block. if self.position_computer.is_some() { - sum_freqs_skipped = self.block_cursor.freqs()[self.cur..].iter().sum(); + // First compute all of the freqs skipped from the current block. + sum_freqs_skipped = self.block_cursor.freqs()[self.cur..].iter().sum::(); match self.block_cursor.skip_to(target) { BlockSegmentPostingsSkipResult::Success(block_skip_freqs) => { sum_freqs_skipped += block_skip_freqs; } BlockSegmentPostingsSkipResult::Terminated => { - return SkipResult::End; + self.block_cursor.doc_decoder.clear(); + self.cur = 0; + return TERMINATED; } } } else if self.block_cursor.skip_to(target) == BlockSegmentPostingsSkipResult::Terminated { // no positions needed. no need to sum freqs. - return SkipResult::End; + self.block_cursor.doc_decoder.clear(); + self.cur = 0; + return TERMINATED; } self.cur = 0; } + // At this point we are on the block, that might contain our document. + let cur = self.cur; - // we're in the right block now, start with an exponential search - let (output, len) = self.block_cursor.docs_aligned(); - let new_cur = self - .block_searcher - .search_in_block(&output, len, cur, target); + let output = self.block_cursor.docs_aligned(); + let new_cur = self.block_searcher.search_in_block(&output, cur, target); if let Some(position_computer) = self.position_computer.as_mut() { sum_freqs_skipped += self.block_cursor.freqs()[cur..new_cur].iter().sum::(); position_computer.add_skip(sum_freqs_skipped as usize); @@ -217,11 +209,7 @@ impl DocSet for SegmentPostings { // `doc` is now the first element >= `target` let doc = output.0[new_cur]; debug_assert!(doc >= target); - if doc == target { - SkipResult::Reached - } else { - SkipResult::OverStep - } + doc } /// Return the current document's `DocId`. @@ -231,18 +219,14 @@ impl DocSet for SegmentPostings { /// Will panics if called without having called advance before. #[inline] fn doc(&self) -> DocId { - let docs = self.block_cursor.docs(); - debug_assert!( - self.cur < docs.len(), - "Have you forgotten to call `.advance()` at least once before calling `.doc()` ." - ); - docs[self.cur] + self.block_cursor.doc(self.cur) } fn size_hint(&self) -> u32 { self.len() as u32 } + /* fn append_to_bitset(&mut self, bitset: &mut BitSet) { // finish the current block if self.advance() { @@ -257,6 +241,7 @@ impl DocSet for SegmentPostings { } } } + */ } impl HasLen for SegmentPostings { @@ -324,15 +309,14 @@ fn split_into_skips_and_postings( doc_freq: u32, mut data: OwnedRead, ) -> (Option, OwnedRead) { - if doc_freq >= USE_SKIP_INFO_LIMIT { - let skip_len = VInt::deserialize(&mut data).expect("Data corrupted").0 as usize; - let mut postings_data = data.clone(); - postings_data.advance(skip_len); - data.clip(skip_len); - (Some(data), postings_data) - } else { - (None, data) + if doc_freq < USE_SKIP_INFO_LIMIT { + return (None, data); } + let skip_len = VInt::deserialize(&mut data).expect("Data corrupted").0 as usize; + let mut postings_data = data.clone(); + postings_data.advance(skip_len); + data.clip(skip_len); + (Some(data), postings_data) } #[derive(Debug, Eq, PartialEq)] @@ -414,7 +398,7 @@ impl BlockSegmentPostings { self.doc_decoder.output_array() } - pub(crate) fn docs_aligned(&self) -> (&AlignedBuffer, usize) { + pub(crate) fn docs_aligned(&self) -> &AlignedBuffer { self.doc_decoder.output_aligned() } @@ -495,35 +479,36 @@ impl BlockSegmentPostings { } } - // we are now on the last, incomplete, variable encoded block. - if self.num_vint_docs > 0 { - let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted( - self.remaining_data.as_ref(), - self.doc_offset, - self.num_vint_docs, - ); - self.remaining_data.advance(num_compressed_bytes); - match self.freq_reading_option { - FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {} - FreqReadingOption::ReadFreq => { - self.freq_decoder - .uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs); - } - } - self.num_vint_docs = 0; - return self - .docs() - .last() - .map(|last_doc| { - if *last_doc >= target_doc { - BlockSegmentPostingsSkipResult::Success(skip_freqs) - } else { - BlockSegmentPostingsSkipResult::Terminated - } - }) - .unwrap_or(BlockSegmentPostingsSkipResult::Terminated); + self.doc_decoder.clear(); + + if self.num_vint_docs == 0 { + return BlockSegmentPostingsSkipResult::Terminated; } - BlockSegmentPostingsSkipResult::Terminated + // we are now on the last, incomplete, variable encoded block. + let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted( + self.remaining_data.as_ref(), + self.doc_offset, + self.num_vint_docs, + ); + self.remaining_data.advance(num_compressed_bytes); + match self.freq_reading_option { + FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {} + FreqReadingOption::ReadFreq => { + self.freq_decoder + .uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs); + } + } + self.num_vint_docs = 0; + self.docs() + .last() + .map(|last_doc| { + if *last_doc >= target_doc { + BlockSegmentPostingsSkipResult::Success(skip_freqs) + } else { + BlockSegmentPostingsSkipResult::Terminated + } + }) + .unwrap_or(BlockSegmentPostingsSkipResult::Terminated) } /// Advance to the next block. @@ -554,26 +539,27 @@ impl BlockSegmentPostings { } // it will be used as the next offset. self.doc_offset = self.doc_decoder.output(COMPRESSION_BLOCK_SIZE - 1); - true - } else if self.num_vint_docs > 0 { - let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted( - self.remaining_data.as_ref(), - self.doc_offset, - self.num_vint_docs, - ); - self.remaining_data.advance(num_compressed_bytes); - match self.freq_reading_option { - FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {} - FreqReadingOption::ReadFreq => { - self.freq_decoder - .uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs); - } - } - self.num_vint_docs = 0; - true - } else { - false + return true; } + self.doc_decoder.clear(); + if self.num_vint_docs == 0 { + return false; + } + let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted( + self.remaining_data.as_ref(), + self.doc_offset, + self.num_vint_docs, + ); + self.remaining_data.advance(num_compressed_bytes); + match self.freq_reading_option { + FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {} + FreqReadingOption::ReadFreq => { + self.freq_decoder + .uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs); + } + } + self.num_vint_docs = 0; + true } /// Returns an empty segment postings object @@ -613,34 +599,34 @@ mod tests { use super::SegmentPostings; use crate::common::HasLen; use crate::core::Index; - use crate::docset::DocSet; + use crate::docset::{DocSet, TERMINATED}; use crate::postings::postings::Postings; use crate::schema::IndexRecordOption; use crate::schema::Schema; use crate::schema::Term; use crate::schema::INDEXED; use crate::DocId; - use crate::SkipResult; use tantivy_fst::Streamer; #[test] fn test_empty_segment_postings() { let mut postings = SegmentPostings::empty(); - assert!(!postings.advance()); - assert!(!postings.advance()); + assert_eq!(postings.advance(), TERMINATED); + assert_eq!(postings.advance(), TERMINATED); assert_eq!(postings.len(), 0); } #[test] - #[should_panic(expected = "Have you forgotten to call `.advance()`")] - fn test_panic_if_doc_called_before_advance() { - SegmentPostings::empty().doc(); + fn test_empty_postings_doc_returns_terminated() { + let mut postings = SegmentPostings::empty(); + assert_eq!(postings.doc(), TERMINATED); + assert_eq!(postings.advance(), TERMINATED); } #[test] - #[should_panic(expected = "Have you forgotten to call `.advance()`")] - fn test_panic_if_freq_called_before_advance() { - SegmentPostings::empty().term_freq(); + fn test_empty_postings_doc_term_freq_returns_0() { + let postings = SegmentPostings::empty(); + assert_eq!(postings.term_freq(), 1); } #[test] @@ -674,25 +660,27 @@ mod tests { { let block_segments = build_block_postings(&doc_ids); let mut docset = SegmentPostings::from_block_postings(block_segments, None); - assert_eq!(docset.skip_next(128), SkipResult::OverStep); + assert_eq!(docset.seek(128), 129); assert_eq!(docset.doc(), 129); - assert!(docset.advance()); + assert_eq!(docset.advance(), 130); assert_eq!(docset.doc(), 130); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let block_segments = build_block_postings(&doc_ids); let mut docset = SegmentPostings::from_block_postings(block_segments, None); - assert_eq!(docset.skip_next(129), SkipResult::Reached); + assert_eq!(docset.seek(129), 129); assert_eq!(docset.doc(), 129); - assert!(docset.advance()); + assert_eq!(docset.advance(), 130); assert_eq!(docset.doc(), 130); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let block_segments = build_block_postings(&doc_ids); let mut docset = SegmentPostings::from_block_postings(block_segments, None); - assert_eq!(docset.skip_next(131), SkipResult::End); + assert_eq!(docset.doc(), 0); + assert_eq!(docset.seek(131), TERMINATED); + assert_eq!(docset.doc(), TERMINATED); } } diff --git a/src/query/all_query.rs b/src/query/all_query.rs index fb9380dd8..d0b45105e 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -1,6 +1,6 @@ use crate::core::Searcher; use crate::core::SegmentReader; -use crate::docset::DocSet; +use crate::docset::{DocSet, TERMINATED}; use crate::query::boost_query::BoostScorer; use crate::query::explanation::does_not_match; use crate::query::{Explanation, Query, Scorer, Weight}; @@ -25,7 +25,6 @@ pub struct AllWeight; impl Weight for AllWeight { fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result> { let all_scorer = AllScorer { - state: State::NotStarted, doc: 0u32, max_doc: reader.max_doc(), }; @@ -40,39 +39,20 @@ impl Weight for AllWeight { } } -enum State { - NotStarted, - Started, - Finished, -} - /// Scorer associated to the `AllQuery` query. pub struct AllScorer { - state: State, doc: DocId, max_doc: DocId, } impl DocSet for AllScorer { - fn advance(&mut self) -> bool { - match self.state { - State::NotStarted => { - self.state = State::Started; - self.doc = 0; - } - State::Started => { - self.doc += 1u32; - } - State::Finished => { - return false; - } - } - if self.doc < self.max_doc { - true - } else { - self.state = State::Finished; - false + fn advance(&mut self) -> DocId { + if self.doc + 1 >= self.max_doc { + self.doc = TERMINATED; + return TERMINATED; } + self.doc += 1; + self.doc } fn doc(&self) -> DocId { @@ -93,6 +73,7 @@ impl Scorer for AllScorer { #[cfg(test)] mod tests { use super::AllQuery; + use crate::docset::TERMINATED; use crate::query::Query; use crate::schema::{Schema, TEXT}; use crate::Index; @@ -120,18 +101,16 @@ mod tests { { let reader = searcher.segment_reader(0); let mut scorer = weight.scorer(reader, 1.0f32).unwrap(); - assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); - assert!(scorer.advance()); + assert_eq!(scorer.advance(), 1u32); assert_eq!(scorer.doc(), 1u32); - assert!(!scorer.advance()); + assert_eq!(scorer.advance(), TERMINATED); } { let reader = searcher.segment_reader(1); let mut scorer = weight.scorer(reader, 1.0f32).unwrap(); - assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); - assert!(!scorer.advance()); + assert_eq!(scorer.advance(), TERMINATED); } } @@ -144,13 +123,11 @@ mod tests { let reader = searcher.segment_reader(0); { let mut scorer = weight.scorer(reader, 2.0f32).unwrap(); - assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.score(), 2.0f32); } { let mut scorer = weight.scorer(reader, 1.5f32).unwrap(); - assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.score(), 1.5f32); } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index 5f315190f..bf5a0f8df 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -6,8 +6,8 @@ use crate::query::{Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; use crate::termdict::{TermDictionary, TermStreamer}; use crate::DocId; +use crate::Result; use crate::TantivyError; -use crate::{Result, SkipResult}; use std::sync::Arc; use tantivy_fst::Automaton; @@ -64,7 +64,7 @@ where fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { let mut scorer = self.scorer(reader, 1.0f32)?; - if scorer.skip_next(doc) == SkipResult::Reached { + if scorer.seek(doc) == doc { Ok(Explanation::new("AutomatonScorer", 1.0f32)) } else { Err(TantivyError::InvalidArgument( @@ -77,6 +77,7 @@ where #[cfg(test)] mod tests { use super::AutomatonWeight; + use crate::docset::TERMINATED; use crate::query::Weight; use crate::schema::{Schema, STRING}; use crate::Index; @@ -141,13 +142,12 @@ mod tests { let mut scorer = automaton_weight .scorer(searcher.segment_reader(0u32), 1.0f32) .unwrap(); - assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.score(), 1.0f32); - assert!(scorer.advance()); + assert_eq!(scorer.advance(), 2u32); assert_eq!(scorer.doc(), 2u32); assert_eq!(scorer.score(), 1.0f32); - assert!(!scorer.advance()); + assert_eq!(scorer.advance(), TERMINATED); } #[test] @@ -160,7 +160,6 @@ mod tests { let mut scorer = automaton_weight .scorer(searcher.segment_reader(0u32), 1.32f32) .unwrap(); - assert!(scorer.advance()); assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.score(), 1.32f32); } diff --git a/src/query/bitset/mod.rs b/src/query/bitset/mod.rs index b2e3fe84e..cdfea5775 100644 --- a/src/query/bitset/mod.rs +++ b/src/query/bitset/mod.rs @@ -1,7 +1,6 @@ use crate::common::{BitSet, TinySet}; -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::DocId; -use std::cmp::Ordering; /// A `BitSetDocSet` makes it possible to iterate through a bitset as if it was a `DocSet`. /// @@ -33,75 +32,50 @@ impl From for BitSetDocSet { } else { docs.tinyset(0) }; - BitSetDocSet { + let mut docset = BitSetDocSet { docs, cursor_bucket: 0, cursor_tinybitset: first_tiny_bitset, doc: 0u32, - } + }; + docset.advance(); + docset } } impl DocSet for BitSetDocSet { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { if let Some(lower) = self.cursor_tinybitset.pop_lowest() { self.doc = (self.cursor_bucket as u32 * 64u32) | lower; - return true; + return self.doc; } if let Some(cursor_bucket) = self.docs.first_non_empty_bucket(self.cursor_bucket + 1) { self.go_to_bucket(cursor_bucket); let lower = self.cursor_tinybitset.pop_lowest().unwrap(); self.doc = (cursor_bucket * 64u32) | lower; - true + self.doc } else { - false + self.doc = TERMINATED; + TERMINATED } } - fn skip_next(&mut self, target: DocId) -> SkipResult { - // skip is required to advance. - if !self.advance() { - return SkipResult::End; - } + fn seek(&mut self, target: DocId) -> DocId { let target_bucket = target / 64u32; // Mask for all of the bits greater or equal // to our target document. - match target_bucket.cmp(&self.cursor_bucket) { - Ordering::Greater => { - self.go_to_bucket(target_bucket); - let greater_filter: TinySet = TinySet::range_greater_or_equal(target); - self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter); - if !self.advance() { - SkipResult::End - } else if self.doc() == target { - SkipResult::Reached - } else { - debug_assert!(self.doc() > target); - SkipResult::OverStep - } - } - Ordering::Equal => loop { - match self.doc().cmp(&target) { - Ordering::Less => { - if !self.advance() { - return SkipResult::End; - } - } - Ordering::Equal => { - return SkipResult::Reached; - } - Ordering::Greater => { - debug_assert!(self.doc() > target); - return SkipResult::OverStep; - } - } - }, - Ordering::Less => { - debug_assert!(self.doc() > target); - SkipResult::OverStep - } + if target_bucket > self.cursor_bucket { + self.go_to_bucket(target_bucket); + let greater_filter: TinySet = TinySet::range_greater_or_equal(target); + self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter); + self.advance(); } + let mut doc = self.doc(); + while doc < target { + doc = self.advance(); + } + doc } /// Returns the current document @@ -122,7 +96,7 @@ impl DocSet for BitSetDocSet { mod tests { use super::BitSetDocSet; use crate::common::BitSet; - use crate::docset::{DocSet, SkipResult}; + use crate::docset::{DocSet, TERMINATED}; use crate::DocId; fn create_docbitset(docs: &[DocId], max_doc: DocId) -> BitSetDocSet { @@ -133,19 +107,24 @@ mod tests { BitSetDocSet::from(docset) } + #[test] + fn test_empty() { + let bitset = BitSet::with_max_value(1000); + let mut empty = BitSetDocSet::from(bitset); + assert_eq!(empty.advance(), TERMINATED) + } + fn test_go_through_sequential(docs: &[DocId]) { let mut docset = create_docbitset(docs, 1_000u32); for &doc in docs { - assert!(docset.advance()); assert_eq!(doc, docset.doc()); + docset.advance(); } - assert!(!docset.advance()); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } #[test] fn test_docbitset_sequential() { - test_go_through_sequential(&[]); test_go_through_sequential(&[1, 2, 3]); test_go_through_sequential(&[1, 2, 3, 4, 5, 63, 64, 65]); test_go_through_sequential(&[63, 64, 65]); @@ -156,64 +135,64 @@ mod tests { fn test_docbitset_skip() { { let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000); - assert_eq!(docset.skip_next(7), SkipResult::Reached); + assert_eq!(docset.seek(7), 7); assert_eq!(docset.doc(), 7); - assert!(docset.advance(), 7); + assert_eq!(docset.advance(), 5112); assert_eq!(docset.doc(), 5112); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000); - assert_eq!(docset.skip_next(3), SkipResult::OverStep); + assert_eq!(docset.seek(3), 5); assert_eq!(docset.doc(), 5); - assert!(docset.advance()); + assert_eq!(docset.advance(), 6); } { let mut docset = create_docbitset(&[5112], 10_000); - assert_eq!(docset.skip_next(5112), SkipResult::Reached); + assert_eq!(docset.seek(5112), 5112); assert_eq!(docset.doc(), 5112); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let mut docset = create_docbitset(&[5112], 10_000); - assert_eq!(docset.skip_next(5113), SkipResult::End); - assert!(!docset.advance()); + assert_eq!(docset.seek(5113), TERMINATED); + assert_eq!(docset.advance(), TERMINATED); } { let mut docset = create_docbitset(&[5112], 10_000); - assert_eq!(docset.skip_next(5111), SkipResult::OverStep); + assert_eq!(docset.seek(5111), 5112); assert_eq!(docset.doc(), 5112); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000); - assert_eq!(docset.skip_next(5112), SkipResult::Reached); + assert_eq!(docset.seek(5112), 5112); assert_eq!(docset.doc(), 5112); - assert!(docset.advance()); + assert_eq!(docset.advance(), 5500); assert_eq!(docset.doc(), 5500); - assert!(docset.advance()); + assert_eq!(docset.advance(), 6666); assert_eq!(docset.doc(), 6666); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000); - assert_eq!(docset.skip_next(5111), SkipResult::OverStep); + assert_eq!(docset.seek(5111), 5112); assert_eq!(docset.doc(), 5112); - assert!(docset.advance()); + assert_eq!(docset.advance(), 5500); assert_eq!(docset.doc(), 5500); - assert!(docset.advance()); + assert_eq!(docset.advance(), 6666); assert_eq!(docset.doc(), 6666); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } { let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5513, 6666], 10_000); - assert_eq!(docset.skip_next(5111), SkipResult::OverStep); + assert_eq!(docset.seek(5111), 5112); assert_eq!(docset.doc(), 5112); - assert!(docset.advance()); + assert_eq!(docset.advance(), 5513); assert_eq!(docset.doc(), 5513); - assert!(docset.advance()); + assert_eq!(docset.advance(), 6666); assert_eq!(docset.doc(), 6666); - assert!(!docset.advance()); + assert_eq!(docset.advance(), TERMINATED); } } } @@ -223,6 +202,7 @@ mod bench { use super::BitSet; use super::BitSetDocSet; + use crate::docset::TERMINATED; use crate::test; use crate::tests; use crate::DocSet; @@ -257,7 +237,7 @@ mod bench { } b.iter(|| { let mut docset = BitSetDocSet::from(bitset.clone()); - while docset.advance() {} + while docset.advance() != TERMINATED {} }); } } diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index f759db3aa..8b6a6c881 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -10,7 +10,7 @@ use crate::query::Scorer; use crate::query::Union; use crate::query::Weight; use crate::query::{intersect_scorers, Explanation}; -use crate::{DocId, SkipResult}; +use crate::DocId; use std::collections::HashMap; fn scorer_union(scorers: Vec>) -> Box @@ -133,7 +133,7 @@ impl Weight for BooleanWeight { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { let mut scorer = self.scorer(reader, 1.0f32)?; - if scorer.skip_next(doc) != SkipResult::Reached { + if scorer.seek(doc) != doc { return Err(does_not_match(doc)); } if !self.scoring_enabled { diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 1c6a341ef..84a42c7e5 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -207,7 +207,6 @@ mod tests { let mut boolean_scorer = boolean_weight .scorer(searcher.segment_reader(0u32), 1.0f32) .unwrap(); - assert!(boolean_scorer.advance()); assert_eq!(boolean_scorer.doc(), 0u32); assert_nearly_equals(boolean_scorer.score(), 0.84163445f32); } @@ -215,7 +214,6 @@ mod tests { let mut boolean_scorer = boolean_weight .scorer(searcher.segment_reader(0u32), 2.0f32) .unwrap(); - assert!(boolean_scorer.advance()); assert_eq!(boolean_scorer.doc(), 0u32); assert_nearly_equals(boolean_scorer.score(), 1.6832689f32); } diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index eda5431a0..da22da072 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -1,8 +1,7 @@ -use crate::common::BitSet; use crate::fastfield::DeleteBitSet; use crate::query::explanation::does_not_match; use crate::query::{Explanation, Query, Scorer, Weight}; -use crate::{DocId, DocSet, Searcher, SegmentReader, SkipResult, Term}; +use crate::{DocId, DocSet, Searcher, SegmentReader, Term}; use std::collections::BTreeSet; use std::fmt; @@ -72,7 +71,7 @@ impl Weight for BoostWeight { fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result { let mut scorer = self.scorer(reader, 1.0f32)?; - if scorer.skip_next(doc) != SkipResult::Reached { + if scorer.seek(doc) != doc { return Err(does_not_match(doc)); } let mut explanation = @@ -99,12 +98,12 @@ impl BoostScorer { } impl DocSet for BoostScorer { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { self.underlying.advance() } - fn skip_next(&mut self, target: DocId) -> SkipResult { - self.underlying.skip_next(target) + fn seek(&mut self, target: DocId) -> DocId { + self.underlying.seek(target) } fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { @@ -119,10 +118,6 @@ impl DocSet for BoostScorer { self.underlying.size_hint() } - fn append_to_bitset(&mut self, bitset: &mut BitSet) { - self.underlying.append_to_bitset(bitset) - } - fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 { self.underlying.count(delete_bitset) } diff --git a/src/query/empty_query.rs b/src/query/empty_query.rs index 76932070e..addec95dc 100644 --- a/src/query/empty_query.rs +++ b/src/query/empty_query.rs @@ -1,4 +1,5 @@ use super::Scorer; +use crate::docset::TERMINATED; use crate::query::explanation::does_not_match; use crate::query::Weight; use crate::query::{Explanation, Query}; @@ -48,15 +49,12 @@ impl Weight for EmptyWeight { pub struct EmptyScorer; impl DocSet for EmptyScorer { - fn advance(&mut self) -> bool { - false + fn advance(&mut self) -> DocId { + TERMINATED } fn doc(&self) -> DocId { - panic!( - "You may not call .doc() on a scorer \ - where the last call to advance() did not return true." - ); + TERMINATED } fn size_hint(&self) -> u32 { @@ -72,18 +70,15 @@ impl Scorer for EmptyScorer { #[cfg(test)] mod tests { + use crate::docset::TERMINATED; use crate::query::EmptyScorer; use crate::DocSet; #[test] fn test_empty_scorer() { let mut empty_scorer = EmptyScorer; - assert!(!empty_scorer.advance()); - } - - #[test] - #[should_panic] - fn test_empty_scorer_panic_on_doc_call() { - EmptyScorer.doc(); + assert_eq!(empty_scorer.doc(), TERMINATED); + assert_eq!(empty_scorer.advance(), TERMINATED); + assert_eq!(empty_scorer.doc(), TERMINATED); } } diff --git a/src/query/exclude.rs b/src/query/exclude.rs index 6bb6996b4..8dd35be95 100644 --- a/src/query/exclude.rs +++ b/src/query/exclude.rs @@ -1,41 +1,37 @@ -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::query::Scorer; use crate::DocId; use crate::Score; -#[derive(Clone, Copy, Debug)] -enum State { - ExcludeOne(DocId), - Finished, -} - /// Filters a given `DocSet` by removing the docs from a given `DocSet`. /// /// The excluding docset has no impact on scoring. pub struct Exclude { underlying_docset: TDocSet, excluding_docset: TDocSetExclude, - excluding_state: State, } impl Exclude where + TDocSet: DocSet, TDocSetExclude: DocSet, { /// Creates a new `ExcludeScorer` pub fn new( - underlying_docset: TDocSet, + mut underlying_docset: TDocSet, mut excluding_docset: TDocSetExclude, ) -> Exclude { - let state = if excluding_docset.advance() { - State::ExcludeOne(excluding_docset.doc()) - } else { - State::Finished - }; + while underlying_docset.doc() != TERMINATED { + let target = underlying_docset.doc(); + if excluding_docset.seek(target) != target { + // this document is not excluded. + break; + } + underlying_docset.advance(); + } Exclude { underlying_docset, excluding_docset, - excluding_state: state, } } } @@ -51,28 +47,7 @@ where /// increasing `doc`. fn accept(&mut self) -> bool { let doc = self.underlying_docset.doc(); - match self.excluding_state { - State::ExcludeOne(excluded_doc) => { - if doc == excluded_doc { - return false; - } - if excluded_doc > doc { - return true; - } - match self.excluding_docset.skip_next(doc) { - SkipResult::OverStep => { - self.excluding_state = State::ExcludeOne(self.excluding_docset.doc()); - true - } - SkipResult::End => { - self.excluding_state = State::Finished; - true - } - SkipResult::Reached => false, - } - } - State::Finished => true, - } + self.excluding_docset.seek(doc) != doc } } @@ -81,27 +56,24 @@ where TDocSet: DocSet, TDocSetExclude: DocSet, { - fn advance(&mut self) -> bool { - while self.underlying_docset.advance() { + fn advance(&mut self) -> DocId { + while self.underlying_docset.advance() != TERMINATED { if self.accept() { - return true; + return self.doc(); } } - false + TERMINATED } - fn skip_next(&mut self, target: DocId) -> SkipResult { - let underlying_skip_result = self.underlying_docset.skip_next(target); - if underlying_skip_result == SkipResult::End { - return SkipResult::End; + fn seek(&mut self, target: DocId) -> DocId { + let underlying_seek_result = self.underlying_docset.seek(target); + if underlying_seek_result == TERMINATED { + return TERMINATED; } if self.accept() { - underlying_skip_result - } else if self.advance() { - SkipResult::OverStep - } else { - SkipResult::End + return underlying_seek_result; } + self.advance() } fn doc(&self) -> DocId { @@ -141,8 +113,9 @@ mod tests { VecDocSet::from(vec![1, 2, 3, 10, 16, 24]), ); let mut els = vec![]; - while exclude_scorer.advance() { + while exclude_scorer.doc() != TERMINATED { els.push(exclude_scorer.doc()); + exclude_scorer.advance(); } assert_eq!(els, vec![5, 8, 15]); } diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 9140f7fd5..d72ef5866 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -1,4 +1,4 @@ -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::query::term_query::TermScorer; use crate::query::EmptyScorer; use crate::query::Scorer; @@ -20,12 +20,14 @@ pub fn intersect_scorers(mut scorers: Vec>) -> Box { if scorers.len() == 1 { return scorers.pop().unwrap(); } + scorers.sort_by_key(|scorer| scorer.size_hint()); + let doc = go_to_first_doc(&mut scorers[..]); + if doc == TERMINATED { + return Box::new(EmptyScorer); + } // We know that we have at least 2 elements. - let num_docsets = scorers.len(); - scorers.sort_by(|left, right| right.size_hint().cmp(&left.size_hint())); - let left = scorers.pop().unwrap(); - let right = scorers.pop().unwrap(); - scorers.reverse(); + let left = scorers.remove(0); + let right = scorers.remove(0); let all_term_scorers = [&left, &right] .iter() .all(|&scorer| scorer.is::()); @@ -34,14 +36,12 @@ pub fn intersect_scorers(mut scorers: Vec>) -> Box { left: *(left.downcast::().map_err(|_| ()).unwrap()), right: *(right.downcast::().map_err(|_| ()).unwrap()), others: scorers, - num_docsets, }); } Box::new(Intersection { left, right, others: scorers, - num_docsets, }) } @@ -50,22 +50,34 @@ pub struct Intersection> left: TDocSet, right: TDocSet, others: Vec, - num_docsets: usize, +} + +fn go_to_first_doc(docsets: &mut [TDocSet]) -> DocId { + let mut candidate = 0; + 'outer: loop { + for docset in docsets.iter_mut() { + let seek_doc = docset.seek(candidate); + if seek_doc > candidate { + candidate = docset.doc(); + continue 'outer; + } + } + return candidate; + } } impl Intersection { pub(crate) fn new(mut docsets: Vec) -> Intersection { let num_docsets = docsets.len(); assert!(num_docsets >= 2); - docsets.sort_by(|left, right| right.size_hint().cmp(&left.size_hint())); - let left = docsets.pop().unwrap(); - let right = docsets.pop().unwrap(); - docsets.reverse(); + docsets.sort_by_key(|docset| docset.size_hint()); + go_to_first_doc(&mut docsets); + let left = docsets.remove(0); + let right = docsets.remove(0); Intersection { left, right, others: docsets, - num_docsets, } } } @@ -80,128 +92,49 @@ impl Intersection { } } -impl Intersection { - pub(crate) fn docset_mut(&mut self, ord: usize) -> &mut dyn DocSet { - match ord { - 0 => &mut self.left, - 1 => &mut self.right, - n => &mut self.others[n - 2], - } - } -} - impl DocSet for Intersection { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { let (left, right) = (&mut self.left, &mut self.right); - - if !left.advance() { - return false; - } - - let mut candidate = left.doc(); - let mut other_candidate_ord: usize = usize::max_value(); + let mut candidate = left.advance(); 'outer: loop { // In the first part we look for a document in the intersection // of the two rarest `DocSet` in the intersection. + loop { - match right.skip_next(candidate) { - SkipResult::Reached => { - break; - } - SkipResult::OverStep => { - candidate = right.doc(); - other_candidate_ord = usize::max_value(); - } - SkipResult::End => { - return false; - } - } - match left.skip_next(candidate) { - SkipResult::Reached => { - break; - } - SkipResult::OverStep => { - candidate = left.doc(); - other_candidate_ord = usize::max_value(); - } - SkipResult::End => { - return false; - } + let right_doc = right.seek(candidate); + candidate = left.seek(right_doc); + if candidate == right_doc { + break; } } + + debug_assert_eq!(left.doc(), right.doc()); // test the remaining scorers; - for (ord, docset) in self.others.iter_mut().enumerate() { - if ord == other_candidate_ord { - continue; - } + for docset in self.others.iter_mut() { // `candidate_ord` is already at the // right position. // // Calling `skip_next` would advance this docset // and miss it. - match docset.skip_next(candidate) { - SkipResult::Reached => {} - SkipResult::OverStep => { - // this is not in the intersection, - // let's update our candidate. - candidate = docset.doc(); - match left.skip_next(candidate) { - SkipResult::Reached => { - other_candidate_ord = ord; - } - SkipResult::OverStep => { - candidate = left.doc(); - other_candidate_ord = usize::max_value(); - } - SkipResult::End => { - return false; - } - } - continue 'outer; - } - SkipResult::End => { - return false; - } + let seek_doc = docset.seek(candidate); + if seek_doc > candidate { + candidate = left.seek(seek_doc); + continue 'outer; } } - return true; + + return candidate; } } - fn skip_next(&mut self, target: DocId) -> SkipResult { - // We optimize skipping by skipping every single member - // of the intersection to target. - let mut current_target: DocId = target; - let mut current_ord = self.num_docsets; - - 'outer: loop { - for ord in 0..self.num_docsets { - let docset = self.docset_mut(ord); - if ord == current_ord { - continue; - } - match docset.skip_next(current_target) { - SkipResult::End => { - return SkipResult::End; - } - SkipResult::OverStep => { - // update the target - // for the remaining members of the intersection. - current_target = docset.doc(); - current_ord = ord; - continue 'outer; - } - SkipResult::Reached => {} - } - } - if target == current_target { - return SkipResult::Reached; - } else { - assert!(current_target > target); - return SkipResult::OverStep; - } + fn seek(&mut self, target: DocId) -> DocId { + self.left.seek(target); + let mut docsets: Vec<&mut dyn DocSet> = vec![&mut self.left, &mut self.right]; + for docset in &mut self.others { + docsets.push(docset); } + go_to_first_doc(&mut docsets[..]) } fn doc(&self) -> DocId { @@ -228,7 +161,7 @@ where #[cfg(test)] mod tests { use super::Intersection; - use crate::docset::{DocSet, SkipResult}; + use crate::docset::{DocSet, TERMINATED}; use crate::postings::tests::test_skip_against_unoptimized; use crate::query::VecDocSet; @@ -238,20 +171,18 @@ mod tests { let left = VecDocSet::from(vec![1, 3, 9]); let right = VecDocSet::from(vec![3, 4, 9, 18]); let mut intersection = Intersection::new(vec![left, right]); - assert!(intersection.advance()); assert_eq!(intersection.doc(), 3); - assert!(intersection.advance()); + assert_eq!(intersection.advance(), 9); assert_eq!(intersection.doc(), 9); - assert!(!intersection.advance()); + assert_eq!(intersection.advance(), TERMINATED); } { let a = VecDocSet::from(vec![1, 3, 9]); let b = VecDocSet::from(vec![3, 4, 9, 18]); let c = VecDocSet::from(vec![1, 5, 9, 111]); let mut intersection = Intersection::new(vec![a, b, c]); - assert!(intersection.advance()); assert_eq!(intersection.doc(), 9); - assert!(!intersection.advance()); + assert_eq!(intersection.advance(), TERMINATED); } } @@ -260,8 +191,8 @@ mod tests { let left = VecDocSet::from(vec![0]); let right = VecDocSet::from(vec![0]); let mut intersection = Intersection::new(vec![left, right]); - assert!(intersection.advance()); assert_eq!(intersection.doc(), 0); + assert_eq!(intersection.advance(), TERMINATED); } #[test] @@ -269,7 +200,7 @@ mod tests { let left = VecDocSet::from(vec![0, 1, 2, 4]); let right = VecDocSet::from(vec![2, 5]); let mut intersection = Intersection::new(vec![left, right]); - assert_eq!(intersection.skip_next(2), SkipResult::Reached); + assert_eq!(intersection.seek(2), 2); assert_eq!(intersection.doc(), 2); } @@ -312,7 +243,7 @@ mod tests { let a = VecDocSet::from(vec![1, 3]); let b = VecDocSet::from(vec![1, 4]); let c = VecDocSet::from(vec![3, 9]); - let mut intersection = Intersection::new(vec![a, b, c]); - assert!(!intersection.advance()); + let intersection = Intersection::new(vec![a, b, c]); + assert_eq!(intersection.doc(), TERMINATED); } } diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 44bb0787b..95eeefd70 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -60,8 +60,8 @@ pub mod tests { .map(|docaddr| docaddr.1) .collect::>() }; - assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]); assert_eq!(test_query(vec!["a", "b"]), vec![1, 2, 3, 4]); + assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]); assert_eq!(test_query(vec!["b", "b"]), vec![0, 1]); assert!(test_query(vec!["g", "ewrwer"]).is_empty()); assert!(test_query(vec!["g", "a"]).is_empty()); diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 3a0902f91..99712c5b6 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -1,4 +1,4 @@ -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::fieldnorm::FieldNormReader; use crate::postings::Postings; use crate::query::bm25::BM25Weight; @@ -25,12 +25,12 @@ impl PostingsWithOffset { } impl DocSet for PostingsWithOffset { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { self.postings.advance() } - fn skip_next(&mut self, target: DocId) -> SkipResult { - self.postings.skip_next(target) + fn seek(&mut self, target: DocId) -> DocId { + self.postings.seek(target) } fn doc(&self) -> DocId { @@ -149,7 +149,7 @@ impl PhraseScorer { PostingsWithOffset::new(postings, (max_offset - offset) as u32) }) .collect::>(); - PhraseScorer { + let mut scorer = PhraseScorer { intersection_docset: Intersection::new(postings_with_offsets), num_terms: num_docsets, left: Vec::with_capacity(100), @@ -158,7 +158,11 @@ impl PhraseScorer { similarity_weight, fieldnorm_reader, score_needed, + }; + if scorer.doc() != TERMINATED && !scorer.phrase_match() { + scorer.advance(); } + scorer } pub fn phrase_count(&self) -> u32 { @@ -225,31 +229,21 @@ impl PhraseScorer { } impl DocSet for PhraseScorer { - fn advance(&mut self) -> bool { - while self.intersection_docset.advance() { - if self.phrase_match() { - return true; + fn advance(&mut self) -> DocId { + loop { + let doc = self.intersection_docset.advance(); + if doc == TERMINATED || self.phrase_match() { + return doc; } } - false } - fn skip_next(&mut self, target: DocId) -> SkipResult { - if self.intersection_docset.skip_next(target) == SkipResult::End { - return SkipResult::End; - } - if self.phrase_match() { - if self.doc() == target { - return SkipResult::Reached; - } else { - return SkipResult::OverStep; - } - } - if self.advance() { - SkipResult::OverStep - } else { - SkipResult::End + fn seek(&mut self, target: DocId) -> DocId { + let doc = self.intersection_docset.seek(target); + if doc == TERMINATED || self.phrase_match() { + return doc; } + self.advance() } fn doc(&self) -> DocId { diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index f82ca2288..3cdc98cc9 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -9,8 +9,8 @@ use crate::query::Weight; use crate::query::{EmptyScorer, Explanation}; use crate::schema::IndexRecordOption; use crate::schema::Term; +use crate::Result; use crate::{DocId, DocSet}; -use crate::{Result, SkipResult}; pub struct PhraseWeight { phrase_terms: Vec<(usize, Term)>, @@ -99,7 +99,7 @@ impl Weight for PhraseWeight { return Err(does_not_match(doc)); } let mut scorer = scorer_opt.unwrap(); - if scorer.skip_next(doc) != SkipResult::Reached { + if scorer.seek(doc) != doc { return Err(does_not_match(doc)); } let fieldnorm_reader = self.fieldnorm_reader(reader); @@ -114,6 +114,7 @@ impl Weight for PhraseWeight { #[cfg(test)] mod tests { use super::super::tests::create_index; + use crate::docset::TERMINATED; use crate::query::PhraseQuery; use crate::{DocSet, Term}; @@ -132,12 +133,11 @@ mod tests { .phrase_scorer(searcher.segment_reader(0u32), 1.0f32) .unwrap() .unwrap(); - assert!(phrase_scorer.advance()); assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.phrase_count(), 2); - assert!(phrase_scorer.advance()); + assert_eq!(phrase_scorer.advance(), 2); assert_eq!(phrase_scorer.doc(), 2); assert_eq!(phrase_scorer.phrase_count(), 1); - assert!(!phrase_scorer.advance()); + assert_eq!(phrase_scorer.advance(), TERMINATED); } } diff --git a/src/query/range_query.rs b/src/query/range_query.rs index 0440b7e0c..26f957af0 100644 --- a/src/query/range_query.rs +++ b/src/query/range_query.rs @@ -10,7 +10,7 @@ use crate::schema::Type; use crate::schema::{Field, IndexRecordOption, Term}; use crate::termdict::{TermDictionary, TermStreamer}; use crate::DocId; -use crate::{Result, SkipResult}; +use crate::Result; use std::collections::Bound; use std::ops::Range; @@ -312,7 +312,7 @@ impl Weight for RangeWeight { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { let mut scorer = self.scorer(reader, 1.0f32)?; - if scorer.skip_next(doc) != SkipResult::Reached { + if scorer.seek(doc) != doc { return Err(does_not_match(doc)); } Ok(Explanation::new("RangeQuery", 1.0f32)) diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index 16c15a198..2a02dc790 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -1,9 +1,8 @@ -use crate::docset::{DocSet, SkipResult}; +use crate::docset::DocSet; use crate::query::score_combiner::ScoreCombiner; use crate::query::Scorer; use crate::DocId; use crate::Score; -use std::cmp::Ordering; use std::marker::PhantomData; /// Given a required scorer and an optional scorer @@ -17,7 +16,6 @@ pub struct RequiredOptionalScorer { req_scorer: TReqScorer, opt_scorer: TOptScorer, score_cache: Option, - opt_finished: bool, _phantom: PhantomData, } @@ -29,14 +27,12 @@ where /// Creates a new `RequiredOptionalScorer`. pub fn new( req_scorer: TReqScorer, - mut opt_scorer: TOptScorer, + opt_scorer: TOptScorer, ) -> RequiredOptionalScorer { - let opt_finished = !opt_scorer.advance(); RequiredOptionalScorer { req_scorer, opt_scorer, score_cache: None, - opt_finished, _phantom: PhantomData, } } @@ -48,7 +44,7 @@ where TReqScorer: DocSet, TOptScorer: DocSet, { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { self.score_cache = None; self.req_scorer.advance() } @@ -76,22 +72,8 @@ where let doc = self.doc(); let mut score_combiner = TScoreCombiner::default(); score_combiner.update(&mut self.req_scorer); - if !self.opt_finished { - match self.opt_scorer.doc().cmp(&doc) { - Ordering::Greater => {} - Ordering::Equal => { - score_combiner.update(&mut self.opt_scorer); - } - Ordering::Less => match self.opt_scorer.skip_next(doc) { - SkipResult::Reached => { - score_combiner.update(&mut self.opt_scorer); - } - SkipResult::End => { - self.opt_finished = true; - } - SkipResult::OverStep => {} - }, - } + if self.opt_scorer.seek(doc) == doc { + score_combiner.update(&mut self.opt_scorer); } let score = score_combiner.score(); self.score_cache = Some(score); @@ -102,7 +84,7 @@ where #[cfg(test)] mod tests { use super::RequiredOptionalScorer; - use crate::docset::DocSet; + use crate::docset::{DocSet, TERMINATED}; use crate::postings::tests::test_skip_against_unoptimized; use crate::query::score_combiner::{DoNothingCombiner, SumCombiner}; use crate::query::ConstScorer; @@ -119,9 +101,7 @@ mod tests { ConstScorer::from(VecDocSet::from(vec![])), ); let mut docs = vec![]; - while reqoptscorer.advance() { - docs.push(reqoptscorer.doc()); - } + reqoptscorer.for_each(&mut |doc, _| docs.push(doc)); assert_eq!(docs, req); } @@ -133,46 +113,45 @@ mod tests { ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15]), 1.0f32), ); { - assert!(reqoptscorer.advance()); assert_eq!(reqoptscorer.doc(), 1); assert_eq!(reqoptscorer.score(), 2f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 3); assert_eq!(reqoptscorer.doc(), 3); assert_eq!(reqoptscorer.score(), 1f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 7); assert_eq!(reqoptscorer.doc(), 7); assert_eq!(reqoptscorer.score(), 2f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 8); assert_eq!(reqoptscorer.doc(), 8); assert_eq!(reqoptscorer.score(), 1f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 9); assert_eq!(reqoptscorer.doc(), 9); assert_eq!(reqoptscorer.score(), 1f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 10); assert_eq!(reqoptscorer.doc(), 10); assert_eq!(reqoptscorer.score(), 1f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 13); assert_eq!(reqoptscorer.doc(), 13); assert_eq!(reqoptscorer.score(), 1f32); } { - assert!(reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), 15); assert_eq!(reqoptscorer.doc(), 15); assert_eq!(reqoptscorer.score(), 2f32); } - assert!(!reqoptscorer.advance()); + assert_eq!(reqoptscorer.advance(), TERMINATED); } #[test] diff --git a/src/query/scorer.rs b/src/query/scorer.rs index 02a4fb021..923d5ff7a 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -1,5 +1,4 @@ -use crate::common::BitSet; -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::DocId; use crate::Score; use downcast_rs::impl_downcast; @@ -17,8 +16,35 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static { /// Iterates through all of the document matched by the DocSet /// `DocSet` and push the scored documents to the collector. fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) { - while self.advance() { - callback(self.doc(), self.score()); + let mut doc = self.doc(); + while doc != TERMINATED { + callback(doc, self.score()); + doc = self.advance(); + } + } + + /// Calls `callback` with all of the `(doc, score)` for which score + /// is exceeding a given threshold. + /// + /// This method is useful for the TopDocs collector. + /// For all docsets, the blanket implementation has the benefit + /// of prefiltering (doc, score) pairs, avoiding the + /// virtual dispatch cost. + /// + /// More importantly, it makes it possible for scorers to implement + /// important optimization (e.g. BlockWAND for union). + fn for_each_pruning( + &mut self, + mut threshold: f32, + callback: &mut dyn FnMut(DocId, Score) -> Score, + ) { + let mut doc = self.doc(); + while doc != TERMINATED { + let score = self.score(); + if score > threshold { + threshold = callback(doc, score); + } + doc = self.advance(); } } } @@ -61,12 +87,12 @@ impl From for ConstScorer { } impl DocSet for ConstScorer { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { self.docset.advance() } - fn skip_next(&mut self, target: DocId) -> SkipResult { - self.docset.skip_next(target) + fn seek(&mut self, target: DocId) -> DocId { + self.docset.seek(target) } fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { @@ -80,10 +106,6 @@ impl DocSet for ConstScorer { fn size_hint(&self) -> u32 { self.docset.size_hint() } - - fn append_to_bitset(&mut self, bitset: &mut BitSet) { - self.docset.append_to_bitset(bitset); - } } impl Scorer for ConstScorer { diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index d38756b7e..0ea904cf1 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -26,10 +26,8 @@ mod tests { { // writing the segment let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); - { - let doc = doc!(text_field => "a"); - index_writer.add_document(doc); - } + let doc = doc!(text_field => "a"); + index_writer.add_document(doc); assert!(index_writer.commit().is_ok()); } let searcher = index.reader().unwrap().searcher(); @@ -40,7 +38,6 @@ mod tests { let term_weight = term_query.weight(&searcher, true).unwrap(); let segment_reader = searcher.segment_reader(0); let mut term_scorer = term_weight.scorer(segment_reader, 1.0f32).unwrap(); - assert!(term_scorer.advance()); assert_eq!(term_scorer.doc(), 0); assert_eq!(term_scorer.score(), 0.28768212); } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index e950edbac..6ccd20892 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -1,4 +1,4 @@ -use crate::docset::{DocSet, SkipResult}; +use crate::docset::DocSet; use crate::query::{Explanation, Scorer}; use crate::DocId; use crate::Score; @@ -45,12 +45,12 @@ impl TermScorer { } impl DocSet for TermScorer { - fn advance(&mut self) -> bool { + fn advance(&mut self) -> DocId { self.postings.advance() } - fn skip_next(&mut self, target: DocId) -> SkipResult { - self.postings.skip_next(target) + fn seek(&mut self, target: DocId) -> DocId { + self.postings.seek(target) } fn doc(&self) -> DocId { diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index e7d47847e..3d53827ee 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -8,8 +8,8 @@ use crate::query::Weight; use crate::query::{Explanation, Scorer}; use crate::schema::IndexRecordOption; use crate::DocId; +use crate::Result; use crate::Term; -use crate::{Result, SkipResult}; pub struct TermWeight { term: Term, @@ -25,7 +25,7 @@ impl Weight for TermWeight { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { let mut scorer = self.scorer_specialized(reader, 1.0f32)?; - if scorer.skip_next(doc) != SkipResult::Reached { + if scorer.seek(doc) != doc { return Err(does_not_match(doc)); } Ok(scorer.explain()) diff --git a/src/query/union.rs b/src/query/union.rs index 7e27ac877..ddaa1ba3d 100644 --- a/src/query/union.rs +++ b/src/query/union.rs @@ -1,10 +1,9 @@ use crate::common::TinySet; -use crate::docset::{DocSet, SkipResult}; +use crate::docset::{DocSet, TERMINATED}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::Scorer; use crate::DocId; use crate::Score; -use std::cmp::Ordering; const HORIZON_NUM_TINYBITSETS: usize = 64; const HORIZON: u32 = 64u32 * HORIZON_NUM_TINYBITSETS as u32; @@ -47,17 +46,9 @@ where fn from(docsets: Vec) -> Union { let non_empty_docsets: Vec = docsets .into_iter() - .flat_map( - |mut docset| { - if docset.advance() { - Some(docset) - } else { - None - } - }, - ) + .filter(|docset| docset.doc() != TERMINATED) .collect(); - Union { + let mut union = Union { docsets: non_empty_docsets, bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]), scores: Box::new([TScoreCombiner::default(); HORIZON as usize]), @@ -65,7 +56,13 @@ where offset: 0, doc: 0, score: 0f32, + }; + if union.refill() { + union.advance(); + } else { + union.doc = TERMINATED; } + union } } @@ -86,7 +83,7 @@ fn refill( let delta = doc - min_doc; bitsets[(delta / 64) as usize].insert_mut(delta % 64u32); score_combiner[delta as usize].update(scorer); - if !scorer.advance() { + if scorer.advance() == TERMINATED { // remove the docset, it has been entirely consumed. return true; } @@ -99,6 +96,7 @@ impl Union bool { + fn advance(&mut self) -> DocId { if self.advance_buffered() { - return true; + return self.doc; } - if self.refill() { - self.advance(); - true - } else { - false + if !self.refill() { + self.doc = TERMINATED; + return TERMINATED; } + if !self.advance_buffered() { + return TERMINATED; + } + self.doc } - fn skip_next(&mut self, target: DocId) -> SkipResult { - if !self.advance() { - return SkipResult::End; - } - match self.doc.cmp(&target) { - Ordering::Equal => { - return SkipResult::Reached; - } - Ordering::Greater => { - return SkipResult::OverStep; - } - Ordering::Less => {} + fn seek(&mut self, target: DocId) -> DocId { + if self.doc >= target { + return self.doc; } let gap = target - self.offset; if gap < HORIZON { @@ -174,18 +165,11 @@ where // Advancing until we reach the end of the bucket // or we reach a doc greater or equal to the target. - while self.advance() { - match self.doc().cmp(&target) { - Ordering::Equal => { - return SkipResult::Reached; - } - Ordering::Greater => { - return SkipResult::OverStep; - } - Ordering::Less => {} - } + let mut doc = self.doc(); + while doc < target { + doc = self.advance(); } - SkipResult::End + doc } else { // clear the buffered info. for obsolete_tinyset in self.bitsets.iter_mut() { @@ -199,45 +183,30 @@ where // advance all docsets to a doc >= to the target. #[cfg_attr(feature = "cargo-clippy", allow(clippy::clippy::collapsible_if))] unordered_drain_filter(&mut self.docsets, |docset| { - if docset.doc() < target { - if docset.skip_next(target) == SkipResult::End { - return true; - } - } - false + docset.seek(target) == TERMINATED }); // at this point all of the docsets // are positionned on a doc >= to the target. - if self.refill() { - self.advance(); - if self.doc() == target { - SkipResult::Reached - } else { - debug_assert!(self.doc() > target); - SkipResult::OverStep - } - } else { - SkipResult::End + if !self.refill() { + self.doc = TERMINATED; + return TERMINATED; } + self.advance() } } - // TODO implement `count` efficiently. - - fn doc(&self) -> DocId { - self.doc - } - - fn size_hint(&self) -> u32 { - 0u32 - } + // TODO Also implement `count` with deletes efficiently. fn count_including_deleted(&mut self) -> u32 { + if self.doc == TERMINATED { + return 0; + } let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS] .iter() .map(|bitset| bitset.len()) - .sum::(); + .sum::() + + 1; for bitset in self.bitsets.iter_mut() { bitset.clear(); } @@ -250,6 +219,14 @@ where self.cursor = HORIZON_NUM_TINYBITSETS; count } + + fn doc(&self) -> DocId { + self.doc + } + + fn size_hint(&self) -> u32 { + self.docsets.iter().map(|docset| docset.size_hint()).max().unwrap_or(0u32) + } } impl Scorer for Union @@ -267,7 +244,7 @@ mod tests { use super::Union; use super::HORIZON; - use crate::docset::{DocSet, SkipResult}; + use crate::docset::{DocSet, TERMINATED}; use crate::postings::tests::test_skip_against_unoptimized; use crate::query::score_combiner::DoNothingCombiner; use crate::query::ConstScorer; @@ -296,12 +273,12 @@ mod tests { }; let mut union: Union<_, DoNothingCombiner> = make_union(); let mut count = 0; - while union.advance() { - assert!(union_expected.advance()); + while union.doc() != TERMINATED { assert_eq!(union_expected.doc(), union.doc()); + assert_eq!(union_expected.advance(), union.advance()); count += 1; } - assert!(!union_expected.advance()); + assert_eq!(union_expected.advance(), TERMINATED); assert_eq!(count, make_union().count_including_deleted()); } @@ -329,9 +306,7 @@ mod tests { fn test_aux_union_skip(docs_list: &[Vec], skip_targets: Vec) { let mut btree_set = BTreeSet::new(); for docs in docs_list { - for &doc in docs.iter() { - btree_set.insert(doc); - } + btree_set.extend(docs.iter().cloned()); } let docset_factory = || { let res: Box = Box::new(Union::<_, DoNothingCombiner>::from( @@ -346,10 +321,10 @@ mod tests { }; let mut docset = docset_factory(); for el in btree_set { - assert!(docset.advance()); assert_eq!(el, docset.doc()); + docset.advance(); } - assert!(!docset.advance()); + assert_eq!(docset.doc(), TERMINATED); test_skip_against_unoptimized(docset_factory, skip_targets); } @@ -372,10 +347,10 @@ mod tests { ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), ]); - assert!(docset.advance()); assert_eq!(docset.doc(), 0u32); - assert_eq!(docset.skip_next(0u32), SkipResult::OverStep); - assert_eq!(docset.doc(), 1u32) + assert_eq!(docset.seek(0u32), 0u32); + assert_eq!(docset.seek(0u32), 0u32); + assert_eq!(docset.doc(), 0u32) } #[test] diff --git a/src/query/vec_docset.rs b/src/query/vec_docset.rs index 33b6143e4..89f32bd7f 100644 --- a/src/query/vec_docset.rs +++ b/src/query/vec_docset.rs @@ -1,9 +1,8 @@ #![allow(dead_code)] use crate::common::HasLen; -use crate::docset::DocSet; +use crate::docset::{DocSet, TERMINATED}; use crate::DocId; -use std::num::Wrapping; /// Simulate a `Postings` objects from a `VecPostings`. /// `VecPostings` only exist for testing purposes. @@ -12,26 +11,30 @@ use std::num::Wrapping; /// No positions are returned. pub struct VecDocSet { doc_ids: Vec, - cursor: Wrapping, + cursor: usize, } impl From> for VecDocSet { fn from(doc_ids: Vec) -> VecDocSet { - VecDocSet { - doc_ids, - cursor: Wrapping(usize::max_value()), - } + VecDocSet { doc_ids, cursor: 0 } } } impl DocSet for VecDocSet { - fn advance(&mut self) -> bool { - self.cursor += Wrapping(1); - self.doc_ids.len() > self.cursor.0 + fn advance(&mut self) -> DocId { + self.cursor += 1; + if self.cursor >= self.doc_ids.len() { + self.cursor = self.doc_ids.len(); + return TERMINATED; + } + self.doc() } fn doc(&self) -> DocId { - self.doc_ids[self.cursor.0] + if self.cursor == self.doc_ids.len() { + return TERMINATED; + } + self.doc_ids[self.cursor] } fn size_hint(&self) -> u32 { @@ -49,22 +52,21 @@ impl HasLen for VecDocSet { pub mod tests { use super::*; - use crate::docset::{DocSet, SkipResult}; + use crate::docset::DocSet; use crate::DocId; #[test] pub fn test_vec_postings() { let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); let mut postings = VecDocSet::from(doc_ids); - assert!(postings.advance()); assert_eq!(postings.doc(), 0u32); - assert!(postings.advance()); + assert_eq!(postings.advance(), 3u32); assert_eq!(postings.doc(), 3u32); - assert_eq!(postings.skip_next(14u32), SkipResult::OverStep); + assert_eq!(postings.seek(14u32), 15u32); assert_eq!(postings.doc(), 15u32); - assert_eq!(postings.skip_next(300u32), SkipResult::Reached); + assert_eq!(postings.seek(300u32), 300u32); assert_eq!(postings.doc(), 300u32); - assert_eq!(postings.skip_next(6000u32), SkipResult::End); + assert_eq!(postings.seek(6000u32), TERMINATED); } #[test]