mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-06-05 01:50:42 +00:00
replace BinaryHeap for TopN (#2186)
* replace BinaryHeap for TopN replace BinaryHeap for TopN with variant that selects the median with QuickSort, which runs in O(n) time. add merge_fruits fast path * call truncate unconditionally, extend test * remove special early exit * add TODO, fmt * truncate top n instead median, return vec * simplify code
This commit is contained in:
@@ -97,7 +97,7 @@ pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
|
||||
mod top_collector;
|
||||
|
||||
mod top_score_collector;
|
||||
pub use self::top_score_collector::TopDocs;
|
||||
pub use self::top_score_collector::{TopDocs, TopNComputer};
|
||||
|
||||
mod custom_score_top_collector;
|
||||
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use super::top_score_collector::TopNComputer;
|
||||
use crate::{DocAddress, DocId, SegmentOrdinal, SegmentReader};
|
||||
|
||||
/// Contains a feature (field, score, etc.) of a document along with the document address.
|
||||
@@ -20,6 +20,14 @@ pub(crate) struct ComparableDoc<T, D> {
|
||||
pub feature: T,
|
||||
pub doc: D,
|
||||
}
|
||||
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ComparableDoc")
|
||||
.field("feature", &self.feature)
|
||||
.field("doc", &self.doc)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
@@ -91,18 +99,13 @@ where T: PartialOrd + Clone
|
||||
if self.limit == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut top_collector = BinaryHeap::new();
|
||||
let mut top_collector = TopNComputer::new(self.limit + self.offset);
|
||||
for child_fruit in children {
|
||||
for (feature, doc) in child_fruit {
|
||||
if top_collector.len() < (self.limit + self.offset) {
|
||||
top_collector.push(ComparableDoc { feature, doc });
|
||||
} else if let Some(mut head) = top_collector.peek_mut() {
|
||||
if head.feature < feature {
|
||||
*head = ComparableDoc { feature, doc };
|
||||
}
|
||||
}
|
||||
top_collector.push(ComparableDoc { feature, doc });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(top_collector
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
@@ -111,7 +114,7 @@ where T: PartialOrd + Clone
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub(crate) fn for_segment<F: PartialOrd>(
|
||||
pub(crate) fn for_segment<F: PartialOrd + Clone>(
|
||||
&self,
|
||||
segment_id: SegmentOrdinal,
|
||||
_: &SegmentReader,
|
||||
@@ -136,20 +139,18 @@ where T: PartialOrd + Clone
|
||||
/// The Top Collector keeps track of the K documents
|
||||
/// sorted by type `T`.
|
||||
///
|
||||
/// The implementation is based on a `BinaryHeap`.
|
||||
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
|
||||
/// The theoretical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n log K)`.
|
||||
/// is `O(n + K)`.
|
||||
pub(crate) struct TopSegmentCollector<T> {
|
||||
limit: usize,
|
||||
heap: BinaryHeap<ComparableDoc<T, DocId>>,
|
||||
topn_computer: TopNComputer<T, DocId>,
|
||||
segment_ord: u32,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd> TopSegmentCollector<T> {
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
|
||||
TopSegmentCollector {
|
||||
limit,
|
||||
heap: BinaryHeap::with_capacity(limit),
|
||||
topn_computer: TopNComputer::new(limit),
|
||||
segment_ord,
|
||||
}
|
||||
}
|
||||
@@ -158,7 +159,7 @@ impl<T: PartialOrd> TopSegmentCollector<T> {
|
||||
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
pub fn harvest(self) -> Vec<(T, DocAddress)> {
|
||||
let segment_ord = self.segment_ord;
|
||||
self.heap
|
||||
self.topn_computer
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|comparable_doc| {
|
||||
@@ -173,33 +174,13 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return true if more documents have been collected than the limit.
|
||||
#[inline]
|
||||
pub(crate) fn at_capacity(&self) -> bool {
|
||||
self.heap.len() >= self.limit
|
||||
}
|
||||
|
||||
/// Collects a document scored by the given feature
|
||||
///
|
||||
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
|
||||
/// will compare the lowest scoring item with the given one and keep whichever is greater.
|
||||
#[inline]
|
||||
pub fn collect(&mut self, doc: DocId, feature: T) {
|
||||
if self.at_capacity() {
|
||||
// It's ok to unwrap as long as a limit of 0 is forbidden.
|
||||
if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) {
|
||||
if limit_feature < feature {
|
||||
if let Some(mut head) = self.heap.peek_mut() {
|
||||
head.feature = feature;
|
||||
head.doc = doc;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// we have not reached capacity yet, so we can just push the
|
||||
// element.
|
||||
self.heap.push(ComparableDoc { feature, doc });
|
||||
}
|
||||
self.topn_computer.push(ComparableDoc { feature, doc });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::collections::BinaryHeap;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
@@ -86,12 +85,15 @@ where
|
||||
/// The `TopDocs` collector keeps track of the top `K` documents
|
||||
/// sorted by their score.
|
||||
///
|
||||
/// The implementation is based on a `BinaryHeap`.
|
||||
/// The theoretical complexity for collecting the top `K` out of `n` documents
|
||||
/// is `O(n log K)`.
|
||||
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
|
||||
/// with pattern defeating QuickSort.
|
||||
/// The theoretical complexity for collecting the top `K` out of `N` documents
|
||||
/// is `O(N + K)`.
|
||||
///
|
||||
/// This collector guarantees a stable sorting in case of a tie on the
|
||||
/// document score. As such, it is suitable to implement pagination.
|
||||
/// This collector does not guarantee a stable sorting in case of a tie on the
|
||||
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
|
||||
/// like docid in case of score equality.
|
||||
/// Only then, it is suitable for pagination.
|
||||
///
|
||||
/// ```rust
|
||||
/// use tantivy::collector::TopDocs;
|
||||
@@ -661,50 +663,35 @@ impl Collector for TopDocs {
|
||||
reader: &SegmentReader,
|
||||
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
|
||||
let heap_len = self.0.limit + self.0.offset;
|
||||
let mut heap: BinaryHeap<ComparableDoc<Score, DocId>> = BinaryHeap::with_capacity(heap_len);
|
||||
let mut top_n = TopNComputer::new(heap_len);
|
||||
|
||||
if let Some(alive_bitset) = reader.alive_bitset() {
|
||||
let mut threshold = Score::MIN;
|
||||
weight.for_each_pruning(threshold, reader, &mut |doc, score| {
|
||||
top_n.threshold = Some(threshold);
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
if alive_bitset.is_deleted(doc) {
|
||||
return threshold;
|
||||
}
|
||||
let heap_item = ComparableDoc {
|
||||
let doc = ComparableDoc {
|
||||
feature: score,
|
||||
doc,
|
||||
};
|
||||
if heap.len() < heap_len {
|
||||
heap.push(heap_item);
|
||||
if heap.len() == heap_len {
|
||||
threshold = heap.peek().map(|el| el.feature).unwrap_or(Score::MIN);
|
||||
}
|
||||
return threshold;
|
||||
}
|
||||
*heap.peek_mut().unwrap() = heap_item;
|
||||
threshold = heap.peek().map(|el| el.feature).unwrap_or(Score::MIN);
|
||||
top_n.push(doc);
|
||||
threshold = top_n.threshold.unwrap_or(Score::MIN);
|
||||
threshold
|
||||
})?;
|
||||
} else {
|
||||
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
|
||||
let heap_item = ComparableDoc {
|
||||
let doc = ComparableDoc {
|
||||
feature: score,
|
||||
doc,
|
||||
};
|
||||
if heap.len() < heap_len {
|
||||
heap.push(heap_item);
|
||||
// TODO the threshold is suboptimal for heap.len == heap_len
|
||||
if heap.len() == heap_len {
|
||||
return heap.peek().map(|el| el.feature).unwrap_or(Score::MIN);
|
||||
} else {
|
||||
return Score::MIN;
|
||||
}
|
||||
}
|
||||
*heap.peek_mut().unwrap() = heap_item;
|
||||
heap.peek().map(|el| el.feature).unwrap_or(Score::MIN)
|
||||
top_n.push(doc);
|
||||
top_n.threshold.unwrap_or(Score::MIN)
|
||||
})?;
|
||||
}
|
||||
|
||||
let fruit = heap
|
||||
let fruit = top_n
|
||||
.into_sorted_vec()
|
||||
.into_iter()
|
||||
.map(|cid| {
|
||||
@@ -736,9 +723,81 @@ impl SegmentCollector for TopScoreSegmentCollector {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fast TopN Computation
|
||||
///
|
||||
/// For TopN == 0, it will be relative expensive.
|
||||
pub struct TopNComputer<Score, DocId> {
|
||||
buffer: Vec<ComparableDoc<Score, DocId>>,
|
||||
top_n: usize,
|
||||
pub(crate) threshold: Option<Score>,
|
||||
}
|
||||
|
||||
impl<Score, DocId> TopNComputer<Score, DocId>
|
||||
where
|
||||
Score: PartialOrd + Clone,
|
||||
DocId: Ord + Clone,
|
||||
{
|
||||
/// Create a new `TopNComputer`.
|
||||
/// Internally it will allocate a buffer of size `2 * top_n`.
|
||||
pub fn new(top_n: usize) -> Self {
|
||||
let vec_cap = top_n.max(1) * 2;
|
||||
TopNComputer {
|
||||
buffer: Vec::with_capacity(vec_cap),
|
||||
top_n,
|
||||
threshold: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn push(&mut self, doc: ComparableDoc<Score, DocId>) {
|
||||
if let Some(last_median) = self.threshold.clone() {
|
||||
if doc.feature < last_median {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if self.buffer.len() == self.buffer.capacity() {
|
||||
let median = self.truncate_top_n();
|
||||
self.threshold = Some(median);
|
||||
}
|
||||
|
||||
// This is faster since it avoids the buffer resizing to be inlined from vec.push()
|
||||
// (this is in the hot path)
|
||||
// TODO: Replace with `push_within_capacity` when it's stabilized
|
||||
let uninit = self.buffer.spare_capacity_mut();
|
||||
// This cannot panic, because we truncate_median will at least remove one element, since
|
||||
// the min capacity is 2.
|
||||
uninit[0].write(doc);
|
||||
// This is safe because it would panic in the line above
|
||||
unsafe {
|
||||
self.buffer.set_len(self.buffer.len() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
fn truncate_top_n(&mut self) -> Score {
|
||||
// Use select_nth_unstable to find the top nth score
|
||||
let (_, median_el, _) = self.buffer.select_nth_unstable(self.top_n);
|
||||
|
||||
let median_score = median_el.feature.clone();
|
||||
// Remove all elements below the top_n
|
||||
self.buffer.truncate(self.top_n);
|
||||
|
||||
median_score
|
||||
}
|
||||
|
||||
pub(crate) fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, DocId>> {
|
||||
if self.buffer.len() > self.top_n {
|
||||
self.truncate_top_n();
|
||||
}
|
||||
self.buffer.sort_unstable();
|
||||
self.buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::TopDocs;
|
||||
use super::{TopDocs, TopNComputer};
|
||||
use crate::collector::top_collector::ComparableDoc;
|
||||
use crate::collector::Collector;
|
||||
use crate::query::{AllQuery, Query, QueryParser};
|
||||
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
|
||||
@@ -767,6 +826,78 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_topn_computer() {
|
||||
let mut computer: TopNComputer<u32, u32> = TopNComputer::new(0);
|
||||
|
||||
computer.push(ComparableDoc {
|
||||
feature: 1u32,
|
||||
doc: 1u32,
|
||||
});
|
||||
computer.push(ComparableDoc {
|
||||
feature: 1u32,
|
||||
doc: 2u32,
|
||||
});
|
||||
computer.push(ComparableDoc {
|
||||
feature: 1u32,
|
||||
doc: 3u32,
|
||||
});
|
||||
assert!(computer.into_sorted_vec().is_empty());
|
||||
}
|
||||
#[test]
|
||||
fn test_topn_computer() {
|
||||
let mut computer: TopNComputer<u32, u32> = TopNComputer::new(2);
|
||||
|
||||
computer.push(ComparableDoc {
|
||||
feature: 1u32,
|
||||
doc: 1u32,
|
||||
});
|
||||
computer.push(ComparableDoc {
|
||||
feature: 2u32,
|
||||
doc: 2u32,
|
||||
});
|
||||
computer.push(ComparableDoc {
|
||||
feature: 3u32,
|
||||
doc: 3u32,
|
||||
});
|
||||
computer.push(ComparableDoc {
|
||||
feature: 2u32,
|
||||
doc: 4u32,
|
||||
});
|
||||
computer.push(ComparableDoc {
|
||||
feature: 1u32,
|
||||
doc: 5u32,
|
||||
});
|
||||
assert_eq!(
|
||||
computer.into_sorted_vec(),
|
||||
&[
|
||||
ComparableDoc {
|
||||
feature: 3u32,
|
||||
doc: 3u32,
|
||||
},
|
||||
ComparableDoc {
|
||||
feature: 2u32,
|
||||
doc: 2u32,
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topn_computer_no_panic() {
|
||||
for top_n in 0..10 {
|
||||
let mut computer: TopNComputer<u32, u32> = TopNComputer::new(top_n);
|
||||
|
||||
for _ in 0..1 + top_n * 2 {
|
||||
computer.push(ComparableDoc {
|
||||
feature: 1u32,
|
||||
doc: 1u32,
|
||||
});
|
||||
}
|
||||
let _vals = computer.into_sorted_vec();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_collector_not_at_capacity_without_offset() -> crate::Result<()> {
|
||||
let index = make_index()?;
|
||||
@@ -852,20 +983,25 @@ mod tests {
|
||||
// using AllQuery to get a constant score
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
|
||||
let page_0 = searcher.search(&AllQuery, &TopDocs::with_limit(1)).unwrap();
|
||||
|
||||
let page_1 = searcher.search(&AllQuery, &TopDocs::with_limit(2)).unwrap();
|
||||
|
||||
let page_2 = searcher.search(&AllQuery, &TopDocs::with_limit(3)).unwrap();
|
||||
|
||||
// precondition for the test to be meaningful: we did get documents
|
||||
// with the same score
|
||||
assert!(page_0.iter().all(|result| result.0 == page_1[0].0));
|
||||
assert!(page_1.iter().all(|result| result.0 == page_1[0].0));
|
||||
assert!(page_2.iter().all(|result| result.0 == page_2[0].0));
|
||||
|
||||
// sanity check since we're relying on make_index()
|
||||
assert_eq!(page_0.len(), 1);
|
||||
assert_eq!(page_1.len(), 2);
|
||||
assert_eq!(page_2.len(), 3);
|
||||
|
||||
assert_eq!(page_1, &page_2[..page_1.len()]);
|
||||
assert_eq!(page_0, &page_2[..page_0.len()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user