diff --git a/Cargo.toml b/Cargo.toml index d7864be30..72bdb6d34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ murmurhash32 = "0.2" chrono = "0.4" smallvec = "1.0" rayon = "1" +ordered-float = "1" [target.'cfg(windows)'.dependencies] winapi = "0.3" @@ -58,6 +59,8 @@ winapi = "0.3" rand = "0.7" maplit = "1" matches = "0.1.8" +proptest = "0.9" +float-cmp = "0.6" [dev-dependencies.fail] version = "0.3" diff --git a/src/postings/block_max_postings.rs b/src/postings/block_max_postings.rs new file mode 100644 index 000000000..9ed93284c --- /dev/null +++ b/src/postings/block_max_postings.rs @@ -0,0 +1,16 @@ +use crate::postings::Postings; +use crate::DocId; + +/// Inverted list with additional information about the maximum term frequency +/// within a block, as well as globally within the list. +pub trait BlockMaxPostings: Postings { + /// Returns the maximum frequency in the entire list. + fn max_term_freq(&self) -> u32; + /// Returns the maximum frequency in the current block. + fn block_max_term_freq(&mut self) -> u32; + /// Returns the document with the largest frequency. + fn max_doc(&self) -> DocId; + /// Returns the document with the largest frequency within the current + /// block. + fn block_max_doc(&self) -> DocId; +} diff --git a/src/postings/block_max_segment_postings.rs b/src/postings/block_max_segment_postings.rs new file mode 100644 index 000000000..58c8a8733 --- /dev/null +++ b/src/postings/block_max_segment_postings.rs @@ -0,0 +1,76 @@ +use crate::postings::{BlockMaxPostings, Postings, SegmentPostings}; +use crate::{DocId, DocSet, SkipResult}; + +/// A wrapper over [`SegmentPostings`](./struct.SegmentPostings.html) +/// with max block frequencies. +pub struct BlockMaxSegmentPostings { + postings: SegmentPostings, + max_blocks: SegmentPostings, + doc_with_max_term_freq: DocId, + max_term_freq: u32, +} + +impl BlockMaxSegmentPostings { + /// Constructs a new segment postings with block-max information. + pub fn new( + postings: SegmentPostings, + max_blocks: SegmentPostings, + doc_with_max_term_freq: DocId, + max_term_freq: u32, + ) -> Self { + Self { + postings, + max_blocks, + doc_with_max_term_freq, + max_term_freq, + } + } +} + +impl DocSet for BlockMaxSegmentPostings { + fn advance(&mut self) -> bool { + self.postings.advance() + } + + fn doc(&self) -> DocId { + self.postings.doc() + } + + fn size_hint(&self) -> u32 { + self.postings.size_hint() + } + + fn skip_next(&mut self, target: DocId) -> SkipResult { + self.postings.skip_next(target) + } +} + +impl Postings for BlockMaxSegmentPostings { + fn term_freq(&self) -> u32 { + self.postings.term_freq() + } + fn positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + self.postings.positions_with_offset(offset, output); + } + fn positions(&mut self, output: &mut Vec) { + self.postings.positions(output); + } +} + +impl BlockMaxPostings for BlockMaxSegmentPostings { + fn max_term_freq(&self) -> u32 { + self.max_term_freq + } + fn block_max_term_freq(&mut self) -> u32 { + if matches!(self.max_blocks.skip_next(self.doc()), SkipResult::End) { + panic!("Max blocks corrupted: reached end of max block"); + } + self.max_blocks.term_freq() + } + fn max_doc(&self) -> DocId { + self.doc_with_max_term_freq + } + fn block_max_doc(&self) -> DocId { + self.max_blocks.doc() + } +} diff --git a/src/postings/mod.rs b/src/postings/mod.rs index b66beb413..aa25ae358 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -2,6 +2,8 @@ Postings module (also called inverted index) */ +mod block_max_postings; +mod block_max_segment_postings; mod block_search; pub(crate) mod compression; /// Postings module @@ -29,6 +31,9 @@ pub use self::term_info::TermInfo; pub use self::segment_postings::{BlockSegmentPostings, SegmentPostings}; +pub use self::block_max_postings::BlockMaxPostings; +pub use self::block_max_segment_postings::BlockMaxSegmentPostings; + pub(crate) use self::stacker::compute_table_size; pub use crate::common::HasLen; diff --git a/src/query/block_max_scorer.rs b/src/query/block_max_scorer.rs new file mode 100644 index 000000000..3ece3adbf --- /dev/null +++ b/src/query/block_max_scorer.rs @@ -0,0 +1,44 @@ +use crate::docset::DocSet; +use crate::DocId; +use crate::Score; +use downcast_rs::impl_downcast; +use std::ops::Deref; +use std::ops::DerefMut; + +/// A set of documents matching a query within a specific segment +/// and having a maximum score within certain blocks. +/// +/// See [`Query`](./trait.Query.html) and [`Scorer`](./trait.Scorer.html). +pub trait BlockMaxScorer: downcast_rs::Downcast + DocSet + 'static { + /// Returns the maximum score within the current block. + /// + /// The blocks are defined when indexing. For example, blocks can be + /// have a specific number postings each, or can be optimized for + /// retrieval speed. Read more in + /// [Faster BlockMax WAND with Variable-sized Blocks][vbmw] + /// + /// This method will perform a bit of computation and is not cached. + /// + /// [vbmw]: https://dl.acm.org/doi/abs/10.1145/3077136.3080780 + fn block_max_score(&mut self) -> Score; + + /// Returns the last document in the current block. + fn block_max_doc(&mut self) -> DocId; + + /// Returns the maximum possible score within the entire document set. + fn max_score(&self) -> Score; +} + +impl_downcast!(BlockMaxScorer); + +impl BlockMaxScorer for Box { + fn block_max_score(&mut self) -> Score { + self.deref_mut().block_max_score() + } + fn max_score(&self) -> Score { + self.deref().max_score() + } + fn block_max_doc(&mut self) -> DocId { + self.deref_mut().block_max_doc() + } +} diff --git a/src/query/block_max_wand.rs b/src/query/block_max_wand.rs new file mode 100644 index 000000000..9e95d5a62 --- /dev/null +++ b/src/query/block_max_wand.rs @@ -0,0 +1,734 @@ +use crate::docset::{DocSet, SkipResult}; +use crate::query::score_combiner::ScoreCombiner; +use crate::query::{BlockMaxScorer, Scorer}; +use crate::DocId; +use crate::Score; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct Pivot { + position: usize, + first_occurrence: usize, + doc: DocId, +} + +enum AdvanceResult { + Found, + Skipped, + End, +} + +/// Find the position in the sorted list of posting lists of the **pivot**. +fn find_pivot_position<'a, TScorer, F>( + mut docsets: impl Iterator, + condition: &F, +) -> Option +where + F: Fn(&Score) -> bool, + TScorer: BlockMaxScorer + Scorer, +{ + let mut position = 0; + let mut uppder_bound = Score::default(); + while let Some(docset) = docsets.next() { + uppder_bound += docset.max_score(); + if condition(&uppder_bound) { + let pivot_doc = docset.doc(); + let first_occurrence = position; + while let Some(docset) = docsets.next() { + if docset.doc() != pivot_doc { + break; + } else { + position += 1; + } + } + return Some(Pivot { + position, + doc: pivot_doc, + first_occurrence, + }); + } + position += 1; + } + None +} + +/// Given an iterator over all ordered lists up to the pivot (inclusive) and the following list (if +/// exists), it returns the next document ID that can be possibly relevant, based on the block max +/// scores. +fn find_next_relevant_doc<'a, T, TScorer>( + docsets_up_to_pivot: &'a mut [T], + pivot_docset: &'a mut T, + docset_after_pivot: Option<&'a mut T>, +) -> DocId +where + T: AsMut, + TScorer: BlockMaxScorer + Scorer, +{ + let pivot_docset = pivot_docset.as_mut(); + let mut next_doc = 1 + docsets_up_to_pivot + .iter_mut() + .map(|docset| docset.as_mut().block_max_doc()) + .chain(std::iter::once(pivot_docset.block_max_doc())) + .min() + .unwrap(); + if let Some(docset) = docset_after_pivot { + let doc = docset.as_mut().doc(); + if doc < next_doc { + next_doc = doc; + } + } + if next_doc <= pivot_docset.doc() { + pivot_docset.doc() + 1 + } else { + next_doc + } +} + +/// Sifts down the first element of the slice. +fn sift_down(docsets: &mut [T]) +where + T: AsRef, + TScorer: BlockMaxScorer + Scorer, +{ + for idx in 1..docsets.len() { + if docsets[idx].as_ref().doc() < docsets[idx - 1].as_ref().doc() { + docsets.swap(idx, idx - 1); + } else { + break; + } + } +} + +/// Creates a `DocSet` that iterates through the union of two or more `DocSet`s, +/// applying [BlockMaxWand] dynamic pruning. +/// +/// [BlockMaxWand]: https://dl.acm.org/doi/10.1145/2009916.2010048 +pub struct BlockMaxWand { + docsets: Vec>, + doc: DocId, + score: Score, + combiner: TScoreCombiner, + threshold_fn: ThresholdFn, +} + +impl BlockMaxWand +where + TScoreCombiner: ScoreCombiner, + TScorer: BlockMaxScorer + Scorer, + ThresholdFn: Fn(&Score) -> bool + 'static, +{ + fn new( + docsets: Vec, + combiner: TScoreCombiner, + threshold_fn: ThresholdFn, + ) -> BlockMaxWand { + let mut non_empty_docsets: Vec<_> = docsets + .into_iter() + .flat_map(|mut docset| { + if docset.advance() { + Some(Box::new(docset)) + } else { + None + } + }) + .collect(); + non_empty_docsets.sort_by_key(Box::::doc); + BlockMaxWand { + docsets: non_empty_docsets, + doc: 0, + score: 0f32, + combiner, + threshold_fn, + } + } + + /// Find the position in the sorted list of posting lists of the **pivot**. + fn find_pivot_position(&self) -> Option { + find_pivot_position( + self.docsets.iter().map(|docset| docset.as_ref()), + &self.threshold_fn, + ) + } + + fn advance_with_pivot(&mut self, pivot: Pivot) -> AdvanceResult { + let block_upper_bound: Score = self.docsets[..=pivot.position] + .iter_mut() + .map(|docset| docset.block_max_score()) + .sum(); + if (self.threshold_fn)(&block_upper_bound) { + println!( + "Above T: {} ({:?})", + pivot.doc, + self.docsets.iter().map(|s| s.doc()).collect::>() + ); + if pivot.doc == self.docsets[0].doc() { + // NOTE(elshize): One additional check needs to be done to improve performance: + // update block-wise bound while accumulating score with the actual score, + // and check each time if still above threshold. + self.combiner.clear(); + for idx in (0..=pivot.position).rev() { + self.combiner.update(self.docsets[idx].as_mut()); + if !self.docsets[idx].advance() { + self.docsets.swap_remove(idx); + } + } + self.score = self.combiner.score(); + self.doc = pivot.doc; + self.docsets.sort_by_key(Box::::doc); + AdvanceResult::Found + } else { + // The subraction is correct because otherwise we would go to the other branch. + let advanced_idx = pivot.first_occurrence - 1; + if !self.docsets[advanced_idx].advance() { + self.docsets.swap_remove(advanced_idx); + } + if self.docsets.is_empty() { + AdvanceResult::End + } else { + sift_down(&mut self.docsets[advanced_idx..]); + AdvanceResult::Skipped + } + } + } else { + let (up_to_pivot, pivot_and_rest) = self.docsets.split_at_mut(pivot.position as usize); + let (pivot, after_pivot) = pivot_and_rest.split_first_mut().unwrap(); + let next_doc = find_next_relevant_doc(up_to_pivot, pivot, after_pivot.first_mut()); + // NOTE(elshize): It might be more efficient to advance the list with the higher + // max score, but let's advance the first one for now for simplicity. + if self.docsets[0].skip_next(next_doc) == SkipResult::End { + self.docsets.swap_remove(0); + } + if self.docsets.is_empty() { + AdvanceResult::End + } else { + sift_down(&mut self.docsets[..]); + AdvanceResult::Skipped + } + } + } +} + +impl DocSet + for BlockMaxWand +where + TScorer: BlockMaxScorer + Scorer, + TScoreCombiner: ScoreCombiner, + ThresholdFn: Fn(&Score) -> bool + 'static, +{ + fn advance(&mut self) -> bool { + loop { + match { + if let Some(pivot) = self.find_pivot_position() { + println!( + "pivot = {:?}\n{:?}", + pivot, + self.docsets.iter().map(|s| s.doc()).collect::>() + ); + self.advance_with_pivot(pivot) + } else { + println!( + "[no pivot] {:?}", + self.docsets.iter().map(|s| s.doc()).collect::>() + ); + AdvanceResult::End + } + } { + AdvanceResult::End => return false, + AdvanceResult::Found => return true, + _ => {} + } + } + } + + fn skip_next(&mut self, target: DocId) -> SkipResult { + while self.doc() < target { + if !self.advance() { + return SkipResult::End; + } + } + if self.doc() == target { + SkipResult::Reached + } else { + SkipResult::OverStep + } + } + + // TODO implement `count` efficiently. + + fn doc(&self) -> DocId { + self.doc + } + + fn size_hint(&self) -> u32 { + 0u32 + } + + fn count_including_deleted(&mut self) -> u32 { + unimplemented!(); + } +} + +impl Scorer + for BlockMaxWand +where + TScoreCombiner: ScoreCombiner, + TScorer: Scorer + BlockMaxScorer, + ThresholdFn: Fn(&Score) -> bool + 'static, +{ + fn score(&mut self) -> Score { + self.score + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::common::HasLen; + use crate::docset::DocSet; + use crate::query::score_combiner::SumCombiner; + use crate::query::Union; + use crate::query::{BlockMaxScorer, Scorer}; + use crate::{DocId, Score}; + use float_cmp::approx_eq; + use ordered_float::OrderedFloat; + use proptest::strategy::Strategy; + use std::cmp::Ordering; + use std::collections::BinaryHeap; + use std::num::Wrapping; + + #[derive(Debug, Clone)] + pub struct VecDocSet { + postings: Vec<(DocId, Score)>, + cursor: Wrapping, + block_max_scores: Vec<(DocId, Score)>, + max_score: Score, + block_size: usize, + } + + impl VecDocSet { + fn new(postings: Vec<(DocId, Score)>, block_size: usize) -> VecDocSet { + let block_max_scores: Vec<_> = postings + .chunks(block_size) + .into_iter() + .map(|block| { + ( + block.iter().last().unwrap().0, + block + .iter() + .map(|(_, s)| OrderedFloat(*s)) + .max() + .unwrap() + .into_inner(), + ) + }) + .collect(); + let max_score = block_max_scores + .iter() + .copied() + .map(|(_, s)| OrderedFloat(s)) + .max() + .unwrap(); + VecDocSet { + postings, + cursor: Wrapping(0_usize) - Wrapping(1_usize), + block_max_scores, + max_score: max_score.into_inner(), + block_size, + } + } + } + + impl DocSet for VecDocSet { + fn advance(&mut self) -> bool { + self.cursor += Wrapping(1); + self.postings.len() > self.cursor.0 + } + + fn doc(&self) -> DocId { + self.postings[self.cursor.0].0 + } + + fn size_hint(&self) -> u32 { + self.len() as u32 + } + } + + impl HasLen for VecDocSet { + fn len(&self) -> usize { + self.postings.len() + } + } + + impl BlockMaxScorer for VecDocSet { + fn max_score(&self) -> Score { + self.max_score + } + fn block_max_score(&mut self) -> Score { + self.block_max_scores[self.cursor.0 / self.block_size].1 + } + fn block_max_doc(&mut self) -> DocId { + self.block_max_scores[self.cursor.0 / self.block_size].0 + } + } + + impl Scorer for VecDocSet { + fn score(&mut self) -> Score { + self.postings[self.cursor.0].1 + } + } + + #[derive(Debug)] + struct ComparableDoc { + feature: T, + doc: D, + } + + impl PartialOrd for ComparableDoc { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for ComparableDoc { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + // Reversed to make BinaryHeap work as a min-heap + let by_feature = other + .feature + .partial_cmp(&self.feature) + .unwrap_or(Ordering::Equal); + + let lazy_by_doc_address = + || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal); + + // In case of a tie on the feature, we sort by ascending + // `DocAddress` in order to ensure a stable sorting of the + // documents. + by_feature.then_with(lazy_by_doc_address) + } + } + + impl PartialEq for ComparableDoc { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } + } + + impl Eq for ComparableDoc {} + + #[derive(Debug)] + struct TopSegmentCollector { + limit: usize, + heap: BinaryHeap>, + } + + impl TopSegmentCollector { + fn new(limit: usize) -> TopSegmentCollector { + TopSegmentCollector { + limit, + heap: BinaryHeap::with_capacity(limit), + } + } + } + + impl TopSegmentCollector { + pub fn harvest(self) -> Vec<(T, DocId)> { + self.heap + .into_sorted_vec() + .into_iter() + .map(|comparable_doc| (comparable_doc.feature, comparable_doc.doc)) + .collect() + } + + /// Return true iff at least K documents have gone through + /// the collector. + #[inline(always)] + pub(crate) fn at_capacity(&self) -> bool { + self.heap.len() >= self.limit + } + + #[inline(always)] + pub(crate) fn above_threshold(&self, elem: &T) -> bool { + if self.at_capacity() { + elem > &self.heap.peek().unwrap().feature + } else { + true + } + } + + /// Collects a document scored by the given feature + /// + /// It collects documents until it has reached the max capacity. Once it reaches capacity, it + /// will compare the lowest scoring item with the given one and keep whichever is greater. + #[inline(always)] + pub fn collect(&mut self, doc: DocId, feature: T) { + if self.at_capacity() { + // It's ok to unwrap as long as a limit of 0 is forbidden. + if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) { + if limit_feature < feature { + if let Some(mut head) = self.heap.peek_mut() { + head.feature = feature; + head.doc = doc; + } + } + } + } else { + // we have not reached capacity yet, so we can just push the + // element. + self.heap.push(ComparableDoc { feature, doc }); + } + } + } + + fn union_vs_bmw(posting_lists: Vec) { + let mut union = Union::::from(posting_lists.clone()); + let mut top_union = TopSegmentCollector::::new(10); + while union.advance() { + top_union.collect(union.doc(), union.score()); + } + let top_bmw = std::rc::Rc::new(std::cell::RefCell::new(TopSegmentCollector::::new( + 10, + ))); + let inner = std::rc::Rc::clone(&top_bmw); + let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default(), move |score| { + let is_above = inner.borrow().above_threshold(score); + println!("{} (above = {})", score, is_above); + is_above + }); + while bmw.advance() { + top_bmw.borrow_mut().collect(bmw.doc(), bmw.score()); + } + drop(bmw); + for ((expected_score, expected_doc), (actual_score, actual_doc)) in + top_union.harvest().into_iter().zip( + std::rc::Rc::try_unwrap(top_bmw) + .unwrap() + .into_inner() + .harvest(), + ) + { + assert!(approx_eq!( + f32, + expected_score, + actual_score, + epsilon = 0.0001 + )); + assert_eq!(expected_doc, actual_doc); + } + } + + #[test] + fn test_bmw_0() { + union_vs_bmw(vec![ + VecDocSet { + postings: vec![ + (0, 1.0), + (23, 1.0), + (28, 1.0), + (56, 1.0), + (59, 1.0), + (66, 1.0), + (93, 1.0), + ], + cursor: Wrapping(0_usize) - Wrapping(1_usize), + block_max_scores: vec![(93, 1.0)], + max_score: 1.0, + block_size: 16, + }, + VecDocSet { + postings: vec![ + (2, 1.6549665), + (43, 2.6958032), + (53, 3.5309567), + (71, 2.7688136), + (87, 3.4279852), + (96, 3.9028034), + ], + cursor: Wrapping(0_usize) - Wrapping(1_usize), + block_max_scores: vec![(96, 3.9028034)], + max_score: 3.9028034, + block_size: 16, + }, + ]) + } + + #[test] + fn test_bmw_1() { + union_vs_bmw(vec![ + VecDocSet { + postings: vec![(73, 1.0), (82, 1.0)], + cursor: Wrapping(0_usize) - Wrapping(1_usize), + block_max_scores: vec![(82, 1.0)], + max_score: 1.0, + block_size: 16, + }, + VecDocSet { + postings: vec![ + (21, 3.582513), + (23, 1.6928024), + (27, 3.887647), + (42, 1.5469292), + (61, 1.7317574), + (62, 1.2968783), + (82, 2.4040694), + (85, 3.1487892), + ], + cursor: Wrapping(0_usize) - Wrapping(1_usize), + block_max_scores: vec![(85, 3.887647)], + max_score: 3.887647, + block_size: 16, + }, + ]) + } + + proptest::proptest! { + #[test] + fn test_union_vs_bmw(postings in proptest::collection::vec( + proptest::collection::vec(0_u32..100, 1..10) + .prop_flat_map(|v| { + let scores = proptest::collection::vec(1_f32..4_f32, v.len()..=v.len()); + scores.prop_map(move |s| { + let mut postings: Vec<_> = v.iter().copied().zip(s.iter().copied()).collect(); + postings.sort_by_key(|p| p.0); + postings.dedup_by_key(|p| p.0); + VecDocSet::new(postings, 16) + }) + }), + 2..5) + ) { + union_vs_bmw(postings); + } + } + + #[test] + fn test_find_pivot_position() { + let postings = vec![ + VecDocSet::new(vec![(0, 2.0)], 1), + VecDocSet::new(vec![(1, 3.0)], 1), + VecDocSet::new(vec![(2, 4.0)], 1), + VecDocSet::new(vec![(3, 5.0)], 1), + VecDocSet::new(vec![(3, 6.0)], 1), + ]; + assert_eq!( + find_pivot_position(postings.iter(), &|&score| score > 2.0), + Some(Pivot { + position: 1, + doc: 1, + first_occurrence: 1, + }) + ); + assert_eq!( + find_pivot_position(postings.iter(), &|&score| score > 5.0), + Some(Pivot { + position: 2, + doc: 2, + first_occurrence: 2, + }) + ); + assert_eq!( + find_pivot_position(postings.iter(), &|&score| score > 9.0), + Some(Pivot { + position: 4, + doc: 3, + first_occurrence: 3, + }) + ); + assert_eq!( + find_pivot_position(postings.iter(), &|&score| score > 20.0), + None + ); + } + + #[test] + fn test_find_next_relevant_doc_before_pivot() { + let mut postings = vec![ + Box::new(VecDocSet::new(vec![(0, 0.0), (3, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(1, 0.0), (4, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(2, 0.0), (6, 0.0)], 2)), // pivot + Box::new(VecDocSet::new(vec![(6, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(6, 0.0), (8, 0.0)], 2)), + ]; + let (up_to_pivot, rest) = postings.split_at_mut(2); + let (pivot, after_pivot) = rest.split_first_mut().unwrap(); + let next_doc = find_next_relevant_doc(up_to_pivot, pivot, Some(&mut after_pivot[0])); + assert_eq!(next_doc, 4); + } + + #[test] + fn test_find_next_relevant_doc_prefix_smaller_than_pivot() { + let mut postings = vec![ + Box::new(VecDocSet::new(vec![(0, 0.0), (3, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(1, 0.0), (4, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(5, 0.0), (8, 0.0)], 2)), // pivot + Box::new(VecDocSet::new(vec![(6, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(6, 0.0), (8, 0.0)], 2)), + ]; + let (up_to_pivot, rest) = postings.split_at_mut(2); + let (pivot, after_pivot) = rest.split_first_mut().unwrap(); + let next_doc = find_next_relevant_doc(up_to_pivot, pivot, Some(&mut after_pivot[0])); + assert_eq!(next_doc, 6); + } + + #[test] + fn test_find_next_relevant_doc_after_pivot() { + let mut postings = vec![ + Box::new(VecDocSet::new(vec![(0, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(1, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(2, 0.0), (8, 0.0)], 2)), // pivot + Box::new(VecDocSet::new(vec![(5, 0.0), (7, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(6, 0.0), (7, 0.0)], 2)), + ]; + let (up_to_pivot, rest) = postings.split_at_mut(2); + let (pivot, after_pivot) = rest.split_first_mut().unwrap(); + let next_doc = find_next_relevant_doc(up_to_pivot, pivot, Some(&mut after_pivot[0])); + assert_eq!(next_doc, 5); + } + + #[test] + fn test_sift_down_already_sifted() { + let mut postings = vec![ + Box::new(VecDocSet::new(vec![(0, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(1, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(2, 0.0), (8, 0.0)], 2)), // pivot + Box::new(VecDocSet::new(vec![(5, 0.0), (7, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(6, 0.0), (7, 0.0)], 2)), + ]; + sift_down(&mut postings[2..]); + assert_eq!( + postings.into_iter().map(|p| p.doc()).collect::>(), + vec![0, 1, 2, 5, 6] + ); + } + + #[test] + fn test_sift_down_sift_one_down() { + let mut postings = vec![ + Box::new(VecDocSet::new(vec![(0, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(1, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(6, 0.0), (8, 0.0)], 2)), // pivot + Box::new(VecDocSet::new(vec![(5, 0.0), (7, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(7, 0.0), (7, 0.0)], 2)), + ]; + sift_down(&mut postings[2..]); + assert_eq!( + postings.into_iter().map(|p| p.doc()).collect::>(), + vec![0, 1, 5, 6, 7] + ); + } + + #[test] + fn test_sift_down_to_bottom() { + let mut postings = vec![ + Box::new(VecDocSet::new(vec![(0, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(1, 0.0), (8, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(7, 0.0), (8, 0.0)], 2)), // pivot + Box::new(VecDocSet::new(vec![(5, 0.0), (7, 0.0)], 2)), + Box::new(VecDocSet::new(vec![(6, 0.0), (7, 0.0)], 2)), + ]; + sift_down(&mut postings[2..]); + assert_eq!( + postings.into_iter().map(|p| p.doc()).collect::>(), + vec![0, 1, 5, 6, 7] + ); + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index 6f52784a5..5624aeffc 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -5,6 +5,8 @@ Query mod all_query; mod automaton_weight; mod bitset; +mod block_max_scorer; +mod block_max_wand; mod bm25; mod boolean_query; mod boost_query; @@ -24,7 +26,6 @@ mod term_query; mod union; mod weight; - #[cfg(test)] mod vec_docset; @@ -38,13 +39,14 @@ pub use self::vec_docset::VecDocSet; pub use self::all_query::{AllQuery, AllScorer, AllWeight}; pub use self::automaton_weight::AutomatonWeight; pub use self::bitset::BitSetDocSet; +pub use self::block_max_scorer::BlockMaxScorer; pub use self::boolean_query::BooleanQuery; pub use self::boost_query::BoostQuery; pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight}; pub use self::exclude::Exclude; pub use self::explanation::Explanation; -pub use self::fuzzy_query::FuzzyTermQuery; pub(crate) use self::fuzzy_query::DFAWrapper; +pub use self::fuzzy_query::FuzzyTermQuery; pub use self::intersection::intersect_scorers; pub use self::phrase_query::PhraseQuery; pub use self::query::Query; diff --git a/src/query/term_query/block_max_term_scorer.rs b/src/query/term_query/block_max_term_scorer.rs new file mode 100644 index 000000000..33346ad00 --- /dev/null +++ b/src/query/term_query/block_max_term_scorer.rs @@ -0,0 +1,98 @@ +use crate::docset::{DocSet, SkipResult}; +use crate::query::{Explanation, Scorer}; +use crate::DocId; +use crate::Score; + +use crate::fieldnorm::FieldNormReader; +use crate::postings::Postings; +use crate::postings::{BlockMaxPostings, BlockMaxSegmentPostings}; +use crate::query::bm25::BM25Weight; +use crate::query::BlockMaxScorer; + +pub struct BlockMaxTermScorer { + postings: BlockMaxSegmentPostings, + fieldnorm_reader: FieldNormReader, + similarity_weight: BM25Weight, +} + +impl BlockMaxTermScorer { + pub fn new( + postings: BlockMaxSegmentPostings, + fieldnorm_reader: FieldNormReader, + similarity_weight: BM25Weight, + ) -> Self { + Self { + postings, + fieldnorm_reader, + similarity_weight, + } + } +} + +impl BlockMaxTermScorer { + fn _score(&self, fieldnorm_id: u8, term_freq: u32) -> Score { + self.similarity_weight.score(fieldnorm_id, term_freq) + } + + pub fn term_freq(&self) -> u32 { + self.postings.term_freq() + } + + pub fn fieldnorm_id(&self) -> u8 { + self.fieldnorm_reader.fieldnorm_id(self.doc()) + } + + pub fn explain(&self) -> Explanation { + let fieldnorm_id = self.fieldnorm_id(); + let term_freq = self.term_freq(); + self.similarity_weight.explain(fieldnorm_id, term_freq) + } +} + +impl DocSet for BlockMaxTermScorer { + fn advance(&mut self) -> bool { + self.postings.advance() + } + + fn skip_next(&mut self, target: DocId) -> SkipResult { + self.postings.skip_next(target) + } + + fn doc(&self) -> DocId { + self.postings.doc() + } + + fn size_hint(&self) -> u32 { + self.postings.size_hint() + } +} + +impl Scorer for BlockMaxTermScorer { + fn score(&mut self) -> Score { + self._score( + self.fieldnorm_reader.fieldnorm_id(self.doc()), + self.postings.term_freq(), + ) + } +} + +impl BlockMaxScorer for BlockMaxTermScorer { + fn block_max_score(&mut self) -> Score { + self._score( + self.fieldnorm_reader + .fieldnorm_id(self.postings.block_max_doc()), + self.postings.term_freq(), + ) + } + + fn block_max_doc(&mut self) -> DocId { + self.postings.block_max_doc() + } + + fn max_score(&self) -> Score { + self._score( + self.fieldnorm_reader.fieldnorm_id(self.postings.max_doc()), + self.postings.max_term_freq(), + ) + } +} diff --git a/src/query/term_query/mod.rs b/src/query/term_query/mod.rs index d38756b7e..8349bd6c4 100644 --- a/src/query/term_query/mod.rs +++ b/src/query/term_query/mod.rs @@ -1,7 +1,9 @@ +mod block_max_term_scorer; mod term_query; mod term_scorer; mod term_weight; +pub use self::block_max_term_scorer::BlockMaxTermScorer; pub use self::term_query::TermQuery; pub use self::term_scorer::TermScorer; pub use self::term_weight::TermWeight;