Fix TopN performance regression.

https://github.com/quickwit-oss/tantivy/pull/2777
This commit is contained in:
Stu Hood
2025-12-15 16:15:12 -08:00
parent c56ddcb6d7
commit 9ffe4af096
4 changed files with 99 additions and 90 deletions

View File

@@ -18,7 +18,8 @@ mod tests {
use std::ops::Range;
use crate::collector::sort_key::{
SortByErasedType, SortBySimilarityScore, SortByStaticFastValue, SortByString,
Comparator, NaturalComparator, ReverseComparator, SortByErasedType, SortBySimilarityScore,
SortByStaticFastValue, SortByString,
};
use crate::collector::{ComparableDoc, DocSetCollector, TopDocs};
use crate::indexer::NoMergePolicy;
@@ -419,15 +420,14 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
if order.is_desc() {
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
} else {
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
}
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -3,7 +3,7 @@ use std::cmp::Ordering;
use columnar::MonotonicallyMappableToU64;
use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::collector::{ComparableDoc, SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::{OwnedValue, Schema};
use crate::{DocId, Order, Score};
@@ -69,6 +69,23 @@ fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedVal
pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
/// Return the order between two values.
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
/// Return the order between two ComparableDoc values, using the semantics which are
/// implemented by TopNComputer.
#[inline(always)]
fn compare_doc<D: Ord>(
&self,
lhs: &ComparableDoc<T, D>,
rhs: &ComparableDoc<T, D>,
) -> Ordering {
// TopNComputer sorts in descending order of the SortKey by default: we apply that ordering
// here to ease comparison in testing.
self.compare(&rhs.sort_key, &lhs.sort_key).then_with(|| {
// In case of a tie on the sort key, we always sort by ascending `DocAddress` in order
// to ensure a stable sorting of the documents, regardless of the sort key's order.
// See the TopNComputer docs for more information.
lhs.doc.cmp(&rhs.doc)
})
}
}
/// With the natural comparator, the top k collector will return
@@ -100,7 +117,8 @@ impl Comparator<OwnedValue> for NaturalComparator {
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the comparator.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;

View File

@@ -1,64 +1,22 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
/// is a type which can be compared with a `Comparator<T>`.
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
f.debug_struct("ComparableDoc")
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}

View File

@@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
/// This collector guarantees a stable sorting in case of a tie on the
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -501,8 +500,13 @@ where
///
/// For TopN == 0, it will be relative expensive.
///
/// When using the natural comparator, the top N computer returns the top N elements in
/// descending order, as expected for a top N.
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
/// ascending order, regardless of the `Comparator` used for the `Score` type.
///
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
@@ -601,10 +605,13 @@ where
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
///
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
#[inline]
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
if let Some(last_median) = &self.threshold {
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
// See the struct docs for an explanation of why this comparison is strict.
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
return;
}
}
@@ -629,11 +636,9 @@ where
#[inline(never)]
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
self.comparator
.compare(&rhs.sort_key, &lhs.sort_key)
.then_with(|| lhs.doc.cmp(&rhs.doc))
});
let (_, median_el, _) = self
.buffer
.select_nth_unstable_by(self.top_n, |lhs, rhs| self.comparator.compare_doc(lhs, rhs));
let median_score = median_el.sort_key.clone();
// Remove all elements below the top_n
@@ -647,11 +652,8 @@ where
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer.sort_unstable_by(|left, right| {
self.comparator
.compare(&right.sort_key, &left.sort_key)
.then_with(|| left.doc.cmp(&right.doc))
});
self.buffer
.sort_unstable_by(|left, right| self.comparator.compare_doc(left, right));
self.buffer
}
@@ -687,7 +689,9 @@ mod tests {
use proptest::prelude::*;
use super::{TopDocs, TopNComputer};
use crate::collector::sort_key::{ComparatorEnum, NaturalComparator, ReverseComparator};
use crate::collector::sort_key::{
Comparator, ComparatorEnum, NaturalComparator, ReverseComparator,
};
use crate::collector::top_collector::ComparableDoc;
use crate::collector::{Collector, DocSetCollector};
use crate::query::{AllQuery, Query, QueryParser};
@@ -756,6 +760,33 @@ mod tests {
);
}
#[test]
fn test_topn_computer_duplicates() {
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
TopNComputer::new_with_comparator(2, NaturalComparator);
computer.push(1u32, 1u32);
computer.push(1u32, 2u32);
computer.push(1u32, 3u32);
computer.push(1u32, 4u32);
computer.push(1u32, 5u32);
// In the presence of duplicates, DocIds are always ascending order.
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
sort_key: 1u32,
doc: 1u32,
},
ComparableDoc {
sort_key: 1u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
@@ -773,14 +804,17 @@ mod tests {
#[test]
fn test_topn_computer_asc_prop(
limit in 0..10_usize,
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
) {
// NB: TopNComputer must receive inputs in ascending DocId order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort();
let mut comparable_docs =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1407,15 +1441,14 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
if order.is_desc() {
comparable_docs.sort_by(|l, r| NaturalComparator.compare_doc(l, r));
} else {
comparable_docs.sort_by(|l, r| ReverseComparator.compare_doc(l, r));
}
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();