From 348ca1e3097e807cd27ebc08adf69941f2b97e8e Mon Sep 17 00:00:00 2001 From: "trinity.pointard" Date: Tue, 30 Jun 2026 16:09:11 +0000 Subject: [PATCH] don't count matching doc twice --- src/aggregation/bucket/term_agg/mod.rs | 38 ++++++++++++++++------ src/aggregation/intermediate_agg_result.rs | 2 +- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/aggregation/bucket/term_agg/mod.rs b/src/aggregation/bucket/term_agg/mod.rs index bd146baca..1197e9f8c 100644 --- a/src/aggregation/bucket/term_agg/mod.rs +++ b/src/aggregation/bucket/term_agg/mod.rs @@ -342,7 +342,7 @@ pub const MAX_NUM_TERMS_FOR_VEC: u64 = 100; /// Average docs-per-bucket below which term counts cluster too tightly (mostly 1s and 2s) for /// `select_nth_unstable` to beat `sort_unstable`'s adaptive paths, so we fall back to a full sort. /// This is very low on purpose, and meant to catch unique or mostly unique terms. -const DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD: u32 = 2; +const DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD: u64 = 2; /// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed /// bucket storage, depending on the column type and aggregation level. @@ -999,6 +999,10 @@ where let segment_size = term_req.req.segment_size as usize; + // Total doc count over all buckets, computed is some case, and which can be reused by + // `cut_off_buckets` to derive `sum_other_doc_count` without a second pass. + let mut total_doc_count: Option = None; + // select_nth_unstable_by_key(segment_size, ...) places the (k+1)-th element at // entries[segment_size] and guarantees entries[0..segment_size] are the top-k, // unordered. We need this to properly compute term_doc_count_before_cutoff. @@ -1065,10 +1069,10 @@ where if entries.len() > segment_size { // unique or near-unique fields create big runs of sorted values (ones), // which is defavorable to quickselect. use the then faster sort_unstable. - let num_buckets = entries.len() as u32; - let num_docs: u32 = entries.iter().map(|(_, b)| b.count).sum(); - let near_unique = - num_docs < num_buckets * DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD; + let num_buckets = entries.len() as u64; + let num_docs = entries.iter().map(|(_, b)| b.count as u64).sum(); + total_doc_count = Some(num_docs); + let near_unique = num_docs < num_buckets * DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD; if term_req.req.order.order == Order::Desc { if near_unique { entries.sort_unstable_by_key(|b| std::cmp::Reverse(b.1.count)); @@ -1087,7 +1091,7 @@ where } let (term_doc_count_before_cutoff, sum_other_doc_count) = - cut_off_buckets(&mut entries, segment_size); + cut_off_buckets(&mut entries, segment_size, total_doc_count); let mut dict: FxHashMap = Default::default(); dict.reserve(entries.len()); @@ -1296,19 +1300,33 @@ impl GetDocCount for (u64, Bucket) { } } +/// Truncates `entries` to the top `num_elem` and returns +/// `(term_doc_count_before_cutoff, sum_other_doc_count)`. +/// +/// When `total_doc_count` is `Some`, `sum_other_doc_count` is derived as `total - sum(kept)`, which +/// only sums the `num_elem` kept entries instead of the (potentially far larger) cut-off tail. pub(crate) fn cut_off_buckets( entries: &mut Vec, num_elem: usize, + total_doc_count: Option, ) -> (u64, u64) { let term_doc_count_before_cutoff = entries .get(num_elem) .map(|entry| entry.doc_count()) .unwrap_or(0); - let sum_other_doc_count = entries - .get(num_elem..) - .map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum()) - .unwrap_or(0); + let sum_other_doc_count = match total_doc_count { + // Reuse the precomputed total: sum_other = total - sum(kept top-k), summing only the + // (small) kept slice. Fewer than `num_elem` buckets means nothing is cut off, so 0. + Some(total) => entries + .get(..num_elem) + .map(|kept| total - kept.iter().map(|entry| entry.doc_count()).sum::()) + .unwrap_or(0), + None => entries + .get(num_elem..) + .map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum()) + .unwrap_or(0), + }; entries.truncate(num_elem); (term_doc_count_before_cutoff, sum_other_doc_count) diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index f639cdb68..b2cd83d81 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -807,7 +807,7 @@ impl IntermediateTermBucketResult { // This can be interesting, as a value of quality of the results, but not good to check the // actual error count for the returned terms. let (_term_doc_count_before_cutoff, sum_other_doc_count) = - cut_off_buckets(&mut buckets, req.size as usize); + cut_off_buckets(&mut buckets, req.size as usize, None); let doc_count_error_upper_bound = if req.show_term_doc_count_error { Some(self.doc_count_error_upper_bound)