don't count matching doc twice

This commit is contained in:
trinity.pointard
2026-06-30 16:09:11 +00:00
parent 5e4fe3520c
commit 348ca1e309
2 changed files with 29 additions and 11 deletions

View File

@@ -342,7 +342,7 @@ pub const MAX_NUM_TERMS_FOR_VEC: u64 = 100;
/// Average docs-per-bucket below which term counts cluster too tightly (mostly 1s and 2s) for
/// `select_nth_unstable` to beat `sort_unstable`'s adaptive paths, so we fall back to a full sort.
/// This is very low on purpose, and meant to catch unique or mostly unique terms.
const DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD: u32 = 2;
const DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD: u64 = 2;
/// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
@@ -999,6 +999,10 @@ where
let segment_size = term_req.req.segment_size as usize;
// Total doc count over all buckets, computed is some case, and which can be reused by
// `cut_off_buckets` to derive `sum_other_doc_count` without a second pass.
let mut total_doc_count: Option<u64> = None;
// select_nth_unstable_by_key(segment_size, ...) places the (k+1)-th element at
// entries[segment_size] and guarantees entries[0..segment_size] are the top-k,
// unordered. We need this to properly compute term_doc_count_before_cutoff.
@@ -1065,10 +1069,10 @@ where
if entries.len() > segment_size {
// unique or near-unique fields create big runs of sorted values (ones),
// which is defavorable to quickselect. use the then faster sort_unstable.
let num_buckets = entries.len() as u32;
let num_docs: u32 = entries.iter().map(|(_, b)| b.count).sum();
let near_unique =
num_docs < num_buckets * DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD;
let num_buckets = entries.len() as u64;
let num_docs = entries.iter().map(|(_, b)| b.count as u64).sum();
total_doc_count = Some(num_docs);
let near_unique = num_docs < num_buckets * DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD;
if term_req.req.order.order == Order::Desc {
if near_unique {
entries.sort_unstable_by_key(|b| std::cmp::Reverse(b.1.count));
@@ -1087,7 +1091,7 @@ where
}
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, segment_size);
cut_off_buckets(&mut entries, segment_size, total_doc_count);
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
@@ -1296,19 +1300,33 @@ impl GetDocCount for (u64, Bucket) {
}
}
/// Truncates `entries` to the top `num_elem` and returns
/// `(term_doc_count_before_cutoff, sum_other_doc_count)`.
///
/// When `total_doc_count` is `Some`, `sum_other_doc_count` is derived as `total - sum(kept)`, which
/// only sums the `num_elem` kept entries instead of the (potentially far larger) cut-off tail.
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
entries: &mut Vec<T>,
num_elem: usize,
total_doc_count: Option<u64>,
) -> (u64, u64) {
let term_doc_count_before_cutoff = entries
.get(num_elem)
.map(|entry| entry.doc_count())
.unwrap_or(0);
let sum_other_doc_count = entries
.get(num_elem..)
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
.unwrap_or(0);
let sum_other_doc_count = match total_doc_count {
// Reuse the precomputed total: sum_other = total - sum(kept top-k), summing only the
// (small) kept slice. Fewer than `num_elem` buckets means nothing is cut off, so 0.
Some(total) => entries
.get(..num_elem)
.map(|kept| total - kept.iter().map(|entry| entry.doc_count()).sum::<u64>())
.unwrap_or(0),
None => entries
.get(num_elem..)
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
.unwrap_or(0),
};
entries.truncate(num_elem);
(term_doc_count_before_cutoff, sum_other_doc_count)

View File

@@ -807,7 +807,7 @@ impl IntermediateTermBucketResult {
// This can be interesting, as a value of quality of the results, but not good to check the
// actual error count for the returned terms.
let (_term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut buckets, req.size as usize);
cut_off_buckets(&mut buckets, req.size as usize, None);
let doc_count_error_upper_bound = if req.show_term_doc_count_error {
Some(self.doc_count_error_upper_bound)