mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-05-31 07:30:39 +00:00
Using the regular segment collector in blockwand. Removing the condition function from blockwand
This commit is contained in:
@@ -100,6 +100,9 @@ mod top_collector;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_score_collector::TopDocs;
|
||||
#[cfg(test)]
|
||||
pub(crate) use self::top_score_collector::TopScoreSegmentCollector;
|
||||
|
||||
|
||||
mod custom_score_top_collector;
|
||||
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
|
||||
|
||||
@@ -128,7 +128,7 @@ pub(crate) struct TopSegmentCollector<T> {
|
||||
}
|
||||
|
||||
impl<T: PartialOrd> TopSegmentCollector<T> {
|
||||
fn new(segment_id: SegmentLocalId, limit: usize) -> TopSegmentCollector<T> {
|
||||
pub fn new(segment_id: SegmentLocalId, limit: usize) -> TopSegmentCollector<T> {
|
||||
TopSegmentCollector {
|
||||
limit,
|
||||
heap: BinaryHeap::with_capacity(limit),
|
||||
|
||||
@@ -429,6 +429,12 @@ impl Collector for TopDocs {
|
||||
/// Segment Collector associated to `TopDocs`.
|
||||
pub struct TopScoreSegmentCollector(TopSegmentCollector<Score>);
|
||||
|
||||
impl TopScoreSegmentCollector {
|
||||
pub fn new(segment_id: SegmentLocalId, limit: usize) -> Self {
|
||||
TopScoreSegmentCollector(TopSegmentCollector::new(segment_id, limit))
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentCollector for TopScoreSegmentCollector {
|
||||
type Fruit = Vec<(Score, DocAddress)>;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::query::score_combiner::ScoreCombiner;
|
||||
use crate::query::{BlockMaxScorer, Scorer};
|
||||
use crate::DocId;
|
||||
use crate::Score;
|
||||
use crate::query::scorer::ScorerWithPruning;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
struct Pivot {
|
||||
@@ -17,19 +18,18 @@ struct Pivot {
|
||||
/// docsets need to be advanced, and are required to be sorted by the doc they point to.
|
||||
///
|
||||
/// The pivot is then defined as the lowest DocId that has a chance of matching our condition.
|
||||
fn find_pivot_position<'a, TScorer, F>(
|
||||
fn find_pivot_position<'a, TScorer>(
|
||||
mut docsets: impl Iterator<Item = &'a TScorer>,
|
||||
condition: &F,
|
||||
lower_bound_score: Score,
|
||||
) -> Option<Pivot>
|
||||
where
|
||||
F: Fn(&Score) -> bool,
|
||||
TScorer: BlockMaxScorer + Scorer,
|
||||
{
|
||||
let mut position = 0;
|
||||
let mut upper_bound = Score::default();
|
||||
while let Some(docset) = docsets.next() {
|
||||
upper_bound += docset.max_score();
|
||||
if condition(&upper_bound) {
|
||||
if lower_bound_score < upper_bound {
|
||||
let pivot_doc = docset.doc();
|
||||
let first_occurrence = position;
|
||||
while let Some(docset) = docsets.next() {
|
||||
@@ -101,25 +101,22 @@ fn sift_down<T, TScorer>(docsets: &mut [T])
|
||||
/// applying [BlockMaxWand] dynamic pruning.
|
||||
///
|
||||
/// [BlockMaxWand]: https://dl.acm.org/doi/10.1145/2009916.2010048
|
||||
pub struct BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner> {
|
||||
pub struct BlockMaxWand<TScorer, TScoreCombiner> {
|
||||
docsets: Vec<Box<TScorer>>,
|
||||
doc: DocId,
|
||||
score: Score,
|
||||
combiner: TScoreCombiner,
|
||||
threshold_fn: ThresholdFn,
|
||||
}
|
||||
|
||||
impl<TScorer, ThresholdFn, TScoreCombiner> BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
|
||||
impl<TScorer, TScoreCombiner> BlockMaxWand<TScorer, TScoreCombiner>
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
TScorer: BlockMaxScorer + Scorer,
|
||||
ThresholdFn: Fn(&Score) -> bool + 'static,
|
||||
{
|
||||
fn new(
|
||||
docsets: Vec<TScorer>,
|
||||
combiner: TScoreCombiner,
|
||||
threshold_fn: ThresholdFn,
|
||||
) -> BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner> {
|
||||
) -> BlockMaxWand<TScorer, TScoreCombiner> {
|
||||
let mut non_empty_docsets: Vec<_> = docsets
|
||||
.into_iter()
|
||||
.flat_map(|mut docset| {
|
||||
@@ -134,26 +131,24 @@ impl<TScorer, ThresholdFn, TScoreCombiner> BlockMaxWand<TScorer, ThresholdFn, TS
|
||||
BlockMaxWand {
|
||||
docsets: non_empty_docsets,
|
||||
combiner,
|
||||
threshold_fn,
|
||||
doc: 0u32,
|
||||
score: 0f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the position in the sorted list of posting lists of the **pivot**.
|
||||
fn find_pivot_position(&self) -> Option<Pivot> {
|
||||
fn find_pivot_position(&self, lower_bound_score: Score) -> Option<Pivot> {
|
||||
find_pivot_position(
|
||||
self.docsets.iter().map(|docset| docset.as_ref()),
|
||||
&self.threshold_fn,
|
||||
)
|
||||
lower_bound_score)
|
||||
}
|
||||
|
||||
fn advance_with_pivot(&mut self, pivot: Pivot) -> SkipResult {
|
||||
fn advance_with_pivot(&mut self, pivot: Pivot, lower_bound_score: Score) -> SkipResult {
|
||||
let block_upper_bound: Score = self.docsets[..=pivot.position]
|
||||
.iter_mut()
|
||||
.map(|docset| docset.block_max_score())
|
||||
.sum();
|
||||
if (self.threshold_fn)(&block_upper_bound) {
|
||||
if block_upper_bound > lower_bound_score {
|
||||
if pivot.doc == self.docsets[0].doc() {
|
||||
// Since self.docsets is sorted by their current doc, in this branch, all
|
||||
// docsets in [0..=pivot] are positioned on pivot.doc.
|
||||
@@ -206,22 +201,14 @@ impl<TScorer, ThresholdFn, TScoreCombiner> BlockMaxWand<TScorer, ThresholdFn, TS
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer, ThresholdFn, TScoreCombiner> DocSet
|
||||
for BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
|
||||
impl<TScorer, TScoreCombiner> DocSet
|
||||
for BlockMaxWand<TScorer, TScoreCombiner>
|
||||
where
|
||||
TScorer: BlockMaxScorer + Scorer,
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
ThresholdFn: Fn(&Score) -> bool + 'static,
|
||||
{
|
||||
fn advance(&mut self) -> bool {
|
||||
while let Some(pivot) = self.find_pivot_position() {
|
||||
match self.advance_with_pivot(pivot) {
|
||||
SkipResult::End => { return false },
|
||||
SkipResult::Reached=> { return true; }
|
||||
SkipResult::OverStep => {}
|
||||
}
|
||||
}
|
||||
false
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
fn skip_next(&mut self, target: DocId) -> SkipResult {
|
||||
@@ -252,16 +239,38 @@ for BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer, ThresholdFn, TScoreCombiner> Scorer
|
||||
for BlockMaxWand<TScorer, ThresholdFn, TScoreCombiner>
|
||||
impl<TScorer, TScoreCombiner> Scorer
|
||||
for BlockMaxWand<TScorer, TScoreCombiner>
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
TScorer: Scorer + BlockMaxScorer,
|
||||
ThresholdFn: Fn(&Score) -> bool + 'static,
|
||||
{
|
||||
fn score(&mut self) -> Score {
|
||||
self.score
|
||||
}
|
||||
|
||||
/// Returns `Some(&mut self)` if pruning is supported by the current scorer.
|
||||
/// None, if pruning is score is not supported.
|
||||
fn get_pruning_scorer(&mut self) -> Option<&mut dyn ScorerWithPruning> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer, TScoreCombiner> ScorerWithPruning
|
||||
for BlockMaxWand<TScorer, TScoreCombiner>
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
TScorer: Scorer + BlockMaxScorer {
|
||||
fn advance_with_pruning(&mut self, lower_bound_score: f32) -> bool {
|
||||
while let Some(pivot) = self.find_pivot_position(lower_bound_score) {
|
||||
match self.advance_with_pivot(pivot, lower_bound_score) {
|
||||
SkipResult::End => { return false },
|
||||
SkipResult::Reached=> { return true; }
|
||||
SkipResult::OverStep => {}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -277,8 +286,8 @@ mod tests {
|
||||
use float_cmp::approx_eq;
|
||||
use proptest::strategy::Strategy;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::num::Wrapping;
|
||||
use crate::collector::{SegmentCollector, TopScoreSegmentCollector};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VecDocSet {
|
||||
@@ -403,94 +412,17 @@ mod tests {
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TopSegmentCollector<T> {
|
||||
limit: usize,
|
||||
heap: BinaryHeap<ComparableDoc<T, DocId>>,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd> TopSegmentCollector<T> {
|
||||
fn new(limit: usize) -> TopSegmentCollector<T> {
|
||||
TopSegmentCollector {
|
||||
limit,
|
||||
heap: BinaryHeap::with_capacity(limit),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Clone + Copy + Default> TopSegmentCollector<T> {
|
||||
pub fn harvest(self) -> Vec<(T, DocId)> {
|
||||
self.heap
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| (comparable_doc.feature, comparable_doc.doc))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return true iff at least K documents have gone through
|
||||
/// the collector.
|
||||
#[inline(always)]
|
||||
pub(crate) fn at_capacity(&self) -> bool {
|
||||
self.heap.len() >= self.limit
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn above_threshold(&self, elem: &T) -> bool {
|
||||
if self.at_capacity() {
|
||||
elem > &self.heap.peek().unwrap().feature
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Collects a document scored by the given feature
|
||||
///
|
||||
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
|
||||
/// will compare the lowest scoring item with the given one and keep whichever is greater.
|
||||
#[inline(always)]
|
||||
pub fn collect(&mut self, doc: DocId, feature: T) {
|
||||
if self.at_capacity() {
|
||||
// It's ok to unwrap as long as a limit of 0 is forbidden.
|
||||
if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) {
|
||||
if limit_feature < feature {
|
||||
if let Some(mut head) = self.heap.peek_mut() {
|
||||
head.feature = feature;
|
||||
head.doc = doc;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// we have not reached capacity yet, so we can just push the
|
||||
// element.
|
||||
self.heap.push(ComparableDoc { feature, doc });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn union_vs_bmw(posting_lists: Vec<VecDocSet>) {
|
||||
let mut union = Union::<VecDocSet, SumCombiner>::from(posting_lists.clone());
|
||||
let mut top_union = TopSegmentCollector::<Score>::new(10);
|
||||
let mut top_union = TopScoreSegmentCollector::new(0, 10);
|
||||
while union.advance() {
|
||||
top_union.collect(union.doc(), union.score());
|
||||
}
|
||||
let top_bmw = std::rc::Rc::new(std::cell::RefCell::new(TopSegmentCollector::<Score>::new(
|
||||
10,
|
||||
)));
|
||||
let inner = std::rc::Rc::clone(&top_bmw);
|
||||
let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default(), move |score| {
|
||||
inner.borrow().above_threshold(score)
|
||||
});
|
||||
while bmw.advance() {
|
||||
top_bmw.borrow_mut().collect(bmw.doc(), bmw.score());
|
||||
}
|
||||
drop(bmw);
|
||||
let top_bmw = TopScoreSegmentCollector::new(0, 10 );
|
||||
let mut bmw = BlockMaxWand::new(posting_lists, SumCombiner::default());
|
||||
let top_docs_bnw = top_bmw.collect_scorer(&mut bmw, None);
|
||||
for ((expected_score, expected_doc), (actual_score, actual_doc)) in
|
||||
top_union.harvest().into_iter().zip(
|
||||
std::rc::Rc::try_unwrap(top_bmw)
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
.harvest(),
|
||||
)
|
||||
top_union.harvest().into_iter().zip( top_docs_bnw )
|
||||
{
|
||||
assert!(approx_eq!(
|
||||
f32,
|
||||
@@ -595,7 +527,7 @@ mod tests {
|
||||
VecDocSet::started(vec![(3, 6.0)], 1),
|
||||
];
|
||||
assert_eq!(
|
||||
find_pivot_position(postings.iter(), &|&score| score > 2.0),
|
||||
find_pivot_position(postings.iter(), 2.0f32),
|
||||
Some(Pivot {
|
||||
position: 1,
|
||||
doc: 1,
|
||||
@@ -603,7 +535,7 @@ mod tests {
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
find_pivot_position(postings.iter(), &|&score| score > 5.0),
|
||||
find_pivot_position(postings.iter(), 5.0f32),
|
||||
Some(Pivot {
|
||||
position: 2,
|
||||
doc: 2,
|
||||
@@ -611,7 +543,7 @@ mod tests {
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
find_pivot_position(postings.iter(), &|&score| score > 9.0),
|
||||
find_pivot_position(postings.iter(), 9.0f32),
|
||||
Some(Pivot {
|
||||
position: 4,
|
||||
doc: 3,
|
||||
@@ -619,7 +551,7 @@ mod tests {
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
find_pivot_position(postings.iter(), &|&score| score > 20.0),
|
||||
find_pivot_position(postings.iter(), 20.0f32),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user