diff --git a/examples/iterating_docs_and_positions.rs b/examples/iterating_docs_and_positions.rs index 7883e47b8..d55186bf6 100644 --- a/examples/iterating_docs_and_positions.rs +++ b/examples/iterating_docs_and_positions.rs @@ -117,11 +117,16 @@ fn main() -> tantivy::Result<()> { if let Some(mut block_segment_postings) = inverted_index.read_block_postings(&term_the, IndexRecordOption::Basic) { - while block_segment_postings.advance() { + loop { + let docs = block_segment_postings.docs(); + if docs.is_empty() { + break; + } // Once again these docs MAY contains deleted documents as well. let docs = block_segment_postings.docs(); // Prints `Docs [0, 2].` println!("Docs {:?}", docs); + block_segment_postings.advance(); } } } diff --git a/src/postings/block_segment_postings.rs b/src/postings/block_segment_postings.rs index 17145e809..4dc2e89db 100644 --- a/src/postings/block_segment_postings.rs +++ b/src/postings/block_segment_postings.rs @@ -47,7 +47,6 @@ fn decode_vint_block( doc_offset: DocId, num_vint_docs: usize, ) { - doc_decoder.clear(); let num_consumed_bytes = doc_decoder.uncompress_vint_sorted(data, doc_offset, num_vint_docs); if let Some(freq_decoder) = freq_decoder_opt { freq_decoder.uncompress_vint_unsorted(&data[num_consumed_bytes..], num_vint_docs); @@ -99,7 +98,7 @@ impl BlockSegmentPostings { data: postings_data, skip_reader, }; - block_segment_postings.advance(); + block_segment_postings.load_block(); block_segment_postings } @@ -117,13 +116,13 @@ impl BlockSegmentPostings { let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, postings_data); self.data = ReadOnlySource::new(postings_data); self.loaded_offset = std::usize::MAX; - self.loaded_offset = std::usize::MAX; if let Some(skip_data) = skip_data_opt { self.skip_reader.reset(skip_data, doc_freq); } else { self.skip_reader.reset(ReadOnlySource::empty(), doc_freq); } self.doc_freq = doc_freq as usize; + self.load_block(); } /// Returns the document frequency associated to this block postings. @@ -215,6 +214,10 @@ impl BlockSegmentPostings { ); } BlockInfo::VInt(num_vint_docs) => { + self.doc_decoder.clear(); + if num_vint_docs == 0 { + return; + } decode_vint_block( &mut self.doc_decoder, if let FreqReadingOption::ReadFreq = self.freq_reading_option { @@ -233,12 +236,9 @@ impl BlockSegmentPostings { /// Advance to the next block. /// /// Returns false iff there was no remaining blocks. - pub fn advance(&mut self) -> bool { - if !self.skip_reader.advance() { - return false; - } + pub fn advance(&mut self) { + self.skip_reader.advance(); self.load_block(); - true } /// Returns an empty segment postings object @@ -294,7 +294,8 @@ mod tests { #[test] fn test_empty_block_segment_postings() { let mut postings = BlockSegmentPostings::empty(); - assert!(!postings.advance()); + postings.advance(); + assert!(postings.docs().is_empty()); assert_eq!(postings.doc_freq(), 0); } @@ -306,13 +307,14 @@ mod tests { assert_eq!(block_segments.doc_freq(), 100_000); loop { let block = block_segments.docs(); + if block.is_empty() { + break; + } for (i, doc) in block.iter().cloned().enumerate() { assert_eq!(offset + (i as u32), doc); } offset += block.len() as u32; - if block_segments.advance() { - break; - } + block_segments.advance(); } } @@ -421,7 +423,6 @@ mod tests { let term_info = inverted_index.get_term_info(&term).unwrap(); inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments); } - assert!(block_segments.advance()); assert_eq!(block_segments.docs(), &[1, 3, 5]); } } diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index cb3dc2bde..0ad46b006 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -109,6 +109,7 @@ impl BlockDecoder { } pub fn clear(&mut self) { + self.output_len = 0; self.output.0.iter_mut().for_each(|el| *el = TERMINATED); } } @@ -244,6 +245,19 @@ pub mod tests { } } + #[test] + fn test_clearing() { + let mut encoder = BlockEncoder::new(); + let vals = (0u32..128u32).map(|i| i * 3).collect::>(); + let (num_bits, compressed) = encoder.compress_block_sorted(&vals[..], 0u32); + let mut decoder = BlockDecoder::default(); + decoder.uncompress_block_sorted(compressed, 0u32, num_bits); + assert_eq!(decoder.output_len, 128); + assert_eq!(decoder.output_array(), &vals[..]); + decoder.clear(); + assert!(decoder.output_array().is_empty()); + } + #[test] fn test_encode_unsorted_block_with_junk() { let mut compressed: Vec = Vec::new(); diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 5b4b64b08..8b4dc3702 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -100,14 +100,15 @@ impl DocSet for SegmentPostings { } fn seek(&mut self, target: DocId) -> DocId { - if self.doc() == target { - return target; + debug_assert!(self.doc() <= target); + if self.doc() >= target { + return self.doc(); } + self.block_cursor.seek(target); // At this point we are on the block, that might contain our document. let output = self.block_cursor.docs_aligned(); - self.cur = self.block_searcher.search_in_block(&output, target); // The last block is not full and padded with the value TERMINATED, @@ -123,6 +124,7 @@ impl DocSet for SegmentPostings { // After the search, the cursor should point to the first value of TERMINATED. let doc = output.0[self.cur]; debug_assert!(doc >= target); + debug_assert_eq!(doc, self.doc()); doc } diff --git a/src/postings/skip.rs b/src/postings/skip.rs index 6a6d66dc4..bc97ae21d 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -81,25 +81,41 @@ impl Default for BlockInfo { impl SkipReader { pub fn new(data: ReadOnlySource, doc_freq: u32, skip_info: IndexRecordOption) -> SkipReader { - SkipReader { - last_doc_in_block: 0u32, + let mut skip_reader = SkipReader { + last_doc_in_block: if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { + 0 + } else { + TERMINATED + }, last_doc_in_previous_block: 0u32, owned_read: OwnedRead::new(data), skip_info, - block_info: BlockInfo::default(), + block_info: BlockInfo::VInt(doc_freq), byte_offset: 0, remaining_docs: doc_freq, position_offset: 0u64, + }; + if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { + skip_reader.read_block_info(); } + skip_reader } pub fn reset(&mut self, data: ReadOnlySource, doc_freq: u32) { - self.last_doc_in_block = 0u32; + self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { + 0 + } else { + TERMINATED + }; self.last_doc_in_previous_block = 0u32; self.owned_read = OwnedRead::new(data); - self.block_info = BlockInfo::default(); + self.block_info = BlockInfo::VInt(doc_freq); self.byte_offset = 0; self.remaining_docs = doc_freq; + self.position_offset = 0u64; + if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { + self.read_block_info(); + } } #[cfg(test)] @@ -165,7 +181,7 @@ impl SkipReader { } } - pub fn advance(&mut self) -> bool { + pub fn advance(&mut self) { match self.block_info { BlockInfo::BitPacked { doc_num_bits, @@ -177,17 +193,17 @@ impl SkipReader { self.position_offset += tf_sum as u64; } BlockInfo::VInt(num_vint_docs) => { - self.remaining_docs -= num_vint_docs; + debug_assert_eq!(num_vint_docs, self.remaining_docs); + self.remaining_docs = 0; + self.byte_offset = std::usize::MAX; } } self.last_doc_in_previous_block = self.last_doc_in_block; if self.remaining_docs >= COMPRESSION_BLOCK_SIZE as u32 { self.read_block_info(); - true } else { self.last_doc_in_block = TERMINATED; self.block_info = BlockInfo::VInt(self.remaining_docs); - self.remaining_docs > 0 } } } @@ -217,7 +233,6 @@ mod tests { doc_freq, IndexRecordOption::WithFreqs, ); - assert!(skip_reader.advance()); assert_eq!(skip_reader.last_doc_in_block(), 1u32); assert_eq!( skip_reader.block_info(), @@ -227,7 +242,7 @@ mod tests { tf_sum: 0 } ); - assert!(skip_reader.advance()); + skip_reader.advance(); assert_eq!(skip_reader.last_doc_in_block(), 5u32); assert_eq!( skip_reader.block_info(), @@ -237,9 +252,12 @@ mod tests { tf_sum: 0 } ); - assert!(skip_reader.advance()); + skip_reader.advance(); assert_eq!(skip_reader.block_info(), BlockInfo::VInt(3u32)); - assert!(!skip_reader.advance()); + skip_reader.advance(); + assert_eq!(skip_reader.block_info(), BlockInfo::VInt(0u32)); + skip_reader.advance(); + assert_eq!(skip_reader.block_info(), BlockInfo::VInt(0u32)); } #[test] @@ -256,7 +274,6 @@ mod tests { doc_freq, IndexRecordOption::Basic, ); - assert!(skip_reader.advance()); assert_eq!(skip_reader.last_doc_in_block(), 1u32); assert_eq!( skip_reader.block_info(), @@ -266,7 +283,7 @@ mod tests { tf_sum: 0u32 } ); - assert!(skip_reader.advance()); + skip_reader.advance(); assert_eq!(skip_reader.last_doc_in_block(), 5u32); assert_eq!( skip_reader.block_info(), @@ -276,9 +293,12 @@ mod tests { tf_sum: 0u32 } ); - assert!(skip_reader.advance()); + skip_reader.advance(); assert_eq!(skip_reader.block_info(), BlockInfo::VInt(3u32)); - assert!(!skip_reader.advance()); + skip_reader.advance(); + assert_eq!(skip_reader.block_info(), BlockInfo::VInt(0u32)); + skip_reader.advance(); + assert_eq!(skip_reader.block_info(), BlockInfo::VInt(0u32)); } #[test] @@ -294,7 +314,6 @@ mod tests { doc_freq, IndexRecordOption::Basic, ); - assert!(skip_reader.advance()); assert_eq!(skip_reader.last_doc_in_block(), 1u32); assert_eq!( skip_reader.block_info(), @@ -304,6 +323,7 @@ mod tests { tf_sum: 0u32 } ); - assert!(!skip_reader.advance()); + skip_reader.advance(); + assert_eq!(skip_reader.block_info(), BlockInfo::VInt(0u32)); } } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index 855ba8f64..a750ff129 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -43,7 +43,6 @@ where fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result> { let max_doc = reader.max_doc(); let mut doc_bitset = BitSet::with_max_value(max_doc); - let inverted_index = reader.inverted_index(self.field); let term_dict = inverted_index.terms(); let mut term_stream = self.automaton_stream(term_dict); @@ -52,12 +51,14 @@ where let mut block_segment_postings = inverted_index .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic); loop { - for &doc in block_segment_postings.docs() { - doc_bitset.insert(doc); - } - if !block_segment_postings.advance() { + let docs = block_segment_postings.docs(); + if docs.is_empty() { break; } + for &doc in docs { + doc_bitset.insert(doc); + } + block_segment_postings.advance(); } } let doc_bitset = BitSetDocSet::from(doc_bitset); diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 30993b11f..61aa46a5c 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -141,7 +141,6 @@ mod tests { .map(|doc| doc.1) .collect::>() }; - { let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a"))]); assert_eq!(matching_docs(&boolean_query), vec![0, 1, 3]); diff --git a/src/query/exclude.rs b/src/query/exclude.rs index 8dd35be95..d27a1c3e2 100644 --- a/src/query/exclude.rs +++ b/src/query/exclude.rs @@ -3,6 +3,11 @@ use crate::query::Scorer; use crate::DocId; use crate::Score; +#[inline(always)] +fn is_within(docset: &mut TDocSetExclude, doc: DocId) -> bool { + docset.doc() <= doc && docset.seek(doc) == doc +} + /// Filters a given `DocSet` by removing the docs from a given `DocSet`. /// /// The excluding docset has no impact on scoring. @@ -23,8 +28,7 @@ where ) -> Exclude { while underlying_docset.doc() != TERMINATED { let target = underlying_docset.doc(); - if excluding_docset.seek(target) != target { - // this document is not excluded. + if !is_within(&mut excluding_docset, target) { break; } underlying_docset.advance(); @@ -36,42 +40,30 @@ where } } -impl Exclude -where - TDocSet: DocSet, - TDocSetExclude: DocSet, -{ - /// Returns true iff the doc is not removed. - /// - /// The method has to be called with non strictly - /// increasing `doc`. - fn accept(&mut self) -> bool { - let doc = self.underlying_docset.doc(); - self.excluding_docset.seek(doc) != doc - } -} - impl DocSet for Exclude where TDocSet: DocSet, TDocSetExclude: DocSet, { fn advance(&mut self) -> DocId { - while self.underlying_docset.advance() != TERMINATED { - if self.accept() { - return self.doc(); + loop { + let candidate = self.underlying_docset.advance(); + if candidate == TERMINATED { + return TERMINATED; + } + if !is_within(&mut self.excluding_docset, candidate) { + return candidate; } } - TERMINATED } fn seek(&mut self, target: DocId) -> DocId { - let underlying_seek_result = self.underlying_docset.seek(target); - if underlying_seek_result == TERMINATED { + let candidate = self.underlying_docset.seek(target); + if candidate == TERMINATED { return TERMINATED; } - if self.accept() { - return underlying_seek_result; + if !is_within(&mut self.excluding_docset, candidate) { + return candidate; } self.advance() } diff --git a/src/query/intersection.rs b/src/query/intersection.rs index e8a510f07..fa57b85f4 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -53,7 +53,8 @@ pub struct Intersection> } fn go_to_first_doc(docsets: &mut [TDocSet]) -> DocId { - let mut candidate = 0; + assert!(!docsets.is_empty()); + let mut candidate = docsets.iter().map(TDocSet::doc).max().unwrap(); 'outer: loop { for docset in docsets.iter_mut() { let seek_doc = docset.seek(candidate); @@ -119,6 +120,9 @@ impl DocSet for Intersection DocSet for Intersection= target); + doc } fn doc(&self) -> DocId { diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 95eeefd70..fe3b2895a 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -12,10 +12,11 @@ pub mod tests { use super::*; use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE}; use crate::core::Index; + use crate::query::Weight; use crate::schema::{Schema, Term, TEXT}; use crate::tests::assert_nearly_equals; - use crate::DocAddress; use crate::DocId; + use crate::{DocAddress, TERMINATED}; pub fn create_index(texts: &[&'static str]) -> Index { let mut schema_builder = Schema::builder(); @@ -67,6 +68,23 @@ pub mod tests { assert!(test_query(vec!["g", "a"]).is_empty()); } + #[test] + pub fn test_phrase_query_simple() -> crate::Result<()> { + let index = create_index(&["a b b d c g c", "a b a b c"]); + let text_field = index.schema().get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let terms: Vec = vec!["a", "b", "c"] + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let phrase_query = PhraseQuery::new(terms); + let phrase_weight = phrase_query.phrase_weight(&searcher, false)?; + let mut phrase_scorer = phrase_weight.scorer(searcher.segment_reader(0), 1.0f32)?; + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + Ok(()) + } + #[test] pub fn test_phrase_query_no_score() { let index = create_index(&[ diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 99712c5b6..a68116603 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -239,6 +239,7 @@ impl DocSet for PhraseScorer { } fn seek(&mut self, target: DocId) -> DocId { + debug_assert!(target >= self.doc()); let doc = self.intersection_docset.seek(target); if doc == TERMINATED || self.phrase_match() { return doc; @@ -266,7 +267,6 @@ impl Scorer for PhraseScorer { #[cfg(test)] mod tests { - use super::{intersection, intersection_count}; fn test_intersection_sym(left: &[u32], right: &[u32], expected: &[u32]) { diff --git a/src/query/range_query.rs b/src/query/range_query.rs index 4869fba9a..bcb9376c4 100644 --- a/src/query/range_query.rs +++ b/src/query/range_query.rs @@ -301,12 +301,14 @@ impl Weight for RangeWeight { let mut block_segment_postings = inverted_index .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic); loop { + let docs = block_segment_postings.docs(); + if docs.is_empty() { + break; + } for &doc in block_segment_postings.docs() { doc_bitset.insert(doc); } - if !block_segment_postings.advance() { - break; - } + block_segment_postings.advance(); } } let doc_bitset = BitSetDocSet::from(doc_bitset); diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index 0ea904cf1..69b011215 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -11,11 +11,12 @@ mod tests { use crate::collector::TopDocs; use crate::docset::DocSet; + use crate::postings::compression::COMPRESSION_BLOCK_SIZE; use crate::query::{Query, QueryParser, Scorer, TermQuery}; use crate::schema::{Field, IndexRecordOption, Schema, STRING, TEXT}; use crate::tests::assert_nearly_equals; - use crate::Index; use crate::Term; + use crate::{Index, TERMINATED}; #[test] pub fn test_term_query_no_freq() { @@ -42,6 +43,41 @@ mod tests { assert_eq!(term_scorer.score(), 0.28768212); } + #[test] + pub fn test_term_query_multiple_of_block_len() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", STRING); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + // writing the segment + let mut index_writer = index.writer_with_num_threads(1, 3_000_000)?; + for _ in 0..COMPRESSION_BLOCK_SIZE { + let doc = doc!(text_field => "a"); + index_writer.add_document(doc); + } + index_writer.commit()?; + } + let searcher = index.reader()?.searcher(); + let term_query = TermQuery::new( + Term::from_field_text(text_field, "a"), + IndexRecordOption::Basic, + ); + let term_weight = term_query.weight(&searcher, true)?; + let segment_reader = searcher.segment_reader(0); + let mut term_scorer = term_weight.scorer(segment_reader, 1.0f32)?; + for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 { + assert_eq!(term_scorer.doc(), i); + if i == COMPRESSION_BLOCK_SIZE as u32 - 1u32 { + assert_eq!(term_scorer.advance(), TERMINATED); + } else { + assert_eq!(term_scorer.advance(), i + 1); + } + } + assert_eq!(term_scorer.doc(), TERMINATED); + Ok(()) + } + #[test] pub fn test_term_weight() { let mut schema_builder = Schema::builder(); @@ -112,6 +148,27 @@ mod tests { assert_eq!(term_query.count(&*reader.searcher()).unwrap(), 1); } + #[test] + fn test_term_query_simple_seek() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let text_field = schema_builder.add_text_field("text", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + index_writer.add_document(doc!(text_field=>"a")); + index_writer.add_document(doc!(text_field=>"a")); + index_writer.commit()?; + let term_a = Term::from_field_text(text_field, "a"); + let term_query = TermQuery::new(term_a, IndexRecordOption::Basic); + let searcher = index.reader()?.searcher(); + let term_weight = term_query.weight(&searcher, false)?; + let mut term_scorer = term_weight.scorer(searcher.segment_reader(0u32), 1.0f32)?; + assert_eq!(term_scorer.doc(), 0u32); + term_scorer.seek(1u32); + assert_eq!(term_scorer.doc(), 1u32); + Ok(()) + } + #[test] fn test_term_query_debug() { let term_query = TermQuery::new(