Using the regular segment collector in blockwand. Removing the condition function from blockwand

This commit is contained in:
Paul Masurel
2020-05-05 00:19:44 +09:00
parent 61fc1e353a
commit c3ccb8aa81
4 changed files with 59 additions and 118 deletions

View File

@@ -100,6 +100,9 @@ mod top_collector;
mod top_score_collector;
pub use self::top_score_collector::TopDocs;
#[cfg(test)]
pub(crate) use self::top_score_collector::TopScoreSegmentCollector;
mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};

View File

@@ -128,7 +128,7 @@ pub(crate) struct TopSegmentCollector<T> {
}
impl<T: PartialOrd> TopSegmentCollector<T> {
fn new(segment_id: SegmentLocalId, limit: usize) -> TopSegmentCollector<T> {
pub fn new(segment_id: SegmentLocalId, limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
limit,
heap: BinaryHeap::with_capacity(limit),

View File

@@ -429,6 +429,12 @@ impl Collector for TopDocs {
/// Segment Collector associated to `TopDocs`.
pub struct TopScoreSegmentCollector(TopSegmentCollector<Score>);
impl TopScoreSegmentCollector {
pub fn new(segment_id: SegmentLocalId, limit: usize) -> Self {
TopScoreSegmentCollector(TopSegmentCollector::new(segment_id, limit))
}
}
impl SegmentCollector for TopScoreSegmentCollector {
type Fruit = Vec<(Score, DocAddress)>;

View File

@@ -3,6 +3,7 @@ use crate::query::score_combiner::ScoreCombiner;
use crate::query::{BlockMaxScorer, Scorer};
use crate::DocId;
use crate::Score;
use crate::query::scorer::ScorerWithPruning;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct Pivot {
@@ -17,19 +18,18 @@ struct Pivot {
/// docsets need to be advanced, and are required to be sorted by the doc they point to.
///
/// The pivot is then defined as the lowest DocId that has a chance of matching our condition.
fn find_pivot_position<'a, TScorer, F>(
fn find_pivot_position<'a, TScorer>(
mut docsets: impl Iterator<Item = &'a TScorer>,
condition: &F,
lower_bound_score: Score,
) -> Option<Pivot>
where
F: Fn(&Score) -> bool,
TScorer: BlockMaxScorer + Scorer,
{
let mut position = 0;
let mut upper_bound = Score::default();
while let Some(docset) = docsets.next() {
upper_bound += docset.max_score();
if condition(&upper_bound) {
if lower_bound_score < upper_bound {
let pivot_doc = docset.doc();
let first_occurrence = position;
while let Some(docset) = docsets.next() {
@@ -101,25 +101,22 @@ fn sift_down<T, TScorer>(docsets: &mut [T])
/// applying [BlockMaxWand] dynamic pruning.
///
/// [BlockMaxWand]: https://dl.acm.org/doi/10.1145/2009916.2010048
pub struct BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner> {
pub struct BlockMaxWand<TScorer, TScoreCombiner> {
docsets: Vec<Box<TScorer>>,
doc: DocId,
score: Score,
combiner: TScoreCombiner,
threshold_fn: ThresholdFn,
}
impl<TScorer, ThresholdFn, TScoreCombiner> BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
impl<TScorer, TScoreCombiner> BlockMaxWand<TScorer, TScoreCombiner>
where
TScoreCombiner: ScoreCombiner,
TScorer: BlockMaxScorer + Scorer,
ThresholdFn: Fn(&Score) -> bool + 'static,
{
fn new(
docsets: Vec<TScorer>,
combiner: TScoreCombiner,
threshold_fn: ThresholdFn,
) -> BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner> {
) -> BlockMaxWand<TScorer, TScoreCombiner> {
let mut non_empty_docsets: Vec<_> = docsets
.into_iter()
.flat_map(|mut docset| {
@@ -134,26 +131,24 @@ impl<TScorer, ThresholdFn, TScoreCombiner> BlockMaxWand<TScorer, ThresholdFn, TS
BlockMaxWand {
docsets: non_empty_docsets,
combiner,
threshold_fn,
doc: 0u32,
score: 0f32
}
}
/// Find the position in the sorted list of posting lists of the **pivot**.
fn find_pivot_position(&self) -> Option<Pivot> {
fn find_pivot_position(&self, lower_bound_score: Score) -> Option<Pivot> {
find_pivot_position(
self.docsets.iter().map(|docset| docset.as_ref()),
&self.threshold_fn,
)
lower_bound_score)
}
fn advance_with_pivot(&mut self, pivot: Pivot) -> SkipResult {
fn advance_with_pivot(&mut self, pivot: Pivot, lower_bound_score: Score) -> SkipResult {
let block_upper_bound: Score = self.docsets[..=pivot.position]
.iter_mut()
.map(|docset| docset.block_max_score())
.sum();
if (self.threshold_fn)(&block_upper_bound) {
if block_upper_bound > lower_bound_score {
if pivot.doc == self.docsets[0].doc() {
// Since self.docsets is sorted by their current doc, in this branch, all
// docsets in [0..=pivot] are positioned on pivot.doc.
@@ -206,22 +201,14 @@ impl<TScorer, ThresholdFn, TScoreCombiner> BlockMaxWand<TScorer, ThresholdFn, TS
}
}
impl<TScorer, ThresholdFn, TScoreCombiner> DocSet
for BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
impl<TScorer, TScoreCombiner> DocSet
for BlockMaxWand<TScorer, TScoreCombiner>
where
TScorer: BlockMaxScorer + Scorer,
TScoreCombiner: ScoreCombiner,
ThresholdFn: Fn(&Score) -> bool + 'static,
{
fn advance(&mut self) -> bool {
while let Some(pivot) = self.find_pivot_position() {
match self.advance_with_pivot(pivot) {
SkipResult::End => { return false },
SkipResult::Reached=> { return true; }
SkipResult::OverStep => {}
}
}
false
unimplemented!();
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
@@ -252,16 +239,38 @@ for BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
}
}
impl<TScorer, ThresholdFn, TScoreCombiner> Scorer
for BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
impl<TScorer, TScoreCombiner> Scorer
for BlockMaxWand<TScorer, TScoreCombiner>
where
TScoreCombiner: ScoreCombiner,
TScorer: Scorer + BlockMaxScorer,
ThresholdFn: Fn(&Score) -> bool + 'static,
{
fn score(&mut self) -> Score {
self.score
}
/// Returns `Some(&mut self)` if pruning is supported by the current scorer.
/// None, if pruning is score is not supported.
fn get_pruning_scorer(&mut self) -> Option<&mut dyn ScorerWithPruning> {
Some(self)
}
}
impl<TScorer, TScoreCombiner> ScorerWithPruning
for BlockMaxWand<TScorer, TScoreCombiner>
where
TScoreCombiner: ScoreCombiner,
TScorer: Scorer + BlockMaxScorer {
fn advance_with_pruning(&mut self, lower_bound_score: f32) -> bool {
while let Some(pivot) = self.find_pivot_position(lower_bound_score) {
match self.advance_with_pivot(pivot, lower_bound_score) {
SkipResult::End => { return false },
SkipResult::Reached=> { return true; }
SkipResult::OverStep => {}
}
}
false
}
}
#[cfg(test)]
@@ -277,8 +286,8 @@ mod tests {
use float_cmp::approx_eq;
use proptest::strategy::Strategy;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::num::Wrapping;
use crate::collector::{SegmentCollector, TopScoreSegmentCollector};
#[derive(Debug, Clone)]
pub struct VecDocSet {
@@ -403,94 +412,17 @@ mod tests {
impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {}
#[derive(Debug)]
struct TopSegmentCollector<T> {
limit: usize,
heap: BinaryHeap<ComparableDoc<T, DocId>>,
}
impl<T: PartialOrd> TopSegmentCollector<T> {
fn new(limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
limit,
heap: BinaryHeap::with_capacity(limit),
}
}
}
impl<T: PartialOrd + Clone + Copy + Default> TopSegmentCollector<T> {
pub fn harvest(self) -> Vec<(T, DocId)> {
self.heap
.into_sorted_vec()
.into_iter()
.map(|comparable_doc| (comparable_doc.feature, comparable_doc.doc))
.collect()
}
/// Return true iff at least K documents have gone through
/// the collector.
#[inline(always)]
pub(crate) fn at_capacity(&self) -> bool {
self.heap.len() >= self.limit
}
#[inline(always)]
pub(crate) fn above_threshold(&self, elem: &T) -> bool {
if self.at_capacity() {
elem > &self.heap.peek().unwrap().feature
} else {
true
}
}
/// Collects a document scored by the given feature
///
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
/// will compare the lowest scoring item with the given one and keep whichever is greater.
#[inline(always)]
pub fn collect(&mut self, doc: DocId, feature: T) {
if self.at_capacity() {
// It's ok to unwrap as long as a limit of 0 is forbidden.
if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) {
if limit_feature < feature {
if let Some(mut head) = self.heap.peek_mut() {
head.feature = feature;
head.doc = doc;
}
}
}
} else {
// we have not reached capacity yet, so we can just push the
// element.
self.heap.push(ComparableDoc { feature, doc });
}
}
}
fn union_vs_bmw(posting_lists: Vec<VecDocSet>) {
let mut union = Union::<VecDocSet, SumCombiner>::from(posting_lists.clone());
let mut top_union = TopSegmentCollector::<Score>::new(10);
let mut top_union = TopScoreSegmentCollector::new(0, 10);
while union.advance() {
top_union.collect(union.doc(), union.score());
}
let top_bmw = std::rc::Rc::new(std::cell::RefCell::new(TopSegmentCollector::<Score>::new(
10,
)));
let inner = std::rc::Rc::clone(&top_bmw);
let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default(), move |score| {
inner.borrow().above_threshold(score)
});
while bmw.advance() {
top_bmw.borrow_mut().collect(bmw.doc(), bmw.score());
}
drop(bmw);
let top_bmw = TopScoreSegmentCollector::new(0, 10 );
let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default());
let top_docs_bnw = top_bmw.collect_scorer(&mut bmw, None);
for ((expected_score, expected_doc), (actual_score, actual_doc)) in
top_union.harvest().into_iter().zip(
std::rc::Rc::try_unwrap(top_bmw)
.unwrap()
.into_inner()
.harvest(),
)
top_union.harvest().into_iter().zip( top_docs_bnw )
{
assert!(approx_eq!(
f32,
@@ -595,7 +527,7 @@ mod tests {
VecDocSet::started(vec![(3, 6.0)], 1),
];
assert_eq!(
find_pivot_position(postings.iter(), &|&score| score > 2.0),
find_pivot_position(postings.iter(), 2.0f32),
Some(Pivot {
position: 1,
doc: 1,
@@ -603,7 +535,7 @@ mod tests {
})
);
assert_eq!(
find_pivot_position(postings.iter(), &|&score| score > 5.0),
find_pivot_position(postings.iter(), 5.0f32),
Some(Pivot {
position: 2,
doc: 2,
@@ -611,7 +543,7 @@ mod tests {
})
);
assert_eq!(
find_pivot_position(postings.iter(), &|&score| score > 9.0),
find_pivot_position(postings.iter(), 9.0f32),
Some(Pivot {
position: 4,
doc: 3,
@@ -619,7 +551,7 @@ mod tests {
})
);
assert_eq!(
find_pivot_position(postings.iter(), &|&score| score > 20.0),
find_pivot_position(postings.iter(), 20.0f32),
None
);
}