Use a strict comparison in TopNComputer (#2777)

* Remove `(Partial)Ord` from `ComparableDoc`, and unify comparison between `TopNComputer` and `Comparator`.

* Doc cleanups.

* Require Ord for `ComparableDoc`.

* Semantics are actually _ascending_ DocId order.

* Adjust docs again for ascending DocId order.

* minor change

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
This commit is contained in:
Stu Hood
2025-12-18 03:13:23 -08:00
committed by GitHub
parent 73657dff77
commit c0f21a45ae
6 changed files with 116 additions and 92 deletions

View File

@@ -11,7 +11,26 @@ pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
mod tests {
pub(crate) mod tests {
// By spec, regardless of whether ascending or descending order was requested, in presence of a
// tie, we sort by ascending doc id/doc address.
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
hits: &mut [ComparableDoc<TSortKey, D>],
order: Order,
) {
if order.is_asc() {
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
} else {
hits.sort_by(|l, r| {
l.sort_key
.cmp(&r.sort_key)
.reverse() // This is descending
.then(l.doc.cmp(&r.doc))
});
}
}
use std::collections::HashMap;
use std::ops::Range;
@@ -372,15 +391,10 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -30,7 +30,8 @@ impl<T: PartialOrd> Comparator<T> for NaturalComparator {
/// first.
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the comparator.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;

View File

@@ -1,64 +1,22 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
/// is a type which can be compared with a `Comparator<T>`.
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
f.debug_struct("ComparableDoc")
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}

View File

@@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
/// This collector guarantees a stable sorting in case of a tie on the
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -500,8 +499,13 @@ where
///
/// For TopN == 0, it will be relative expensive.
///
/// When using the natural comparator, the top N computer returns the top N elements in
/// descending order, as expected for a top N.
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
/// ascending order, regardless of the `Comparator` used for the `Score` type.
///
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
@@ -580,6 +584,18 @@ where
}
}
#[inline(always)]
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
c: &C,
lhs: &ComparableDoc<TSortKey, D>,
rhs: &ComparableDoc<TSortKey, D>,
) -> std::cmp::Ordering {
c.compare(&lhs.sort_key, &rhs.sort_key)
.reverse() // Reverse here because we want top K.
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
// sort by doc id
}
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
where
D: Ord,
@@ -600,10 +616,13 @@ where
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
///
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
#[inline]
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
if let Some(last_median) = &self.threshold {
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
// See the struct docs for an explanation of why this comparison is strict.
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
return;
}
}
@@ -629,9 +648,7 @@ where
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
self.comparator
.compare(&rhs.sort_key, &lhs.sort_key)
.then_with(|| lhs.doc.cmp(&rhs.doc))
compare_for_top_k(&self.comparator, lhs, rhs)
});
let median_score = median_el.sort_key.clone();
@@ -646,11 +663,8 @@ where
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer.sort_unstable_by(|left, right| {
self.comparator
.compare(&right.sort_key, &left.sort_key)
.then_with(|| left.doc.cmp(&right.doc))
});
self.buffer
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
self.buffer
}
@@ -755,6 +769,33 @@ mod tests {
);
}
#[test]
fn test_topn_computer_duplicates() {
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
TopNComputer::new_with_comparator(2, NaturalComparator);
computer.push(1u32, 1u32);
computer.push(1u32, 2u32);
computer.push(1u32, 3u32);
computer.push(1u32, 4u32);
computer.push(1u32, 5u32);
// In the presence of duplicates, DocIds are always ascending order.
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
sort_key: 1u32,
doc: 1u32,
},
ComparableDoc {
sort_key: 1u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
@@ -772,14 +813,17 @@ mod tests {
#[test]
fn test_topn_computer_asc_prop(
limit in 0..10_usize,
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
) {
// NB: TopNComputer must receive inputs in ascending DocId order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort();
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1406,15 +1450,10 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -23,13 +23,18 @@ struct InnerDeleteQueue {
last_block: Weak<Block>,
}
/// The delete queue is a linked list storing delete operations.
///
/// Several consumers can hold a reference to it. Delete operations
/// get dropped/gc'ed when no more consumers are holding a reference
/// to them.
#[derive(Clone)]
pub struct DeleteQueue {
inner: Arc<RwLock<InnerDeleteQueue>>,
}
impl DeleteQueue {
// Creates a new delete queue.
/// Creates a new empty delete queue.
pub fn new() -> DeleteQueue {
DeleteQueue {
inner: Arc::default(),
@@ -58,10 +63,10 @@ impl DeleteQueue {
block
}
// Creates a new cursor that makes it possible to
// consume future delete operations.
//
// Past delete operations are not accessible.
/// Creates a new cursor that makes it possible to
/// consume future delete operations.
///
/// Past delete operations are not accessible.
pub fn cursor(&self) -> DeleteCursor {
let last_block = self.get_last_block();
let operations_len = last_block.operations.len();
@@ -71,7 +76,7 @@ impl DeleteQueue {
}
}
// Appends a new delete operations.
/// Appends a new delete operations.
pub fn push(&self, delete_operation: DeleteOperation) {
self.inner
.write()
@@ -169,6 +174,7 @@ struct Block {
next: NextBlock,
}
/// As we process delete operations, keeps track of our position.
#[derive(Clone)]
pub struct DeleteCursor {
block: Arc<Block>,

View File

@@ -5,14 +5,20 @@ use crate::Opstamp;
/// Timestamped Delete operation.
pub struct DeleteOperation {
/// Operation stamp.
/// It is used to check whether the delete operation
/// applies to an added document operation.
pub opstamp: Opstamp,
/// Weight is used to define the set of documents to be deleted.
pub target: Box<dyn Weight>,
}
/// Timestamped Add operation.
#[derive(Eq, PartialEq, Debug)]
pub struct AddOperation<D: Document = TantivyDocument> {
/// Operation stamp.
pub opstamp: Opstamp,
/// Document to be added.
pub document: D,
}