mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-07-05 16:50:42 +00:00
don't count matching doc twice
This commit is contained in:
@@ -342,7 +342,7 @@ pub const MAX_NUM_TERMS_FOR_VEC: u64 = 100;
|
||||
/// Average docs-per-bucket below which term counts cluster too tightly (mostly 1s and 2s) for
|
||||
/// `select_nth_unstable` to beat `sort_unstable`'s adaptive paths, so we fall back to a full sort.
|
||||
/// This is very low on purpose, and meant to catch unique or mostly unique terms.
|
||||
const DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD: u32 = 2;
|
||||
const DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD: u64 = 2;
|
||||
|
||||
/// Build a concrete `SegmentTermCollector` with either a Vec- or HashMap-backed
|
||||
/// bucket storage, depending on the column type and aggregation level.
|
||||
@@ -999,6 +999,10 @@ where
|
||||
|
||||
let segment_size = term_req.req.segment_size as usize;
|
||||
|
||||
// Total doc count over all buckets, computed is some case, and which can be reused by
|
||||
// `cut_off_buckets` to derive `sum_other_doc_count` without a second pass.
|
||||
let mut total_doc_count: Option<u64> = None;
|
||||
|
||||
// select_nth_unstable_by_key(segment_size, ...) places the (k+1)-th element at
|
||||
// entries[segment_size] and guarantees entries[0..segment_size] are the top-k,
|
||||
// unordered. We need this to properly compute term_doc_count_before_cutoff.
|
||||
@@ -1065,10 +1069,10 @@ where
|
||||
if entries.len() > segment_size {
|
||||
// unique or near-unique fields create big runs of sorted values (ones),
|
||||
// which is defavorable to quickselect. use the then faster sort_unstable.
|
||||
let num_buckets = entries.len() as u32;
|
||||
let num_docs: u32 = entries.iter().map(|(_, b)| b.count).sum();
|
||||
let near_unique =
|
||||
num_docs < num_buckets * DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD;
|
||||
let num_buckets = entries.len() as u64;
|
||||
let num_docs = entries.iter().map(|(_, b)| b.count as u64).sum();
|
||||
total_doc_count = Some(num_docs);
|
||||
let near_unique = num_docs < num_buckets * DOCS_PER_BUCKET_QUICKSELECT_THRESHOLD;
|
||||
if term_req.req.order.order == Order::Desc {
|
||||
if near_unique {
|
||||
entries.sort_unstable_by_key(|b| std::cmp::Reverse(b.1.count));
|
||||
@@ -1087,7 +1091,7 @@ where
|
||||
}
|
||||
|
||||
let (term_doc_count_before_cutoff, sum_other_doc_count) =
|
||||
cut_off_buckets(&mut entries, segment_size);
|
||||
cut_off_buckets(&mut entries, segment_size, total_doc_count);
|
||||
|
||||
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
|
||||
dict.reserve(entries.len());
|
||||
@@ -1296,19 +1300,33 @@ impl GetDocCount for (u64, Bucket) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates `entries` to the top `num_elem` and returns
|
||||
/// `(term_doc_count_before_cutoff, sum_other_doc_count)`.
|
||||
///
|
||||
/// When `total_doc_count` is `Some`, `sum_other_doc_count` is derived as `total - sum(kept)`, which
|
||||
/// only sums the `num_elem` kept entries instead of the (potentially far larger) cut-off tail.
|
||||
pub(crate) fn cut_off_buckets<T: GetDocCount + Debug>(
|
||||
entries: &mut Vec<T>,
|
||||
num_elem: usize,
|
||||
total_doc_count: Option<u64>,
|
||||
) -> (u64, u64) {
|
||||
let term_doc_count_before_cutoff = entries
|
||||
.get(num_elem)
|
||||
.map(|entry| entry.doc_count())
|
||||
.unwrap_or(0);
|
||||
|
||||
let sum_other_doc_count = entries
|
||||
.get(num_elem..)
|
||||
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
|
||||
.unwrap_or(0);
|
||||
let sum_other_doc_count = match total_doc_count {
|
||||
// Reuse the precomputed total: sum_other = total - sum(kept top-k), summing only the
|
||||
// (small) kept slice. Fewer than `num_elem` buckets means nothing is cut off, so 0.
|
||||
Some(total) => entries
|
||||
.get(..num_elem)
|
||||
.map(|kept| total - kept.iter().map(|entry| entry.doc_count()).sum::<u64>())
|
||||
.unwrap_or(0),
|
||||
None => entries
|
||||
.get(num_elem..)
|
||||
.map(|cut_off_range| cut_off_range.iter().map(|entry| entry.doc_count()).sum())
|
||||
.unwrap_or(0),
|
||||
};
|
||||
|
||||
entries.truncate(num_elem);
|
||||
(term_doc_count_before_cutoff, sum_other_doc_count)
|
||||
|
||||
@@ -807,7 +807,7 @@ impl IntermediateTermBucketResult {
|
||||
// This can be interesting, as a value of quality of the results, but not good to check the
|
||||
// actual error count for the returned terms.
|
||||
let (_term_doc_count_before_cutoff, sum_other_doc_count) =
|
||||
cut_off_buckets(&mut buckets, req.size as usize);
|
||||
cut_off_buckets(&mut buckets, req.size as usize, None);
|
||||
|
||||
let doc_count_error_upper_bound = if req.show_term_doc_count_error {
|
||||
Some(self.doc_count_error_upper_bound)
|
||||
|
||||
Reference in New Issue
Block a user