replace BinaryHeap for TopN (#2186)

* replace BinaryHeap for TopN

replace BinaryHeap for TopN with variant that selects the median with QuickSort,
which runs in O(n) time.

add merge_fruits fast path

* call truncate unconditionally, extend test

* remove special early exit

* add TODO, fmt

* truncate top n instead median, return vec

* simplify code
This commit is contained in:
PSeitz
2023-09-27 09:25:30 +02:00
committed by GitHub
parent 90586bc1e2
commit b525f653c0
3 changed files with 189 additions and 72 deletions

View File

@@ -97,7 +97,7 @@ pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;
mod top_score_collector;
pub use self::top_score_collector::TopDocs;
pub use self::top_score_collector::{TopDocs, TopNComputer};
mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};

View File

@@ -1,7 +1,7 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::marker::PhantomData;
use super::top_score_collector::TopNComputer;
use crate::{DocAddress, DocId, SegmentOrdinal, SegmentReader};
/// Contains a feature (field, score, etc.) of a document along with the document address.
@@ -20,6 +20,14 @@ pub(crate) struct ComparableDoc<T, D> {
pub feature: T,
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComparableDoc")
.field("feature", &self.feature)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
@@ -91,18 +99,13 @@ where T: PartialOrd + Clone
if self.limit == 0 {
return Ok(Vec::new());
}
let mut top_collector = BinaryHeap::new();
let mut top_collector = TopNComputer::new(self.limit + self.offset);
for child_fruit in children {
for (feature, doc) in child_fruit {
if top_collector.len() < (self.limit + self.offset) {
top_collector.push(ComparableDoc { feature, doc });
} else if let Some(mut head) = top_collector.peek_mut() {
if head.feature < feature {
*head = ComparableDoc { feature, doc };
}
}
top_collector.push(ComparableDoc { feature, doc });
}
}
Ok(top_collector
.into_sorted_vec()
.into_iter()
@@ -111,7 +114,7 @@ where T: PartialOrd + Clone
.collect())
}
pub(crate) fn for_segment<F: PartialOrd>(
pub(crate) fn for_segment<F: PartialOrd + Clone>(
&self,
segment_id: SegmentOrdinal,
_: &SegmentReader,
@@ -136,20 +139,18 @@ where T: PartialOrd + Clone
/// The Top Collector keeps track of the K documents
/// sorted by type `T`.
///
/// The implementation is based on a `BinaryHeap`.
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
/// The theoretical complexity for collecting the top `K` out of `n` documents
/// is `O(n log K)`.
/// is `O(n + K)`.
pub(crate) struct TopSegmentCollector<T> {
limit: usize,
heap: BinaryHeap<ComparableDoc<T, DocId>>,
topn_computer: TopNComputer<T, DocId>,
segment_ord: u32,
}
impl<T: PartialOrd> TopSegmentCollector<T> {
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
limit,
heap: BinaryHeap::with_capacity(limit),
topn_computer: TopNComputer::new(limit),
segment_ord,
}
}
@@ -158,7 +159,7 @@ impl<T: PartialOrd> TopSegmentCollector<T> {
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
pub fn harvest(self) -> Vec<(T, DocAddress)> {
let segment_ord = self.segment_ord;
self.heap
self.topn_computer
.into_sorted_vec()
.into_iter()
.map(|comparable_doc| {
@@ -173,33 +174,13 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
.collect()
}
/// Return true if more documents have been collected than the limit.
#[inline]
pub(crate) fn at_capacity(&self) -> bool {
self.heap.len() >= self.limit
}
/// Collects a document scored by the given feature
///
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
/// will compare the lowest scoring item with the given one and keep whichever is greater.
#[inline]
pub fn collect(&mut self, doc: DocId, feature: T) {
if self.at_capacity() {
// It's ok to unwrap as long as a limit of 0 is forbidden.
if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) {
if limit_feature < feature {
if let Some(mut head) = self.heap.peek_mut() {
head.feature = feature;
head.doc = doc;
}
}
}
} else {
// we have not reached capacity yet, so we can just push the
// element.
self.heap.push(ComparableDoc { feature, doc });
}
self.topn_computer.push(ComparableDoc { feature, doc });
}
}

View File

@@ -1,4 +1,3 @@
use std::collections::BinaryHeap;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
@@ -86,12 +85,15 @@ where
/// The `TopDocs` collector keeps track of the top `K` documents
/// sorted by their score.
///
/// The implementation is based on a `BinaryHeap`.
/// The theoretical complexity for collecting the top `K` out of `n` documents
/// is `O(n log K)`.
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
/// with pattern defeating QuickSort.
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector guarantees a stable sorting in case of a tie on the
/// document score. As such, it is suitable to implement pagination.
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -661,50 +663,35 @@ impl Collector for TopDocs {
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let heap_len = self.0.limit + self.0.offset;
let mut heap: BinaryHeap<ComparableDoc<Score, DocId>> = BinaryHeap::with_capacity(heap_len);
let mut top_n = TopNComputer::new(heap_len);
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
weight.for_each_pruning(threshold, reader, &mut |doc, score| {
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
}
let heap_item = ComparableDoc {
let doc = ComparableDoc {
feature: score,
doc,
};
if heap.len() < heap_len {
heap.push(heap_item);
if heap.len() == heap_len {
threshold = heap.peek().map(|el| el.feature).unwrap_or(Score::MIN);
}
return threshold;
}
*heap.peek_mut().unwrap() = heap_item;
threshold = heap.peek().map(|el| el.feature).unwrap_or(Score::MIN);
top_n.push(doc);
threshold = top_n.threshold.unwrap_or(Score::MIN);
threshold
})?;
} else {
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
let heap_item = ComparableDoc {
let doc = ComparableDoc {
feature: score,
doc,
};
if heap.len() < heap_len {
heap.push(heap_item);
// TODO the threshold is suboptimal for heap.len == heap_len
if heap.len() == heap_len {
return heap.peek().map(|el| el.feature).unwrap_or(Score::MIN);
} else {
return Score::MIN;
}
}
*heap.peek_mut().unwrap() = heap_item;
heap.peek().map(|el| el.feature).unwrap_or(Score::MIN)
top_n.push(doc);
top_n.threshold.unwrap_or(Score::MIN)
})?;
}
let fruit = heap
let fruit = top_n
.into_sorted_vec()
.into_iter()
.map(|cid| {
@@ -736,9 +723,81 @@ impl SegmentCollector for TopScoreSegmentCollector {
}
}
/// Fast TopN Computation
///
/// For TopN == 0, it will be relative expensive.
pub struct TopNComputer<Score, DocId> {
buffer: Vec<ComparableDoc<Score, DocId>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
}
impl<Score, DocId> TopNComputer<Score, DocId>
where
Score: PartialOrd + Clone,
DocId: Ord + Clone,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
pub fn new(top_n: usize) -> Self {
let vec_cap = top_n.max(1) * 2;
TopNComputer {
buffer: Vec::with_capacity(vec_cap),
top_n,
threshold: None,
}
}
#[inline]
pub(crate) fn push(&mut self, doc: ComparableDoc<Score, DocId>) {
if let Some(last_median) = self.threshold.clone() {
if doc.feature < last_median {
return;
}
}
if self.buffer.len() == self.buffer.capacity() {
let median = self.truncate_top_n();
self.threshold = Some(median);
}
// This is faster since it avoids the buffer resizing to be inlined from vec.push()
// (this is in the hot path)
// TODO: Replace with `push_within_capacity` when it's stabilized
let uninit = self.buffer.spare_capacity_mut();
// This cannot panic, because we truncate_median will at least remove one element, since
// the min capacity is 2.
uninit[0].write(doc);
// This is safe because it would panic in the line above
unsafe {
self.buffer.set_len(self.buffer.len() + 1);
}
}
#[inline(never)]
fn truncate_top_n(&mut self) -> Score {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable(self.top_n);
let median_score = median_el.feature.clone();
// Remove all elements below the top_n
self.buffer.truncate(self.top_n);
median_score
}
pub(crate) fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, DocId>> {
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer.sort_unstable();
self.buffer
}
}
#[cfg(test)]
mod tests {
use super::TopDocs;
use super::{TopDocs, TopNComputer};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::Collector;
use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
@@ -767,6 +826,78 @@ mod tests {
}
}
#[test]
fn test_empty_topn_computer() {
let mut computer: TopNComputer<u32, u32> = TopNComputer::new(0);
computer.push(ComparableDoc {
feature: 1u32,
doc: 1u32,
});
computer.push(ComparableDoc {
feature: 1u32,
doc: 2u32,
});
computer.push(ComparableDoc {
feature: 1u32,
doc: 3u32,
});
assert!(computer.into_sorted_vec().is_empty());
}
#[test]
fn test_topn_computer() {
let mut computer: TopNComputer<u32, u32> = TopNComputer::new(2);
computer.push(ComparableDoc {
feature: 1u32,
doc: 1u32,
});
computer.push(ComparableDoc {
feature: 2u32,
doc: 2u32,
});
computer.push(ComparableDoc {
feature: 3u32,
doc: 3u32,
});
computer.push(ComparableDoc {
feature: 2u32,
doc: 4u32,
});
computer.push(ComparableDoc {
feature: 1u32,
doc: 5u32,
});
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
feature: 3u32,
doc: 3u32,
},
ComparableDoc {
feature: 2u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
let mut computer: TopNComputer<u32, u32> = TopNComputer::new(top_n);
for _ in 0..1 + top_n * 2 {
computer.push(ComparableDoc {
feature: 1u32,
doc: 1u32,
});
}
let _vals = computer.into_sorted_vec();
}
}
#[test]
fn test_top_collector_not_at_capacity_without_offset() -> crate::Result<()> {
let index = make_index()?;
@@ -852,20 +983,25 @@ mod tests {
// using AllQuery to get a constant score
let searcher = index.reader().unwrap().searcher();
let page_0 = searcher.search(&AllQuery, &TopDocs::with_limit(1)).unwrap();
let page_1 = searcher.search(&AllQuery, &TopDocs::with_limit(2)).unwrap();
let page_2 = searcher.search(&AllQuery, &TopDocs::with_limit(3)).unwrap();
// precondition for the test to be meaningful: we did get documents
// with the same score
assert!(page_0.iter().all(|result| result.0 == page_1[0].0));
assert!(page_1.iter().all(|result| result.0 == page_1[0].0));
assert!(page_2.iter().all(|result| result.0 == page_2[0].0));
// sanity check since we're relying on make_index()
assert_eq!(page_0.len(), 1);
assert_eq!(page_1.len(), 2);
assert_eq!(page_2.len(), 3);
assert_eq!(page_1, &page_2[..page_1.len()]);
assert_eq!(page_0, &page_2[..page_0.len()]);
}
#[test]