diff --git a/src/postings/block_segment_postings.rs b/src/postings/block_segment_postings.rs index 17145e809..9c55e52c3 100644 --- a/src/postings/block_segment_postings.rs +++ b/src/postings/block_segment_postings.rs @@ -20,10 +20,10 @@ pub struct BlockSegmentPostings { freq_decoder: BlockDecoder, freq_reading_option: FreqReadingOption, - doc_freq: usize, + doc_freq: u32, data: ReadOnlySource, - skip_reader: SkipReader, + pub(crate) skip_reader: SkipReader, } fn decode_bitpacked_block( @@ -89,7 +89,6 @@ impl BlockSegmentPostings { None => SkipReader::new(ReadOnlySource::empty(), doc_freq, record_option), }; - let doc_freq = doc_freq as usize; let mut block_segment_postings = BlockSegmentPostings { doc_decoder: BlockDecoder::with_val(TERMINATED), loaded_offset: std::usize::MAX, @@ -123,14 +122,14 @@ impl BlockSegmentPostings { } else { self.skip_reader.reset(ReadOnlySource::empty(), doc_freq); } - self.doc_freq = doc_freq as usize; + self.doc_freq = doc_freq; } /// Returns the document frequency associated to this block postings. /// /// This `doc_freq` is simply the sum of the length of all of the blocks /// length, and it does not take in account deleted documents. - pub fn doc_freq(&self) -> usize { + pub fn doc_freq(&self) -> u32 { self.doc_freq } diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index f55e1cf1e..7c2a266cf 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -38,8 +38,8 @@ impl SegmentPostings { } } - pub fn doc_freq(&self) -> usize { - self.block_cursor.doc_freq + pub fn doc_freq(&self) -> u32 { + self.block_cursor.doc_freq() } /// Creates a segment postings object with the given documents @@ -143,7 +143,7 @@ impl DocSet for SegmentPostings { impl HasLen for SegmentPostings { fn len(&self) -> usize { - self.block_cursor.doc_freq() + self.block_cursor.doc_freq() as usize } } diff --git a/src/postings/skip.rs b/src/postings/skip.rs index d5a67de70..6613618b7 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -2,7 +2,7 @@ use crate::common::BinarySerializable; use crate::directory::ReadOnlySource; use crate::postings::compression::{compressed_block_size, COMPRESSION_BLOCK_SIZE}; use crate::schema::IndexRecordOption; -use crate::{DocId, TERMINATED}; +use crate::{DocId, Score, TERMINATED}; use owned_read::OwnedRead; pub struct SkipSerializer { @@ -102,8 +102,11 @@ impl SkipReader { self.remaining_docs = doc_freq; } - #[inline(always)] - pub(crate) fn last_doc_in_block(&self) -> DocId { + pub fn block_max_score(&self) -> Score { + unimplemented!(); + } + + pub fn last_doc_in_block(&self) -> DocId { self.last_doc_in_block } @@ -159,7 +162,7 @@ impl SkipReader { /// If the target is larger than all documents, the skip_reader /// then advance to the last Variable In block. pub fn seek(&mut self, target: DocId) { - while self.last_doc_in_block < target { + while self.last_doc_in_block() < target { self.advance(); } } diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index 3878092ee..3b5fb74bc 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -1,73 +1,106 @@ -use crate::{Score, DocId, TERMINATED, DocSet}; use crate::query::term_query::TermScorer; -use crate::postings::BlockSegmentPostings; -use futures::AsyncSeekExt; +use crate::query::Scorer; +use crate::{DocId, DocSet, Score, TERMINATED}; - -struct BlockWAND { - term_scorers : Vec, -} - - -fn find_pivot_doc(term_scorers: &[TermScorer], threshold: f32) -> DocId { +/// Returns the lowest document that has a chance of exceeding the +/// threshold score. +/// +/// term_scorers are assumed sorted by .doc(). +fn find_pivot_doc(term_scorers: &[TermScorer], threshold: f32) -> Option { let mut max_score = 0.0f32; - for term_scorer in term_scorers.iter() { + for term_scorer in term_scorers { max_score += term_scorer.max_score(); if max_score > threshold { - return term_scorer.doc(); + return Some(term_scorer.doc()); } } - TERMINATED + None } fn shallow_advance(scorers: &mut Vec, pivot: DocId) -> Score { - let mut max_block_score = 0.0f32; - let mut i = 0; - while i < scorers.len() { - if scorers[i].doc() > pivot { + let mut block_max_score_upperbound = 0.0f32; + for scorer in scorers { + if scorer.doc() > pivot { break; } - while scorers[i].postings.block_cursor.skip_reader.doc() < pivot { - if scorers[i].postings.block_cursor.skip_reader.advance() { - max_block_score += scorers[i].postings.block_cursor.skip_reader.block_max_score(); - i += 1; - } else { - scorers.swap_remove(i); - } - } + scorer.postings.block_cursor.seek(pivot); + block_max_score_upperbound += scorer.postings.block_cursor.skip_reader.block_max_score(); } - max_block_score + block_max_score_upperbound } -pub fn block_wand(mut scorers: Vec, mut threshold: f32, callback: &mut dyn FnMut(u32, Score) -> Score) { +fn compute_score(scorers: &mut Vec, doc: DocId) -> Score { + let mut i = 0; + let mut score = 0.0f32; + while i < scorers.len() { + if scorers[i].doc() > doc { + break; + } + if scorers[i].seek(doc) == TERMINATED { + scorers.swap_remove(i); + } else { + score += scorers[i].score(); + i += 1; + } + } + score +} + +fn advance_all_scorers(scorers: &mut Vec) { + let mut i = 0; + while i < scorers.len() { + if scorers[i].advance() == TERMINATED { + scorers.swap_remove(i); + } else { + i += 1; + } + } +} + +pub fn block_wand( + mut scorers: Vec, + mut threshold: f32, + callback: &mut dyn FnMut(u32, Score) -> Score, +) { loop { scorers.sort_by_key(|scorer| scorer.doc()); - let pivot_doc = find_pivot_doc(&scorers, threshold); - if pivot_doc == TERMINATED { - return; - } - if shallow_advance(&mut scorers, pivot_doc) > threshold { - if scorers[0].doc() == pivot_doc { - // EvaluatePartial(d , p); - // Move all pointers from lists[0] to lists[p] by calling - // Next(list, d + 1) - } else { - let scorer_id = scorers.iter_mut() - .take_while(|term_scorer| term_scorer.doc() < pivot_doc) + let pivot_opt = find_pivot_doc(&scorers, threshold); + if let Some(pivot_doc) = pivot_opt { + let block_max_score_upperbound = shallow_advance(&mut scorers, pivot_doc); + // TODO bug: more than one scorer can point on the pivot. + if block_max_score_upperbound <= threshold { + // TODO choose a better candidate. + if scorers[0].seek(pivot_doc + 1) == TERMINATED { + scorers.swap_remove(0); + } + continue; + } + + if scorers[0].doc() != pivot_doc { + // all scorers are not aligned on pivot_doc. + if let Some(scorer_ord) = scorers + .iter_mut() + .take_while(|scorer| scorer.doc() < pivot_doc) .enumerate() - .min_by_key(|(scorer_id, scorer)| scorer.doc_freq()) - .map(|(scorer_id, scorer)| scorer_id) - .unwrap(); - if scorers[scorer_id].seek(pivot_doc) == TERMINATED { - scorers.swap_remove(scorer_id); + .min_by_key(|(_ord, scorer)| scorer.doc_freq()) + .map(|(ord, _scorer)| ord) + { + // TODOD FIX seek, right now the block will never get loaded. + if scorers[scorer_ord].seek(pivot_doc) == TERMINATED { + scorers.swap_remove(scorer_ord); + } + continue; } } + // TODO no need to fully score? + let score = compute_score(&mut scorers, pivot_doc); + if score > threshold { + threshold = callback(pivot_doc, score); + } + advance_all_scorers(&mut scorers); } else { - //d = GetNewCandidate(); - //Choose one list from the lists before and including lists[p] - //with the largest IDF, move it by calling Next(list, d) + return; } } - -} \ No newline at end of file +} diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index f62cc911c..a46266230 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -7,9 +7,9 @@ use crate::query::Exclude; use crate::query::Occur; use crate::query::RequiredOptionalScorer; use crate::query::Scorer; -use crate::query::{Union, TermUnion}; use crate::query::Weight; use crate::query::{intersect_scorers, Explanation}; +use crate::query::{TermUnion, Union}; use crate::DocId; use std::collections::HashMap; @@ -29,8 +29,7 @@ where .into_iter() .map(|scorer| *(scorer.downcast::().map_err(|_| ()).unwrap())) .collect(); - let scorer: Box = - Box::new(TermUnion::::from(scorers)); + let scorer: Box = Box::new(TermUnion::::from(scorers)); return scorer; } } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 7f3330d99..3cc1ef58b 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -1,9 +1,9 @@ +mod block_wand; mod boolean_query; mod boolean_weight; -mod block_wand; -pub use self::boolean_query::BooleanQuery; pub(crate) use self::block_wand::block_wand; +pub use self::boolean_query::BooleanQuery; #[cfg(test)] mod tests { diff --git a/src/query/mod.rs b/src/query/mod.rs index b5fa68d64..8a40f4604 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -27,7 +27,7 @@ mod vec_docset; pub(crate) mod score_combiner; pub use self::intersection::Intersection; -pub use self::union::{Union, TermUnion}; +pub use self::union::{TermUnion, Union}; #[cfg(test)] pub use self::vec_docset::VecDocSet; diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index e1a3b8eff..bd2391918 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -33,8 +33,8 @@ impl TermScorer { self.postings.term_freq() } - pub fn doc_freq(&self,) -> usize { - self.postings.doc_freq() + pub fn doc_freq(&self) -> usize { + self.postings.doc_freq() as usize } pub fn fieldnorm_id(&self) -> u8 { @@ -47,7 +47,7 @@ impl TermScorer { self.similarity_weight.explain(fieldnorm_id, term_freq) } - pub fn max_score(&self, ) -> f32 { + pub fn max_score(&self) -> f32 { unimplemented!(); } } diff --git a/src/query/union.rs b/src/query/union.rs index e3abc79f7..43fa407a2 100644 --- a/src/query/union.rs +++ b/src/query/union.rs @@ -1,12 +1,12 @@ use crate::common::TinySet; -use crate::query::boolean_query::block_wand; use crate::docset::{DocSet, TERMINATED}; +use crate::fastfield::DeleteBitSet; +use crate::query::boolean_query::block_wand; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; +use crate::query::term_query::TermScorer; use crate::query::Scorer; use crate::DocId; use crate::Score; -use crate::query::term_query::TermScorer; -use crate::fastfield::DeleteBitSet; const HORIZON_NUM_TINYBITSETS: usize = 64; const HORIZON: u32 = 64u32 * HORIZON_NUM_TINYBITSETS as u32; @@ -41,8 +41,6 @@ pub struct Union { score: Score, } - - impl From> for Union where TScoreCombiner: ScoreCombiner, @@ -208,7 +206,11 @@ where } fn size_hint(&self) -> u32 { - self.docsets.iter().map(|docset| docset.size_hint()).max().unwrap_or(0u32) + self.docsets + .iter() + .map(|docset| docset.size_hint()) + .max() + .unwrap_or(0u32) } fn count_including_deleted(&mut self) -> u32 { @@ -234,7 +236,6 @@ where } } - impl Scorer for Union where TScoreCombiner: ScoreCombiner, @@ -246,13 +247,13 @@ where } pub struct TermUnion { - underlying: Union + underlying: Union, } impl From> for TermUnion { fn from(scorers: Vec) -> Self { TermUnion { - underlying: Union::from(scorers) + underlying: Union::from(scorers), } } } @@ -267,7 +268,7 @@ impl DocSet for TermUnion { } fn fill_buffer(&mut self, buffer: &mut [u32]) -> usize { - self.underlying.fill_buffer(buffer) + self.underlying.fill_buffer(buffer) } fn doc(&self) -> u32 { @@ -275,30 +276,33 @@ impl DocSet for TermUnion { } fn size_hint(&self) -> u32 { - self.underlying.size_hint() + self.underlying.size_hint() } fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 { - self.underlying.count(delete_bitset) + self.underlying.count(delete_bitset) } fn count_including_deleted(&mut self) -> u32 { - self.underlying.count_including_deleted() + self.underlying.count_including_deleted() } } impl Scorer for TermUnion { fn score(&mut self) -> f32 { - self.underlying.score() + self.underlying.score() } - fn for_each_pruning(&mut self, mut threshold: f32, callback: &mut dyn FnMut(DocId, Score) -> Score) { + fn for_each_pruning( + &mut self, + threshold: f32, + callback: &mut dyn FnMut(DocId, Score) -> Score, + ) { let term_scorers = std::mem::replace(&mut self.underlying.docsets, vec![]); block_wand(term_scorers, threshold, callback); } } - #[cfg(test)] mod tests {