From b525f653c09726c9f97d35ee15aea1d557bce4f2 Mon Sep 17 00:00:00 2001 From: PSeitz Date: Wed, 27 Sep 2023 09:25:30 +0200 Subject: [PATCH] replace BinaryHeap for TopN (#2186) * replace BinaryHeap for TopN replace BinaryHeap for TopN with variant that selects the median with QuickSort, which runs in O(n) time. add merge_fruits fast path * call truncate unconditionally, extend test * remove special early exit * add TODO, fmt * truncate top n instead median, return vec * simplify code --- src/collector/mod.rs | 2 +- src/collector/top_collector.rs | 59 +++----- src/collector/top_score_collector.rs | 200 ++++++++++++++++++++++----- 3 files changed, 189 insertions(+), 72 deletions(-) diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 6e58ed6c0..4d9b43d65 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -97,7 +97,7 @@ pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit}; mod top_collector; mod top_score_collector; -pub use self::top_score_collector::TopDocs; +pub use self::top_score_collector::{TopDocs, TopNComputer}; mod custom_score_top_collector; pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer}; diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 691ef324b..ddb78c7b1 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,7 +1,7 @@ use std::cmp::Ordering; -use std::collections::BinaryHeap; use std::marker::PhantomData; +use super::top_score_collector::TopNComputer; use crate::{DocAddress, DocId, SegmentOrdinal, SegmentReader}; /// Contains a feature (field, score, etc.) of a document along with the document address. @@ -20,6 +20,14 @@ pub(crate) struct ComparableDoc { pub feature: T, pub doc: D, } +impl std::fmt::Debug for ComparableDoc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ComparableDoc") + .field("feature", &self.feature) + .field("doc", &self.doc) + .finish() + } +} impl PartialOrd for ComparableDoc { fn partial_cmp(&self, other: &Self) -> Option { @@ -91,18 +99,13 @@ where T: PartialOrd + Clone if self.limit == 0 { return Ok(Vec::new()); } - let mut top_collector = BinaryHeap::new(); + let mut top_collector = TopNComputer::new(self.limit + self.offset); for child_fruit in children { for (feature, doc) in child_fruit { - if top_collector.len() < (self.limit + self.offset) { - top_collector.push(ComparableDoc { feature, doc }); - } else if let Some(mut head) = top_collector.peek_mut() { - if head.feature < feature { - *head = ComparableDoc { feature, doc }; - } - } + top_collector.push(ComparableDoc { feature, doc }); } } + Ok(top_collector .into_sorted_vec() .into_iter() @@ -111,7 +114,7 @@ where T: PartialOrd + Clone .collect()) } - pub(crate) fn for_segment( + pub(crate) fn for_segment( &self, segment_id: SegmentOrdinal, _: &SegmentReader, @@ -136,20 +139,18 @@ where T: PartialOrd + Clone /// The Top Collector keeps track of the K documents /// sorted by type `T`. /// -/// The implementation is based on a `BinaryHeap`. +/// The implementation is based on a repeatedly truncating on the median after K * 2 documents /// The theoretical complexity for collecting the top `K` out of `n` documents -/// is `O(n log K)`. +/// is `O(n + K)`. pub(crate) struct TopSegmentCollector { - limit: usize, - heap: BinaryHeap>, + topn_computer: TopNComputer, segment_ord: u32, } -impl TopSegmentCollector { +impl TopSegmentCollector { fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector { TopSegmentCollector { - limit, - heap: BinaryHeap::with_capacity(limit), + topn_computer: TopNComputer::new(limit), segment_ord, } } @@ -158,7 +159,7 @@ impl TopSegmentCollector { impl TopSegmentCollector { pub fn harvest(self) -> Vec<(T, DocAddress)> { let segment_ord = self.segment_ord; - self.heap + self.topn_computer .into_sorted_vec() .into_iter() .map(|comparable_doc| { @@ -173,33 +174,13 @@ impl TopSegmentCollector { .collect() } - /// Return true if more documents have been collected than the limit. - #[inline] - pub(crate) fn at_capacity(&self) -> bool { - self.heap.len() >= self.limit - } - /// Collects a document scored by the given feature /// /// It collects documents until it has reached the max capacity. Once it reaches capacity, it /// will compare the lowest scoring item with the given one and keep whichever is greater. #[inline] pub fn collect(&mut self, doc: DocId, feature: T) { - if self.at_capacity() { - // It's ok to unwrap as long as a limit of 0 is forbidden. - if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) { - if limit_feature < feature { - if let Some(mut head) = self.heap.peek_mut() { - head.feature = feature; - head.doc = doc; - } - } - } - } else { - // we have not reached capacity yet, so we can just push the - // element. - self.heap.push(ComparableDoc { feature, doc }); - } + self.topn_computer.push(ComparableDoc { feature, doc }); } } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 23152ddeb..484882046 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -1,4 +1,3 @@ -use std::collections::BinaryHeap; use std::fmt; use std::marker::PhantomData; use std::sync::Arc; @@ -86,12 +85,15 @@ where /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. /// -/// The implementation is based on a `BinaryHeap`. -/// The theoretical complexity for collecting the top `K` out of `n` documents -/// is `O(n log K)`. +/// The implementation is based on a repeatedly truncating on the median after K * 2 documents +/// with pattern defeating QuickSort. +/// The theoretical complexity for collecting the top `K` out of `N` documents +/// is `O(N + K)`. /// -/// This collector guarantees a stable sorting in case of a tie on the -/// document score. As such, it is suitable to implement pagination. +/// This collector does not guarantee a stable sorting in case of a tie on the +/// document score, for stable sorting `PartialOrd` needs to resolve on other fields +/// like docid in case of score equality. +/// Only then, it is suitable for pagination. /// /// ```rust /// use tantivy::collector::TopDocs; @@ -661,50 +663,35 @@ impl Collector for TopDocs { reader: &SegmentReader, ) -> crate::Result<::Fruit> { let heap_len = self.0.limit + self.0.offset; - let mut heap: BinaryHeap> = BinaryHeap::with_capacity(heap_len); + let mut top_n = TopNComputer::new(heap_len); if let Some(alive_bitset) = reader.alive_bitset() { let mut threshold = Score::MIN; - weight.for_each_pruning(threshold, reader, &mut |doc, score| { + top_n.threshold = Some(threshold); + weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { if alive_bitset.is_deleted(doc) { return threshold; } - let heap_item = ComparableDoc { + let doc = ComparableDoc { feature: score, doc, }; - if heap.len() < heap_len { - heap.push(heap_item); - if heap.len() == heap_len { - threshold = heap.peek().map(|el| el.feature).unwrap_or(Score::MIN); - } - return threshold; - } - *heap.peek_mut().unwrap() = heap_item; - threshold = heap.peek().map(|el| el.feature).unwrap_or(Score::MIN); + top_n.push(doc); + threshold = top_n.threshold.unwrap_or(Score::MIN); threshold })?; } else { weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { - let heap_item = ComparableDoc { + let doc = ComparableDoc { feature: score, doc, }; - if heap.len() < heap_len { - heap.push(heap_item); - // TODO the threshold is suboptimal for heap.len == heap_len - if heap.len() == heap_len { - return heap.peek().map(|el| el.feature).unwrap_or(Score::MIN); - } else { - return Score::MIN; - } - } - *heap.peek_mut().unwrap() = heap_item; - heap.peek().map(|el| el.feature).unwrap_or(Score::MIN) + top_n.push(doc); + top_n.threshold.unwrap_or(Score::MIN) })?; } - let fruit = heap + let fruit = top_n .into_sorted_vec() .into_iter() .map(|cid| { @@ -736,9 +723,81 @@ impl SegmentCollector for TopScoreSegmentCollector { } } +/// Fast TopN Computation +/// +/// For TopN == 0, it will be relative expensive. +pub struct TopNComputer { + buffer: Vec>, + top_n: usize, + pub(crate) threshold: Option, +} + +impl TopNComputer +where + Score: PartialOrd + Clone, + DocId: Ord + Clone, +{ + /// Create a new `TopNComputer`. + /// Internally it will allocate a buffer of size `2 * top_n`. + pub fn new(top_n: usize) -> Self { + let vec_cap = top_n.max(1) * 2; + TopNComputer { + buffer: Vec::with_capacity(vec_cap), + top_n, + threshold: None, + } + } + + #[inline] + pub(crate) fn push(&mut self, doc: ComparableDoc) { + if let Some(last_median) = self.threshold.clone() { + if doc.feature < last_median { + return; + } + } + if self.buffer.len() == self.buffer.capacity() { + let median = self.truncate_top_n(); + self.threshold = Some(median); + } + + // This is faster since it avoids the buffer resizing to be inlined from vec.push() + // (this is in the hot path) + // TODO: Replace with `push_within_capacity` when it's stabilized + let uninit = self.buffer.spare_capacity_mut(); + // This cannot panic, because we truncate_median will at least remove one element, since + // the min capacity is 2. + uninit[0].write(doc); + // This is safe because it would panic in the line above + unsafe { + self.buffer.set_len(self.buffer.len() + 1); + } + } + + #[inline(never)] + fn truncate_top_n(&mut self) -> Score { + // Use select_nth_unstable to find the top nth score + let (_, median_el, _) = self.buffer.select_nth_unstable(self.top_n); + + let median_score = median_el.feature.clone(); + // Remove all elements below the top_n + self.buffer.truncate(self.top_n); + + median_score + } + + pub(crate) fn into_sorted_vec(mut self) -> Vec> { + if self.buffer.len() > self.top_n { + self.truncate_top_n(); + } + self.buffer.sort_unstable(); + self.buffer + } +} + #[cfg(test)] mod tests { - use super::TopDocs; + use super::{TopDocs, TopNComputer}; + use crate::collector::top_collector::ComparableDoc; use crate::collector::Collector; use crate::query::{AllQuery, Query, QueryParser}; use crate::schema::{Field, Schema, FAST, STORED, TEXT}; @@ -767,6 +826,78 @@ mod tests { } } + #[test] + fn test_empty_topn_computer() { + let mut computer: TopNComputer = TopNComputer::new(0); + + computer.push(ComparableDoc { + feature: 1u32, + doc: 1u32, + }); + computer.push(ComparableDoc { + feature: 1u32, + doc: 2u32, + }); + computer.push(ComparableDoc { + feature: 1u32, + doc: 3u32, + }); + assert!(computer.into_sorted_vec().is_empty()); + } + #[test] + fn test_topn_computer() { + let mut computer: TopNComputer = TopNComputer::new(2); + + computer.push(ComparableDoc { + feature: 1u32, + doc: 1u32, + }); + computer.push(ComparableDoc { + feature: 2u32, + doc: 2u32, + }); + computer.push(ComparableDoc { + feature: 3u32, + doc: 3u32, + }); + computer.push(ComparableDoc { + feature: 2u32, + doc: 4u32, + }); + computer.push(ComparableDoc { + feature: 1u32, + doc: 5u32, + }); + assert_eq!( + computer.into_sorted_vec(), + &[ + ComparableDoc { + feature: 3u32, + doc: 3u32, + }, + ComparableDoc { + feature: 2u32, + doc: 2u32, + } + ] + ); + } + + #[test] + fn test_topn_computer_no_panic() { + for top_n in 0..10 { + let mut computer: TopNComputer = TopNComputer::new(top_n); + + for _ in 0..1 + top_n * 2 { + computer.push(ComparableDoc { + feature: 1u32, + doc: 1u32, + }); + } + let _vals = computer.into_sorted_vec(); + } + } + #[test] fn test_top_collector_not_at_capacity_without_offset() -> crate::Result<()> { let index = make_index()?; @@ -852,20 +983,25 @@ mod tests { // using AllQuery to get a constant score let searcher = index.reader().unwrap().searcher(); + let page_0 = searcher.search(&AllQuery, &TopDocs::with_limit(1)).unwrap(); + let page_1 = searcher.search(&AllQuery, &TopDocs::with_limit(2)).unwrap(); let page_2 = searcher.search(&AllQuery, &TopDocs::with_limit(3)).unwrap(); // precondition for the test to be meaningful: we did get documents // with the same score + assert!(page_0.iter().all(|result| result.0 == page_1[0].0)); assert!(page_1.iter().all(|result| result.0 == page_1[0].0)); assert!(page_2.iter().all(|result| result.0 == page_2[0].0)); // sanity check since we're relying on make_index() + assert_eq!(page_0.len(), 1); assert_eq!(page_1.len(), 2); assert_eq!(page_2.len(), 3); assert_eq!(page_1, &page_2[..page_1.len()]); + assert_eq!(page_0, &page_2[..page_0.len()]); } #[test]