Compare commits

..

7 Commits

Author SHA1 Message Date
pascal
72a1d4a33c Avoid scoring buffered unions when scores are ignored
BufferedUnionScorer can use score_doc during refill only when the score combiner needs scores. DoNothingCombiner now advertises that scoring is unnecessary, preserving the no-score path for count collectors and avoiding wasted score_doc calls.

Add a regression test that verifies DoNothingCombiner does not invoke score() or score_doc() while counting a buffered union.
2026-06-12 09:29:15 +02:00
pascal
91db5c55dc cargo fmt, remove impl 2026-06-12 09:29:15 +02:00
pascal
0cbed3dd3f Clarify postings copy variable names 2026-06-12 09:29:15 +02:00
pascal
8e20e5914a Share BM25 fieldnorm caches per thread
Reuse BM25 TF normalization caches for weights with the same average fieldnorm using a bounded thread-local LRU. This avoids recomputing and duplicating the cache for many terms on the same field without adding cross-thread contention.
2026-06-12 09:29:15 +02:00
pascal
b99de692c1 Optimize buffered union scoring with block refills
Add horizon-limited buffering APIs for docsets and scorers so buffered union can refill from block-oriented postings while preserving term frequencies. This lets term scorers score buffered docs directly and reduces per-document refill overhead for dense unions.
2026-06-12 09:29:14 +02:00
pascal
eb1aabf22c Split buffered refill from scorer removal 2026-06-12 09:29:14 +02:00
pascal
cc28fb70e8 Defer terminated scorer removal during buffered refill 2026-06-12 09:29:14 +02:00
32 changed files with 774 additions and 1706 deletions

View File

@@ -66,9 +66,6 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_status_with_terms_zipf_1000_sub_agg);
register!(group, terms_zipf_1000_with_terms_status_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_status_with_date_histogram);
register!(group, terms_status_with_date_histogram_hard_bounds);
register!(group, terms_status_with_date_histogram_and_sibling_terms);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
@@ -393,57 +390,6 @@ fn terms_status_with_histogram(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_status_with_date_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"over_time": { "date_histogram": { "field": "timestamp", "fixed_interval": "1h" } }
}
}
});
execute_agg(index, agg_req);
}
/// Same fused terms × date_histogram, but with `hard_bounds`. The timestamps span 0..120h; the
/// bounds drop only the first and last hour (ms: 1h=3_600_000, 119h=428_400_000), so almost every
/// doc is in-bounds. This exercises the collector's hard-bounds path: `bounds.contains` runs per
/// doc (the `all_docs_in_bounds` short-circuit is off) and the rare out-of-bounds doc takes the
/// `term_counts` branch.
fn terms_status_with_date_histogram_hard_bounds(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"over_time": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": "1h",
"hard_bounds": { "min": 3_600_000, "max": 428_400_000 }
}
}
}
}
});
execute_agg(index, agg_req);
}
/// Same fused terms × date_histogram, but with a sibling terms aggregation next to it. The fused
/// fast path should still trigger for `my_texts` (sibling aggregations are independent top-level
/// aggregations, so they don't change its eligibility).
fn terms_status_with_date_histogram_and_sibling_terms(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"over_time": { "date_histogram": { "field": "timestamp", "fixed_interval": "1h" } }
}
},
"other_texts": { "terms": { "field": "text_few_terms" } }
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -837,9 +783,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
doc_with_value /= 20;
}
let _val_max = 1_000_000.0;
const SPAN_MS: i64 = 120 * 3600 * 1000; // 120 hours in ms
const NOISE_MS: i64 = 2 * 3600 * 1000; // ±2h noise
for i in 0..doc_with_value {
for _ in 0..doc_with_value {
let val: f64 = rng.random_range(0.0..1_000_000.0);
let json = if rng.random_bool(0.1) {
// 10% are numeric values
@@ -847,9 +791,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
} else {
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
};
let base_ms = (i as i64 * SPAN_MS) / doc_with_value as i64;
let noise_ms = rng.random_range(-NOISE_MS..NOISE_MS);
let ts_ms = (base_ms + noise_ms).clamp(0, SPAN_MS);
index_writer.add_document(doc!(
single_term => "single_term",
text_field => "cool",
@@ -862,7 +803,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
date_field => DateTime::from_timestamp_millis(ts_ms),
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
))?;
if cardinality == Cardinality::OptionalSparse {
for _ in 0..20 {

View File

@@ -110,31 +110,43 @@ fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
(
"dense and 0.1% a".to_string(),
5_000_000,
0.001,
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
0,
9,
),
("dense and 1% a".to_string(), 5_000_000, 0.01, "dense", 0, 9),
("dense and 10% a".to_string(), 5_000_000, 0.1, "dense", 0, 9),
(
"sparse and 50% a".to_string(),
5_000_000,
"dense and 99% a".to_string(),
10_000_000,
0.99,
"dense",
990,
999,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
0,
9,
),
(
"sparse and 99% a".to_string(),
10_000_000,
0.99,
"sparse",
9_999_990,
9_999_999,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, num_docs, p_title_a, num_rand_distribution, range_low, range_high) in
scenarios
{
for (scenario_id, n, p_title_a, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(num_docs, p_title_a, num_rand_distribution);
let bench_index = build_shared_indices(n, p_title_a, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
@@ -146,7 +158,7 @@ fn main() {
let field_names = ["num_rand", "num_asc", "num_rand_fast", "num_asc_fast"];
// Define the three terms we want to test with
let terms = ["a"];
let terms = ["a", "b", "z"];
// Generate all combinations of terms and field names
let mut queries = Vec::new();
@@ -191,7 +203,7 @@ fn run_benchmark_tasks(
bench_index,
query_str,
DocSetCollector,
"all_results",
"all results",
);
// Test top 100 by the field (if it's a FAST field)

View File

@@ -15,37 +15,9 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
{
#[inline]
pub fn fetch_block<'a>(&'a mut self, docs: &'a [u32], accessor: &Column<T>) {
self.fetch_block_with_is_full(docs, accessor, accessor.index.get_cardinality().is_full());
}
/// Like [`Self::fetch_block`] but takes the column's fullness instead of querying
/// `accessor.index.get_cardinality()` each call — for callers that know it up front (e.g.
/// checked once at construction). `is_full` must equal
/// `accessor.index.get_cardinality().is_full()`.
#[inline]
pub fn fetch_block_with_is_full<'a>(
&'a mut self,
docs: &'a [u32],
accessor: &Column<T>,
is_full: bool,
) {
if is_full {
// Skip the resize when already the right length (common case: fixed-size blocks).
if self.val_cache.len() != docs.len() {
self.val_cache.resize(docs.len(), T::default());
}
// When the docs form a contiguous ascending run we can fetch the values
// as a single range. This lets codecs (e.g. bitpacked) bulk-decode the
// slice instead of gathering value-by-value, and avoids per-value dynamic
// dispatch. `docs` is always sorted ascending and free of duplicates here,
// so comparing the endpoints is enough to detect contiguity.
if is_contiguous(docs) {
accessor
.values
.get_range(docs[0] as u64, &mut self.val_cache);
} else {
accessor.values.get_vals(docs, &mut self.val_cache);
}
if accessor.index.get_cardinality().is_full() {
self.val_cache.resize(docs.len(), T::default());
accessor.values.get_vals(docs, &mut self.val_cache);
} else {
self.docid_cache.clear();
self.row_id_cache.clear();
@@ -186,22 +158,6 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
/// Returns true if `docs` is a contiguous ascending run `[d, d + 1, ..., d + n - 1]`.
///
/// Assumes `docs` is sorted ascending and free of duplicates (the invariant for the
/// doc blocks passed to `fetch_block`), so comparing the endpoints is sufficient.
#[inline]
fn is_contiguous(docs: &[u32]) -> bool {
let (Some(&first), Some(&last)) = (docs.first(), docs.last()) else {
return false;
};
debug_assert!(
docs.windows(2).all(|w| w[0] < w[1]),
"fetch_block requires docs sorted ascending without duplicates"
);
(last - first) as usize + 1 == docs.len()
}
/// Given two sorted lists of docids `docs` and `hits`, hits is a subset of `docs`.
/// Return all docs that are not in `hits`.
fn find_missing_docs<F>(docs: &[u32], hits: &[u32], mut callback: F)
@@ -332,46 +288,4 @@ mod tests {
assert_eq!(accessor.docid_cache, vec![0]);
assert_eq!(accessor.val_cache, vec![1]);
}
#[test]
fn test_is_contiguous() {
assert!(!is_contiguous(&[]));
assert!(is_contiguous(&[5]));
assert!(is_contiguous(&[5, 6, 7, 8]));
assert!(is_contiguous(&[0, 1, 2]));
assert!(!is_contiguous(&[5, 7, 8]));
assert!(!is_contiguous(&[0, 1, 3]));
}
#[test]
fn test_fetch_block_contiguous_and_gather_match() {
use crate::column_index::ColumnIndex;
use crate::column_values::{
ALL_U64_CODEC_TYPES, serialize_and_load_u64_based_column_values,
};
let vals: Vec<u64> = (0..200u64).map(|i| i * 7 + 3).collect();
let values =
serialize_and_load_u64_based_column_values::<u64>(&&vals[..], &ALL_U64_CODEC_TYPES);
let column = Column {
index: ColumnIndex::Full,
values,
};
let check = |accessor: &mut ColumnBlockAccessor<u64>, docs: &[u32]| {
accessor.fetch_block(docs, &column);
let got: Vec<(u32, u64)> = accessor.iter_docid_vals(docs, &column).collect();
let expected: Vec<(u32, u64)> = docs.iter().map(|&d| (d, vals[d as usize])).collect();
assert_eq!(got, expected);
};
let mut accessor = ColumnBlockAccessor::<u64>::default();
// Contiguous block -> get_range fast path.
check(&mut accessor, &(10..74).collect::<Vec<u32>>());
// Non-contiguous block -> get_vals gather path.
check(&mut accessor, &[0, 5, 9, 100, 199]);
// Single doc and full span.
check(&mut accessor, &[42]);
check(&mut accessor, &(0..200).collect::<Vec<u32>>());
}
}

View File

@@ -119,18 +119,8 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync + DowncastSync {
/// the segment's `maxdoc`.
#[inline(always)]
fn get_range(&self, start: u64, output: &mut [T]) {
let mut out_chunks = output.chunks_exact_mut(4);
let mut idx = start;
for out_x4 in out_chunks.by_ref() {
out_x4[0] = self.get_val(idx as u32);
out_x4[1] = self.get_val((idx + 1) as u32);
out_x4[2] = self.get_val((idx + 2) as u32);
out_x4[3] = self.get_val((idx + 3) as u32);
idx += 4;
}
for out in out_chunks.into_remainder() {
for (out, idx) in output.iter_mut().zip(start..) {
*out = self.get_val(idx as u32);
idx += 1;
}
}

View File

@@ -121,22 +121,6 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
reader.get_vals(&all_docs, &mut buffer);
assert_eq!(vals, buffer);
// Validate `get_range` over the full column and a sub-range. The sub-range starts
// at a non-zero offset to exercise the entrance-ramp alignment of the batch decode.
buffer.resize(all_docs.len(), 0);
reader.get_range(0, &mut buffer);
assert_eq!(vals, buffer, "get_range (full) mismatch in data set {name}");
if vals.len() >= 2 {
let start = 1usize;
buffer.resize(vals.len() - start, 0);
reader.get_range(start as u64, &mut buffer);
assert_eq!(
&vals[start..],
&buffer[..],
"get_range (sub-range) mismatch in data set {name}"
);
}
if !vals.is_empty() {
let test_rand_idx = rand::rng().random_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals

View File

@@ -327,9 +327,7 @@ fn exists(inp: &str) -> IResult<&str, UserInputLeaf> {
peek(alt((
value(
"",
satisfy(|c: char| {
c.is_whitespace() || (ESCAPE_IN_WORD.contains(&c) && c != '\\')
}),
satisfy(|c: char| c.is_whitespace() || ESCAPE_IN_WORD.contains(&c)),
),
eof,
))),
@@ -347,9 +345,7 @@ fn exists_precond(inp: &str) -> IResult<&str, (), ()> {
peek(alt((
value(
"",
satisfy(|c: char| {
c.is_whitespace() || (ESCAPE_IN_WORD.contains(&c) && c != '\\')
}),
satisfy(|c: char| c.is_whitespace() || ESCAPE_IN_WORD.contains(&c)),
),
eof,
))), // we need to check this isn't a wildcard query
@@ -711,7 +707,6 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
peek(alt((
value((), multispace1),
value((), char(')')),
value((), char('^')),
value((), eof),
))),
),
@@ -733,10 +728,9 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
peek(alt((
value((), multispace1),
value((), char(')')),
value((), char('^')),
value((), eof),
))),
"expected whitespace, closing parenthesis, boost, or end of input",
"expected whitespace, closing parenthesis, or end of input",
),
)(inp)
{
@@ -779,10 +773,6 @@ fn leaf(inp: &str) -> IResult<&str, UserInputAst> {
value((), multispace1),
value((), char(')')),
value((), eof),
value(
(),
satisfy(|c: char| ESCAPE_IN_WORD.contains(&c) && c != '\\'),
),
))),
),
|_| UserInputAst::from(UserInputLeaf::All),
@@ -815,10 +805,6 @@ fn leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
value((), multispace1),
value((), char(')')),
value((), eof),
value(
(),
satisfy(|c: char| ESCAPE_IN_WORD.contains(&c) && c != '\\'),
),
))),
),
),
@@ -1765,8 +1751,6 @@ mod test {
test_parse_query_to_ast_helper("*", "*");
test_parse_query_to_ast_helper("(*)", "*");
test_parse_query_to_ast_helper("(* )", "*");
// All query with boost
test_parse_query_to_ast_helper("*^2", "(*)^2");
}
#[test]
@@ -1829,7 +1813,6 @@ mod test {
test_parse_query_to_ast_helper("a:b*", "\"a\":b*");
test_parse_query_to_ast_helper("a:*b", "\"a\":*b");
test_parse_query_to_ast_helper(r#"a:*def*"#, "\"a\":*def*");
test_parse_query_to_ast_helper("a:*\\:foo", "\"a\":*:foo");
}
#[test]
@@ -1894,8 +1877,6 @@ mod test {
},
_ => panic!("Expected a leaf"),
}
// Regex followed by `^boost`
test_parse_query_to_ast_helper(r#"foo:/bar/^2"#, r#"("foo":/bar/)^2"#);
}
#[test]

View File

@@ -10,11 +10,11 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_histogram_collector,
build_segment_range_collector, CompositeAggReqData, CompositeAggregation,
CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, TermMissingAgg, TermsAggReqData,
TermsAggregation, TermsAggregationInternal,
build_segment_filter_collector, build_segment_range_collector, CompositeAggReqData,
CompositeAggregation, CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData,
HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
@@ -338,8 +338,12 @@ pub(crate) fn build_segment_agg_collector(
req_data.segment_ordinal,
)))
}
AggKind::Histogram => build_segment_histogram_collector(req, node),
AggKind::DateHistogram => build_segment_histogram_collector(req, node),
AggKind::Histogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Composite => Ok(Box::new(

View File

@@ -299,12 +299,6 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_sum(&self) -> Option<&SumAggregation> {
match &self {
AggregationVariants::Sum(sum) => Some(sum),
_ => None,
}
}
}
#[cfg(test)]

View File

@@ -244,52 +244,19 @@ impl Display for HistogramBounds {
}
impl HistogramBounds {
pub(crate) fn contains(&self, val: f64) -> bool {
fn contains(&self, val: f64) -> bool {
val >= self.min && val <= self.max
}
}
/// The per-bucket identifier stored in a [`SegmentHistogramBucketEntry`].
///
/// It is [`BucketId`] when the histogram has sub aggregations (which key their state by it), and
/// the zero-sized `()` when it does not. Without sub aggregations the id is never read, so storing
/// `()` drops 8 bytes per bucket (24 -> 16) and turns id assignment into a no-op.
pub trait BucketIdSlot: Copy + Default + std::fmt::Debug + PartialEq {
/// Assigns the next id from the provider, called once when a bucket is first filled.
fn assign(provider: &mut BucketIdProvider) -> Self;
/// Resolves to the `BucketId` for sub-aggregation bookkeeping.
///
/// Only ever called for the [`BucketId`] slot: the `()` slot is used exactly when there are no
/// sub aggregations, so every call site is guarded by `sub_agg.is_some()` and is dead for `()`.
fn to_bucket_id(self) -> BucketId;
}
impl BucketIdSlot for BucketId {
#[inline(always)]
fn assign(provider: &mut BucketIdProvider) -> Self {
provider.next_bucket_id()
}
#[inline(always)]
fn to_bucket_id(self) -> BucketId {
self
}
}
impl BucketIdSlot for () {
#[inline(always)]
fn assign(_provider: &mut BucketIdProvider) -> Self {}
#[inline(always)]
fn to_bucket_id(self) -> BucketId {
unreachable!("bucket ids are only resolved when sub aggregations are present")
}
}
#[derive(Default, Clone, Debug, PartialEq)]
pub(crate) struct SegmentHistogramBucketEntry<B> {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: B,
pub bucket_id: BucketId,
}
impl<B: BucketIdSlot> SegmentHistogramBucketEntry<B> {
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardBufferedSubAggs>,
@@ -302,7 +269,7 @@ impl<B: BucketIdSlot> SegmentHistogramBucketEntry<B> {
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id.to_bucket_id(),
self.bucket_id,
)?;
}
Ok(IntermediateHistogramBucketEntry {
@@ -313,140 +280,29 @@ impl<B: BucketIdSlot> SegmentHistogramBucketEntry<B> {
}
}
/// The contiguous bucket range a histogram can span, derived from the column min/max (clamped to
/// the histogram bounds). Buckets in `[base_pos, base_pos + len)` can be stored in a flat `Vec`
/// indexed by `bucket_pos - base_pos`, avoiding the hash map on the hot path.
#[derive(Clone, Copy, Debug)]
pub(crate) struct DenseRange {
/// `bucket_pos` mapped to index 0 of the dense `Vec`.
pub(crate) base_pos: i64,
/// Number of bucket positions in the range.
pub(crate) len: usize,
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// Storage for the histogram buckets of a single parent bucket.
///
/// Starts out sparse (a hash map keyed by `bucket_pos`). Once enough distinct buckets have been
/// filled that we are clearly going to cover most of the column's theoretical range, it switches
/// to a dense `Vec` indexed by `bucket_pos - base_pos`, which removes hashing from the hot loop.
#[derive(Clone, Debug)]
enum HistogramBuckets<B> {
Sparse(FxHashMap<i64, SegmentHistogramBucketEntry<B>>),
Dense {
base_pos: i64,
/// One slot per bucket position; a slot with `doc_count == 0` has not been hit yet.
buckets: Vec<SegmentHistogramBucketEntry<B>>,
},
}
impl<B> Default for HistogramBuckets<B> {
fn default() -> Self {
HistogramBuckets::Sparse(FxHashMap::default())
}
}
impl<B: BucketIdSlot> HistogramBuckets<B> {
impl HistogramBuckets {
fn memory_consumption(&self) -> u64 {
let num_slots = match self {
HistogramBuckets::Sparse(map) => map.capacity(),
HistogramBuckets::Dense { buckets, .. } => buckets.capacity(),
};
num_slots as u64 * std::mem::size_of::<SegmentHistogramBucketEntry<B>>() as u64
}
/// Switches from sparse to dense storage once the dense `Vec` would use no more memory than the
/// hash map does now, so the switch never increases memory. Called at block boundaries.
///
/// The `Vec` holds one `Entry` per bucket position in the range. The map additionally stores
/// the key and a control byte per slot, at a load factor of 7/16..7/8, so for a dense histogram
/// its footprint grows past the `Vec` well before full coverage. And since the `Vec` never
/// grows afterwards while the map would keep growing, dense only gets relatively cheaper — so
/// no upper bound on the range is needed: a large but sparse range simply never crosses over.
#[inline]
fn maybe_densify(&mut self, dense_range: Option<DenseRange>) {
let Some(range) = dense_range else { return };
let HistogramBuckets::Sparse(map) = self else {
return;
};
let dense_bytes = range
.len
.saturating_mul(std::mem::size_of::<SegmentHistogramBucketEntry<B>>());
let sparse_bytes = map
.capacity()
.saturating_mul(std::mem::size_of::<(i64, SegmentHistogramBucketEntry<B>)>() + 1);
if dense_bytes > sparse_bytes {
return;
}
let map = std::mem::take(map);
let mut buckets = vec![SegmentHistogramBucketEntry::<B>::default(); range.len];
for (bucket_pos, entry) in map {
buckets[(bucket_pos - range.base_pos) as usize] = entry;
}
*self = HistogramBuckets::Dense {
base_pos: range.base_pos,
buckets,
};
}
/// Returns the bucket entry for `bucket_pos`, setting its key (and `bucket_id`, when `B` is
/// [`BucketId`]) on first use.
///
/// For the dense variant `bucket_pos` is guaranteed to be inside the range, since it is
/// derived from the column min/max that bounds every value (see [`compute_dense_range`]).
#[inline]
fn get_or_create(
&mut self,
bucket_pos: i64,
bucket_id_provider: &mut BucketIdProvider,
key_from_pos: impl FnOnce(i64) -> f64,
) -> &mut SegmentHistogramBucketEntry<B> {
match self {
HistogramBuckets::Sparse(map) => {
map.entry(bucket_pos)
.or_insert_with(|| SegmentHistogramBucketEntry {
key: key_from_pos(bucket_pos),
doc_count: 0,
bucket_id: B::assign(bucket_id_provider),
})
}
HistogramBuckets::Dense { base_pos, buckets } => {
let idx = (bucket_pos - *base_pos) as usize;
debug_assert!(idx < buckets.len(), "bucket_pos outside the dense range");
let entry = &mut buckets[idx];
if entry.doc_count == 0 {
entry.key = key_from_pos(bucket_pos);
entry.bucket_id = B::assign(bucket_id_provider);
}
entry
}
}
}
/// Consumes the storage, yielding all non-empty bucket entries.
fn into_filled_entries(self) -> Vec<SegmentHistogramBucketEntry<B>> {
match self {
HistogramBuckets::Sparse(map) => map.into_values().collect(),
HistogramBuckets::Dense { buckets, .. } => {
buckets.into_iter().filter(|b| b.doc_count > 0).collect()
}
}
self.buckets.capacity() as u64 * std::mem::size_of::<SegmentHistogramBucketEntry>() as u64
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
pub struct SegmentHistogramCollector<B> {
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets<B>>,
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardBufferedSubAggs>,
req_data: HistogramAggReqData,
bucket_id_provider: BucketIdProvider,
/// Theoretical bucket range derived from the column min/max, if dense `Vec` storage is
/// viable. `None` keeps every parent bucket in the sparse hash map.
dense_range: Option<DenseRange>,
}
impl<B: BucketIdSlot> SegmentAggregationCollector for SegmentHistogramCollector<B> {
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -471,10 +327,7 @@ impl<B: BucketIdSlot> SegmentAggregationCollector for SegmentHistogramCollector<
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let dense_range = self.dense_range;
let store = &mut self.parent_buckets[parent_bucket_id as usize];
// Upgrade to dense storage before processing the block if the buckets are dense enough.
store.maybe_densify(dense_range);
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let req = &self.req_data;
let bounds = req.bounds;
@@ -485,42 +338,30 @@ impl<B: BucketIdSlot> SegmentAggregationCollector for SegmentHistogramCollector<
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
// special path for nested buckets
if let Some(sub_agg) = &mut self.sub_agg {
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
if bounds.contains(val) {
let bucket = store.get_or_create(
get_bucket_pos(val),
&mut self.bucket_id_provider,
|pos| get_bucket_key_from_pos(pos as f64, interval, offset),
);
bucket.doc_count += 1;
sub_agg.push(bucket.bucket_id.to_bucket_id(), doc);
}
}
} else {
for val in agg_data.column_block_accessor.iter_vals() {
let val = f64_from_fastfield_u64(val, req.field_type);
if bounds.contains(val) {
let bucket = store.get_or_create(
get_bucket_pos(val),
&mut self.bucket_id_provider,
|pos| get_bucket_key_from_pos(pos as f64, interval, offset),
);
bucket.doc_count += 1;
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
});
bucket.doc_count += 1;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
// `checked_sub` is `None` when densifying shrank the accounted memory; only account growth.
if let Some(mem_delta) = self
.get_memory_consumption(parent_bucket_id)
.checked_sub(mem_pre)
{
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
@@ -544,7 +385,9 @@ impl<B: BucketIdSlot> SegmentAggregationCollector for SegmentHistogramCollector<
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets::default());
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
@@ -561,7 +404,7 @@ impl<B: BucketIdSlot> SegmentAggregationCollector for SegmentHistogramCollector<
}
}
impl<B: BucketIdSlot> SegmentHistogramCollector<B> {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
}
@@ -570,12 +413,11 @@ impl<B: BucketIdSlot> SegmentHistogramCollector<B> {
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets<B>,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let filled = histogram.into_filled_entries();
let mut buckets = Vec::with_capacity(filled.len());
let mut buckets = Vec::with_capacity(histogram.buckets.len());
for bucket in filled {
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
buckets.push(bucket_res?);
@@ -599,18 +441,19 @@ impl<B: BucketIdSlot> SegmentHistogramCollector<B> {
None
};
let mut req_data = agg_data.per_request.histogram_req_data[node.idx_in_req_data].clone();
normalize_histogram_req(&mut req_data)?;
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
agg_data
.context
.limits
.add_memory_consumed(req_data.get_memory_consumption() as u64)?;
let dense_range = compute_dense_range(
&req_data.accessor,
req_data.field_type,
req_data.req.interval,
req_data.offset,
req_data.bounds,
);
let sub_agg = sub_agg.map(BufferedSubAggs::new);
Ok(Self {
@@ -618,155 +461,15 @@ impl<B: BucketIdSlot> SegmentHistogramCollector<B> {
sub_agg,
req_data,
bucket_id_provider: BucketIdProvider::default(),
dense_range,
})
}
}
impl SegmentHistogramCollector<()> {
/// Builds a histogram collector whose parent `t` is a dense histogram filled from
/// `counts[t * num_time_buckets .. (t + 1) * num_time_buckets]` (row-major). Used by the fused
/// terms×histogram collector to turn its flat 2D counters into the regular intermediate result,
/// so cross-segment merging is shared with the general path.
pub(crate) fn from_dense_rows(
req_data: HistogramAggReqData,
base_pos: i64,
num_time_buckets: usize,
counts: &[u32],
) -> Self {
let interval = req_data.req.interval;
let offset = req_data.offset;
let num_parents = counts.len().checked_div(num_time_buckets).unwrap_or(0);
let parent_buckets = (0..num_parents)
.map(|t| {
let row = &counts[t * num_time_buckets..(t + 1) * num_time_buckets];
let buckets = row
.iter()
.enumerate()
.map(|(b, &doc_count)| SegmentHistogramBucketEntry {
key: get_bucket_key_from_pos(
(base_pos + b as i64) as f64,
interval,
offset,
),
doc_count: doc_count as u64,
bucket_id: (),
})
.collect();
HistogramBuckets::Dense { base_pos, buckets }
})
.collect();
Self {
parent_buckets,
sub_agg: None,
req_data,
bucket_id_provider: BucketIdProvider::default(),
dense_range: None,
}
}
}
/// Validates and normalizes a histogram request in place: applies date ns-normalization (for a
/// `histogram` on a date column) and resolves `bounds`/`offset` from the request.
fn normalize_histogram_req(req_data: &mut HistogramAggReqData) -> crate::Result<()> {
req_data.req.validate()?;
if req_data.field_type == ColumnType::DateTime && !req_data.is_date_histogram {
req_data.req.normalize_date_time();
}
req_data.bounds = req_data.req.hard_bounds.unwrap_or(HistogramBounds {
min: f64::MIN,
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
// Drop `hard_bounds` that can't exclude any value (the column's range already sits inside
// them): the per-doc `bounds.contains` check is then a no-op, so collapsing to the unbounded
// sentinel lets the histogram hot loop skip it and the fused term×histogram path derive
// per-term counts from the grid. Only this collect-time filter is touched — empty-bucket
// emission reads `req.hard_bounds` directly (see `get_req_min_max`), and `hard_bounds` only
// ever clips that range, so a wider-than-data bound leaves the result unchanged.
if req_data.req.hard_bounds.is_some() {
let col_min = f64_from_fastfield_u64(req_data.accessor.min_value(), req_data.field_type);
let col_max = f64_from_fastfield_u64(req_data.accessor.max_value(), req_data.field_type);
if col_min >= req_data.bounds.min && col_max <= req_data.bounds.max {
req_data.bounds = HistogramBounds {
min: f64::MIN,
max: f64::MAX,
};
}
}
Ok(())
}
/// Clones and normalizes (resolving interval/offset/bounds) the histogram request at `node`, and
/// returns it together with its dense bucket range — or `None` if the column has no usable range.
/// Used by the fused terms×histogram collector, which then owns the normalized request.
pub(crate) fn prepare_histogram_dense_range(
agg_data: &AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Option<(HistogramAggReqData, DenseRange)>> {
let mut req_data = agg_data.per_request.histogram_req_data[node.idx_in_req_data].clone();
normalize_histogram_req(&mut req_data)?;
let dense_range = compute_dense_range(
&req_data.accessor,
req_data.field_type,
req_data.req.interval,
req_data.offset,
req_data.bounds,
);
Ok(dense_range.map(|range| (req_data, range)))
}
/// Builds a boxed histogram (or date histogram) segment collector, picking the bucket-id storage
/// based on whether there are sub aggregations: `()` (no id stored) when there are none, otherwise
/// [`BucketId`].
pub(crate) fn build_segment_histogram_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
if node.children.is_empty() {
Ok(Box::new(
SegmentHistogramCollector::<()>::from_req_and_validate(agg_data, node)?,
))
} else {
Ok(Box::new(
SegmentHistogramCollector::<BucketId>::from_req_and_validate(agg_data, node)?,
))
}
}
#[inline]
pub(crate) fn get_bucket_pos_f64(val: f64, interval: f64, offset: f64) -> f64 {
fn get_bucket_pos_f64(val: f64, interval: f64, offset: f64) -> f64 {
((val - offset) / interval).floor()
}
/// Computes the dense bucket range for a column from its min/max value (clamped to the histogram
/// bounds), or `None` if there are no values within bounds (or the range overflows `usize`).
///
/// There is no upper bound on the range: whether dense storage is actually used is decided later,
/// per parent bucket, by [`HistogramBuckets::maybe_densify`] based on the memory it would save.
///
/// The column min/max bound every value the collector can see, so a `Vec` sized to this range can
/// be indexed by `bucket_pos - base_pos` without any out-of-bounds check on the hot path.
fn compute_dense_range(
accessor: &Column<u64>,
field_type: ColumnType,
interval: f64,
offset: f64,
bounds: HistogramBounds,
) -> Option<DenseRange> {
let col_min = f64_from_fastfield_u64(accessor.min_value(), field_type);
let col_max = f64_from_fastfield_u64(accessor.max_value(), field_type);
let lo = col_min.max(bounds.min);
let hi = col_max.min(bounds.max);
if lo > hi {
return None;
}
let base_pos = get_bucket_pos_f64(lo, interval, offset) as i64;
let top_pos = get_bucket_pos_f64(hi, interval, offset) as i64;
let len = usize::try_from(top_pos.checked_sub(base_pos)?.checked_add(1)?).ok()?;
(len > 0).then_some(DenseRange { base_pos, len })
}
#[inline]
fn get_bucket_key_from_pos(bucket_pos: f64, interval: f64, offset: f64) -> f64 {
bucket_pos * interval + offset
@@ -1071,62 +774,6 @@ mod tests {
Ok(())
}
#[test]
fn histogram_dense_storage_test() -> crate::Result<()> {
histogram_dense_storage_test_with_opt(false)?;
histogram_dense_storage_test_with_opt(true)?;
Ok(())
}
/// Exercises the switch from sparse hash map to dense `Vec` storage. The switch happens at a
/// block boundary (a block is `COLLECT_BLOCK_BUFFER_LEN` = 64 docs), so we need many docs in a
/// single segment, densely covering the bucket range. `with_sub_agg` toggles the `iter_vals`
/// fast path vs. the `iter_docid_vals` path used when there is a sub aggregation.
fn histogram_dense_storage_test_with_opt(with_sub_agg: bool) -> crate::Result<()> {
let num_buckets = 50usize;
let docs_per_bucket = 10usize;
// Value `k` repeated `docs_per_bucket` times for each bucket `k`, so every value in bucket
// `k` equals `k` and the per-bucket average is exactly `k`.
let values: Vec<f64> = (0..num_buckets * docs_per_bucket)
.map(|i| (i % num_buckets) as f64)
.collect();
// `merge_segments = true` collapses the per-value segments into a single segment with all
// the docs, which is collected in 64-doc blocks and therefore switches to dense storage.
let index = get_test_index_from_values(true, &values)?;
let agg_req: Aggregations = serde_json::from_value(if with_sub_agg {
json!({
"histogram": {
"histogram": { "field": "score_f64", "interval": 1.0 },
"aggs": { "avg": { "avg": { "field": "score_f64" } } }
}
})
} else {
json!({
"histogram": {
"histogram": { "field": "score_f64", "interval": 1.0 }
}
})
})
.unwrap();
let res = exec_request(agg_req, &index)?;
for k in 0..num_buckets {
assert_eq!(res["histogram"]["buckets"][k]["key"], k as f64);
assert_eq!(
res["histogram"]["buckets"][k]["doc_count"],
docs_per_bucket as u64
);
if with_sub_agg {
assert_eq!(res["histogram"]["buckets"][k]["avg"]["value"], k as f64);
}
}
assert_eq!(res["histogram"]["buckets"][num_buckets], Value::Null);
Ok(())
}
#[test]
fn histogram_memory_limit() -> crate::Result<()> {
let index = get_test_index_with_num_docs(true, 100)?;
@@ -1421,55 +1068,6 @@ mod tests {
Ok(())
}
#[test]
fn histogram_non_binding_hard_bounds_test_multi_segment() -> crate::Result<()> {
histogram_non_binding_hard_bounds_test_with_opt(false)
}
#[test]
fn histogram_non_binding_hard_bounds_test_single_segment() -> crate::Result<()> {
histogram_non_binding_hard_bounds_test_with_opt(true)
}
/// `hard_bounds` wider than the data (here with mid-interval edges, to cover the "bound cuts a
/// bucket" case) can't exclude any value, so the result must be identical to the same request
/// without bounds. Guards the normalization that collapses such bounds to the unbounded
/// sentinel so the hot loop / fused path can skip the per-doc bounds check.
fn histogram_non_binding_hard_bounds_test_with_opt(merge_segments: bool) -> crate::Result<()> {
let values = vec![10.0, 12.0, 14.0, 16.0, 10.0, 13.0, 10.0, 12.0];
let index = get_test_index_from_values(merge_segments, &values)?;
// Mid-interval edges, but wider than the data range [10, 16] -> they exclude nothing.
let with_bounds: Aggregations = serde_json::from_value(json!({
"histogram": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": 9.5, "max": 16.5 }
}
}
}))
.unwrap();
let no_bounds: Aggregations = serde_json::from_value(json!({
"histogram": {
"histogram": { "field": "score_f64", "interval": 1.0 }
}
}))
.unwrap();
let res_bounds = exec_request(with_bounds, &index)?;
let res_plain = exec_request(no_bounds, &index)?;
// Dropping a non-binding bound must not change anything.
assert_eq!(res_bounds, res_plain);
// Sanity: buckets span the data range with gaps filled (min_doc_count defaults to 0).
assert_eq!(res_bounds["histogram"]["buckets"][0]["key"], 10.0);
assert_eq!(res_bounds["histogram"]["buckets"][0]["doc_count"], 3);
assert_eq!(res_bounds["histogram"]["buckets"][6]["key"], 16.0);
assert_eq!(res_bounds["histogram"]["buckets"][6]["doc_count"], 1);
assert_eq!(res_bounds["histogram"]["buckets"][7], Value::Null);
Ok(())
}
#[test]
fn histogram_empty_result_behaviour_test_single_segment() -> crate::Result<()> {
histogram_empty_result_behaviour_test_with_opt(true)

View File

@@ -29,8 +29,6 @@ use crate::aggregation::{format_date, BucketId, Key};
use crate::error::DataCorruption;
use crate::TantivyError;
mod term_histogram;
/// Contains all information required by the SegmentTermCollector to perform the
/// terms aggregation on a segment.
#[derive(Debug, Clone)]
@@ -376,21 +374,9 @@ pub(crate) fn build_segment_term_collector(
// Let's see if we can use a vec to aggregate our data
// instead of a hashmap.
let col_max_value = terms_req_data.accessor.max_value();
let max_column_val: u64 =
let max_term_id: u64 =
col_max_value.max(terms_req_data.missing_value_for_accessor.unwrap_or(0u64));
// Fused fast path: low-cardinality terms × a single `histogram`/`date_histogram` leaf over full
// columns with a small enough bucket grid. Anything else falls through to the general path.
if let Some(collector) = term_histogram::maybe_build_collector(
req_data,
node,
&terms_req_data,
max_column_val,
is_top_level,
)? {
return Ok(collector);
}
let sub_agg_collector = if has_sub_aggregations {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
@@ -399,30 +385,30 @@ pub(crate) fn build_segment_term_collector(
let mut bucket_id_provider = BucketIdProvider::default();
// Decide which bucket storage is best suited for this aggregation.
if is_top_level && max_column_val < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations {
let term_buckets = VecTermBucketsNoAgg::new(max_column_val + 1, &mut bucket_id_provider);
if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations {
let term_buckets = VecTermBucketsNoAgg::new(max_term_id + 1, &mut bucket_id_provider);
let collector: SegmentTermCollector<_, HighCardSubAggBuffer> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg: None,
bucket_id_provider,
max_term_id: max_column_val,
max_term_id,
terms_req_data,
};
Ok(Box::new(collector))
} else if is_top_level && max_column_val < MAX_NUM_TERMS_FOR_VEC {
let term_buckets = VecTermBuckets::new(max_column_val + 1, &mut bucket_id_provider);
} else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC {
let term_buckets = VecTermBuckets::new(max_term_id + 1, &mut bucket_id_provider);
let sub_agg = sub_agg_collector.map(LowCardBufferedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggBuffer> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
max_term_id: max_column_val,
max_term_id,
terms_req_data,
};
Ok(Box::new(collector))
} else if max_column_val < 8_000_000 && is_top_level {
} else if max_term_id < 8_000_000 && is_top_level {
let term_buckets: PagedTermMap =
PagedTermMap::new(max_column_val + 1, &mut bucket_id_provider);
PagedTermMap::new(max_term_id + 1, &mut bucket_id_provider);
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggBuffer> =
@@ -430,7 +416,7 @@ pub(crate) fn build_segment_term_collector(
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
max_term_id: max_column_val,
max_term_id,
terms_req_data,
};
Ok(Box::new(collector))
@@ -443,7 +429,7 @@ pub(crate) fn build_segment_term_collector(
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
max_term_id: max_column_val,
max_term_id,
terms_req_data,
};
Ok(Box::new(collector))

View File

@@ -1,585 +0,0 @@
//! Fused collector for the very common shape `terms` (low cardinality) × a single
//! `histogram`/`date_histogram` sub-aggregation with nothing nested below it.
//!
//! See [`SegmentTermHistogramCollector`] for the approach and [`maybe_build_collector`] for the
//! conditions under which it is used.
use columnar::ColumnBlockAccessor;
use super::{Bucket, SegmentTermCollector, TermsAggReqData, VecTermBuckets};
use crate::aggregation::agg_data::{AggKind, AggRefNode, AggregationsSegmentCtx};
use crate::aggregation::bucket::{
get_bucket_pos_f64, prepare_histogram_dense_range, HistogramAggReqData,
SegmentHistogramCollector,
};
use crate::aggregation::buffered_sub_aggs::LowCardSubAggBuffer;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::{f64_from_fastfield_u64, BucketId};
/// Maximum number of cells (`num_terms × num_time_buckets`) in the fused flat 2D grid. Above this
/// the grid would be too large/cache-unfriendly, so we fall back to the general buffered path.
/// `1 << 14` cells = 128 KB of `u64` counters, comfortably L2-resident.
///
/// Since we are only at the top-level, this won't be multiplied by any parent buckets.
const MAX_FUSED_GRID_BUCKETS: usize = 16384;
/// Fused collector for `terms` (low cardinality) × a single `histogram`/`date_histogram` leaf with
/// nothing nested below it, when the resulting `num_terms × num_time_buckets` grid is small (see
/// [`MAX_FUSED_GRID_BUCKETS`]).
///
/// It keeps a flat, fully dense 2D counter grid (`counts[term * num_time_buckets + bucket]`) and a
/// per-term total. A single pass reads both the term and histogram columns in document order and
/// bumps the counters directly — no doc-id buffering, no per-term scattered re-fetch, no dynamic
/// dispatch on flush, no per-bucket key/id storage during collection (keys are derived from the
/// index at the end).
///
/// At result time the flat grid is expanded back into the regular term map + histogram storage and
/// handed to the shared intermediate-result builders, so cross-segment merging is identical to the
/// general path.
#[derive(Debug)]
pub(crate) struct SegmentTermHistogramCollector {
/// Per-term count of docs *outside* `hard_bounds` (still in `doc_count`, but in no bucket).
/// Per-term total = this + the term's `counts` row-sum; left empty when there are no hard
/// bounds (every doc is in-bounds, so there's no remainder to track).
term_counts: Vec<u32>,
/// Flattened `[num_terms * num_time_buckets]` histogram counters (`u32`, see
/// `term_counts`).
///
/// Each term id get its own contiguous slice of `num_time_buckets` histogram counter.
/// When we count all docs (#nofilter), we can derive the per-term total as the sum over that
/// term's slice.
counts: Vec<u32>,
/// Histogram buckets per term (the dense time-range length).
num_time_buckets: usize,
/// `bucket_pos` mapped to time-bucket index 0.
base_pos: i64,
terms_req_data: TermsAggReqData,
/// The (cloned, normalized) histogram request: its column + interval/offset/bounds.
hist_req_data: HistogramAggReqData,
/// Private block accessors for both columns. We read them together, so each needs its own
/// (the shared `agg_data` scratch accessor only holds one block at a time). Owning them keeps
/// `collect` independent of `agg_data`.
term_block: ColumnBlockAccessor<u64>,
hist_block: ColumnBlockAccessor<u64>,
/// No hard bounds, so every doc is in-bounds.
all_docs_in_bounds: bool,
/// Both columns are full (fused-path precondition); cached so `collect` skips the per-block
/// cardinality lookup in `fetch_block`.
is_full: bool,
}
impl SegmentAggregationCollector for SegmentTermHistogramCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
debug_assert_eq!(
parent_bucket_id, 0,
"fused term-histogram collector is top-level only"
);
// Expand the flat grid back into the regular structures and reuse the shared builders, so
// ordering/cut-off/dict handling and cross-segment merging match the general path exactly.
let mut bucket_id_provider = BucketIdProvider::default();
// Per-term total = histogram row-sum (in-bounds) + `term_counts` (out-of-bounds remainder,
// empty when there are no hard bounds).
let term_buckets = VecTermBuckets {
buckets: self
.counts
.chunks_exact(self.num_time_buckets)
.enumerate()
.map(|(term_id, row)| {
let in_bounds: u32 = row.iter().sum();
let out_of_bounds = self.term_counts.get(term_id).copied().unwrap_or(0);
Bucket {
count: in_bounds + out_of_bounds,
bucket_id: bucket_id_provider.next_bucket_id(),
}
})
.collect(),
};
let mut histogram = SegmentHistogramCollector::<()>::from_dense_rows(
self.hist_req_data.clone(),
self.base_pos,
self.num_time_buckets,
&self.counts,
);
let name = self.terms_req_data.name.clone();
let bucket = SegmentTermCollector::<VecTermBuckets, LowCardSubAggBuffer>::into_intermediate_bucket_result(
&self.terms_req_data,
Some(&mut histogram as &mut dyn SegmentAggregationCollector),
term_buckets,
agg_data,
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
#[inline]
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
_agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(
parent_bucket_id, 0,
"fused term-histogram collector is top-level only"
);
// Fetch both columns into our own accessors (we read them together, so they can't share the
// single `agg_data` scratch accessor). The collector owns all its inputs, so `collect`
// doesn't touch `agg_data`.
self.term_block
.fetch_block_with_is_full(docs, &self.terms_req_data.accessor, self.is_full);
self.hist_block
.fetch_block_with_is_full(docs, &self.hist_req_data.accessor, self.is_full);
// Hoist the loop-invariant fields into locals: the optimizer can't prove the
// `self.counts`/`self.term_counts` writes don't alias these `self` fields, so it can't keep
// them in registers and re-reads them from memory every iteration — ~15% slower on
// `terms_status_with_date_histogram` when read straight from `self`.
// Note: check which are actually relevant.
let field_type = self.hist_req_data.field_type;
let bounds = self.hist_req_data.bounds;
let interval = self.hist_req_data.req.interval;
let offset = self.hist_req_data.offset;
let base_pos = self.base_pos;
let num_time_buckets = self.num_time_buckets;
let all_docs_in_bounds = self.all_docs_in_bounds;
let term_counts = &mut self.term_counts;
let counts = &mut self.counts;
// Both columns are full (checked at construction), so values align with `docs` positionally
// and are read together in one pass.
// In-bounds docs bump the `counts` grid, out-of-bounds bump `term_counts`; deriving the
// total at flush avoids a per-doc `term_counts` RMW that serializes on
// store-to-load forwarding.
for (term_id, hist_raw) in self.term_block.iter_vals().zip(self.hist_block.iter_vals()) {
let term_id = term_id as usize;
let val = f64_from_fastfield_u64(hist_raw, field_type);
if all_docs_in_bounds || bounds.contains(val) {
let bucket = (get_bucket_pos_f64(val, interval, offset) as i64 - base_pos) as usize;
debug_assert!(
bucket < num_time_buckets,
"histogram bucket outside dense range"
);
counts[term_id * num_time_buckets + bucket] += 1;
} else {
term_counts[term_id] += 1;
}
}
Ok(())
}
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Nothing is buffered: `collect` writes the flat grid directly.
Ok(())
}
fn prepare_max_bucket(
&mut self,
_max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
// Top-level: the flat grid is allocated up front.
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
None
}
}
/// Builds the fused terms×histogram collector for a single top-level parent, when the shape is
/// eligible. Returns `Ok(None)` to fall back to the general buffered terms path.
///
/// Eligibility: top-level, low-cardinality terms over a full column with no missing/include-exclude
/// handling; a single `histogram`/`date_histogram` leaf (no nesting below it) over a full column;
/// and a `num_terms × num_time_buckets` grid no larger than [`MAX_FUSED_GRID_BUCKETS`].
pub(super) fn maybe_build_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
terms_req_data: &TermsAggReqData,
col_max_val: u64,
is_top_level: bool,
) -> crate::Result<Option<Box<dyn SegmentAggregationCollector>>> {
// Both columns must be full (one value per doc) so their values align positionally with `docs`
// and we can zip them. Requiring full columns also makes the terms agg's `missing` config a
// no-op (`fetch_block_with_missing` early-returns on full columns), so we needn't check for it.
//
// We don't cap the term cardinality here: the flat grid is bounded by the total cell count
// (`num_terms * num_time_buckets <= MAX_FUSED_GRID_BUCKETS`) checked below, which subsumes it.
//
// We only allow this at the top-level, since we don't know how many buckets are created. We
// are less likely to get enough docs for the preallocation to be worth and there's a risk of
// using too much memory. We could check the maximum theoretical buckets up-front and pass
// them down.
let fuseable = is_top_level
// TODO: We can easily support this
&& terms_req_data.allowed_term_ids.is_none()
&& terms_req_data.accessor.get_cardinality().is_full()
// The flat counters are `u32`, bumped once per value, so no count can exceed the column's
// value count. (Essentially always true here: the column is full, so its value count
// equals the doc count, and `DocId` is `u32`.)
&& terms_req_data.accessor.values.num_vals() < u32::MAX
&& node.children.len() == 1
&& matches!(
node.children[0].kind,
AggKind::Histogram | AggKind::DateHistogram
)
&& node.children[0].children.is_empty()
&& agg_data.per_request.histogram_req_data[node.children[0].idx_in_req_data]
.accessor
.get_cardinality()
.is_full();
if !fuseable {
return Ok(None);
}
// Clone + normalize the histogram request and get its dense bucket range; only take the fused
// path when the flat `num_terms × num_time_buckets` grid is small enough.
let Some((hist_req_data, range)) = prepare_histogram_dense_range(agg_data, &node.children[0])?
else {
return Ok(None);
};
let num_terms = col_max_val.saturating_add(1) as usize;
if num_terms.saturating_mul(range.len) > MAX_FUSED_GRID_BUCKETS {
return Ok(None);
}
// No hard bounds means every doc is in-bounds, letting `collect` short-circuit the bounds
// check — and leaving `term_counts` (the out-of-bounds remainder) unused, so we skip allocating
// it.
let all_docs_in_bounds =
hist_req_data.bounds.min == f64::MIN && hist_req_data.bounds.max == f64::MAX;
let counts = vec![0u32; num_terms * range.len];
let term_counts = if all_docs_in_bounds {
Vec::new()
} else {
vec![0u32; num_terms]
};
// Charge both grids to the aggregation memory limit.
agg_data.context.limits.add_memory_consumed(
((counts.len() + term_counts.len()) * std::mem::size_of::<u32>()) as u64,
)?;
Ok(Some(Box::new(SegmentTermHistogramCollector {
term_counts,
counts,
num_time_buckets: range.len,
base_pos: range.base_pos,
terms_req_data: terms_req_data.clone(),
hist_req_data,
term_block: ColumnBlockAccessor::default(),
hist_block: ColumnBlockAccessor::default(),
all_docs_in_bounds,
is_full: terms_req_data.accessor.get_cardinality().is_full(),
})))
}
#[cfg(test)]
mod tests {
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{
exec_request, exec_request_with_query_and_memory_limit,
get_test_index_from_values_and_terms,
};
use crate::aggregation::AggregationLimitsGuard;
/// Hand-computed correctness check for the fused terms×histogram fast path
/// ([`super::SegmentTermHistogramCollector`]): low-cardinality terms × a histogram leaf over
/// full columns, exercised single- and multi-segment.
#[test]
fn fused_term_histogram_test() -> crate::Result<()> {
fused_term_histogram_with_opt(false)?;
fused_term_histogram_with_opt(true)?;
Ok(())
}
fn fused_term_histogram_with_opt(merge_segments: bool) -> crate::Result<()> {
// 300 docs: term = {a, b, c} by i % 3, histogram value = i % 20 (interval 1 => buckets
// 0..19). gcd(3, 20) = 1, so every (term, bucket) pair occurs exactly 300 / 60 = 5 times.
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
// Two segments, to also exercise cross-segment merging of the fused per-term histograms.
let segments = vec![docs[..150].to_vec(), docs[150..].to_vec()];
let index = get_test_index_from_values_and_terms(merge_segments, &segments)?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "order": { "_key": "asc" } },
"aggs": {
"histo": { "histogram": { "field": "score_f64", "interval": 1.0 } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..20usize {
assert_eq!(histo[b]["key"], b as f64, "term {term} bucket {b}");
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {b}");
}
assert_eq!(histo[20], serde_json::Value::Null);
}
assert_eq!(res["by_term"]["buckets"][3], serde_json::Value::Null);
Ok(())
}
/// A `missing` config on a *full* term column still takes the fused path (the string sentinel
/// is just `col_max + 1`, so the column stays low-cardinality). Since no doc is missing, the
/// real term buckets must be exactly as without `missing`.
#[test]
fn fused_term_histogram_with_missing_on_full_column() -> crate::Result<()> {
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "missing": "MISSING", "order": { "_key": "asc" } },
"aggs": {
"histo": { "histogram": { "field": "score_f64", "interval": 1.0 } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// Column is full, so "MISSING" never applies: a, b, c are unchanged (100 docs, 5 per
// bucket).
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..20usize {
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {b}");
}
}
Ok(())
}
/// Term cardinality above the general path's `MAX_NUM_TERMS_FOR_VEC` (100) still fuses: the
/// flat grid is bounded by the total cell count (`num_terms * num_time_buckets`), not the
/// term count.
#[test]
fn fused_term_histogram_many_terms() -> crate::Result<()> {
let num_terms = 150usize;
let docs_per_term = 2usize;
// All docs share histogram value 0 (a single bucket), so the grid is 150 x 1 = 150 cells.
let docs: Vec<(f64, String)> = (0..num_terms * docs_per_term)
.map(|i| (0.0, format!("t{:03}", i % num_terms)))
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "size": 1000, "order": { "_key": "asc" } },
"aggs": {
"histo": { "histogram": { "field": "score_f64", "interval": 1.0 } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
let buckets = res["by_term"]["buckets"].as_array().unwrap();
assert_eq!(buckets.len(), num_terms);
for (i, bucket) in buckets.iter().enumerate() {
assert_eq!(bucket["key"], format!("t{i:03}"));
assert_eq!(bucket["doc_count"], docs_per_term as u64);
assert_eq!(bucket["histo"]["buckets"][0]["key"], 0.0);
assert_eq!(
bucket["histo"]["buckets"][0]["doc_count"],
docs_per_term as u64
);
}
Ok(())
}
/// `hard_bounds` exercises the non-derived `term_counts` branch: a term's `doc_count` must
/// count *every* doc with that term, including docs whose histogram value is outside the
/// bounds (those are excluded from the histogram buckets but still counted for the term). This
/// is the case where the per-doc `term_counts` increment cannot be replaced by the grid
/// row-sum.
#[test]
fn fused_term_histogram_with_hard_bounds() -> crate::Result<()> {
// 300 docs: term = {a, b, c} by i % 3, value = i % 20. Per term: 100 docs, each value in
// 0..=19 occurring 5 times.
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
// hard_bounds [5, 14] (inclusive) keeps only values 5..=14 in the histogram (10 buckets);
// values 0..=4 and 15..=19 are out of bounds.
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "order": { "_key": "asc" } },
"aggs": {
"histo": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": 5.0, "max": 14.0 }
}
}
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
// doc_count includes the 50 per-term docs whose value is outside [5, 14].
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..10usize {
let key = 5 + b;
assert_eq!(histo[b]["key"], key as f64, "term {term} bucket key {key}");
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {key}");
}
// Only the 10 in-bounds buckets exist.
assert_eq!(histo[10], serde_json::Value::Null);
}
Ok(())
}
/// Non-binding `hard_bounds` (wider than the data, with mid-interval edges) must still produce
/// exact results via the derive-from-grid path: since no doc is out of bounds, normalization
/// drops the bound, every doc lands in the dense range, and each term's total equals its
/// histogram row-sum. This is the case that previously fell back to the per-doc counter only
/// because `bounds != [MIN, MAX]`.
#[test]
fn fused_term_histogram_with_non_binding_hard_bounds() -> crate::Result<()> {
// 300 docs: term = {a, b, c} by i % 3, value = i % 20. Data values span [0, 19].
let docs: Vec<(f64, String)> = (0..300u64)
.map(|i| {
(
(i % 20) as f64,
["a", "b", "c"][(i % 3) as usize].to_string(),
)
})
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
// Bounds wider than [0, 19], with mid-interval edges -> they exclude nothing.
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id", "order": { "_key": "asc" } },
"aggs": {
"histo": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": -0.5, "max": 19.5 }
}
}
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
for (term_idx, term) in ["a", "b", "c"].iter().enumerate() {
assert_eq!(res["by_term"]["buckets"][term_idx]["key"], *term);
// Every doc is in-bounds, so the per-term total is the full 100 (as without bounds).
assert_eq!(res["by_term"]["buckets"][term_idx]["doc_count"], 100);
let histo = &res["by_term"]["buckets"][term_idx]["histo"]["buckets"];
for b in 0..20usize {
assert_eq!(histo[b]["key"], b as f64, "term {term} bucket {b}");
assert_eq!(histo[b]["doc_count"], 5, "term {term} bucket {b}");
}
assert_eq!(histo[20], serde_json::Value::Null);
}
Ok(())
}
/// Regression: with hard bounds the fused path allocates `term_counts` (one `u32`/term) on top
/// of the grid, and that allocation must be charged to the memory limit. With many terms and a
/// single time bucket the two are equal in size, so a limit admitting the grid alone but not
/// grid + `term_counts` must fail.
#[test]
fn fused_term_histogram_hard_bounds_charges_term_counts() -> crate::Result<()> {
// 16k distinct terms, one doc each; values alternate in/out of the single-bucket bounds
// [5, 5] so the bounds bind and `term_counts` is allocated. num_terms=16000,
// num_time_buckets=1 => `counts` and `term_counts` are ~64 KB each.
let docs: Vec<(f64, String)> = (0..16_000u64)
.map(|i| (if i % 2 == 0 { 5.0 } else { 10.0 }, format!("t{i:05}")))
.collect();
let index = get_test_index_from_values_and_terms(true, &[docs])?;
let agg_req: Aggregations = serde_json::from_value(serde_json::json!({
"by_term": {
"terms": { "field": "string_id" },
"aggs": {
"histo": {
"histogram": {
"field": "score_f64",
"interval": 1.0,
"hard_bounds": { "min": 5.0, "max": 5.0 }
}
}
}
}
}))
.unwrap();
// ~96 KB admits the grid (~64 KB) but not grid + `term_counts` (~128 KB).
let err = exec_request_with_query_and_memory_limit(
agg_req,
&index,
None,
AggregationLimitsGuard::new(Some(96_000), None),
)
.unwrap_err();
assert!(
err.to_string().contains("memory limit was exceeded"),
"expected a memory-limit error, got: {err}"
);
Ok(())
}
}

View File

@@ -138,7 +138,6 @@ impl SubAggBuffer for HighCardSubAggBuffer {
}
}
#[inline]
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
@@ -197,7 +196,6 @@ impl SubAggBuffer for LowCardSubAggBuffer {
}
}
#[inline]
fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {

View File

@@ -377,22 +377,7 @@ impl IntermediateMetricResult {
MetricResult::ExtendedStats(intermediate_stats.finalize())
}
IntermediateMetricResult::Sum(intermediate_sum) => {
// By default match Elasticsearch: empty / all-missing sum
// buckets serialize as `"value": 0`, not `"value": null`.
// The non-ES `none_if_no_match` flag on `SumAggregation`
// opts into SQL-style `null` for downstream consumers.
let none_if_no_match = req
.agg
.as_sum()
.and_then(|sum| sum.none_if_no_match)
.unwrap_or(false);
let value = intermediate_sum.finalize();
if none_if_no_match {
MetricResult::Sum(value.into())
} else {
let value = Some(value.unwrap_or(0.0));
MetricResult::Sum(value.into())
}
MetricResult::Sum(intermediate_sum.finalize().into())
}
IntermediateMetricResult::Percentiles(percentiles) => MetricResult::Percentiles(
percentiles

View File

@@ -27,16 +27,6 @@ pub struct SumAggregation {
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(default, deserialize_with = "deserialize_option_f64")]
pub missing: Option<f64>,
/// Non-Elasticsearch extension. When `Some(true)`, the serialized result
/// returns `"value": null` if no values were collected (all documents had
/// missing/NULL values for the field), matching the behavior of `min`,
/// `max`, and `avg`. When `None` or `Some(false)` (the default) the
/// result returns `"value": 0`, matching Elasticsearch.
///
/// Intended for SQL-style consumers where `SUM` of zero rows is `NULL`
/// and must be distinguishable from a bucket that genuinely sums to `0`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub none_if_no_match: Option<bool>,
}
impl SumAggregation {
@@ -45,7 +35,6 @@ impl SumAggregation {
Self {
field: field_name,
missing: None,
none_if_no_match: None,
}
}
/// Returns the field name the aggregation is computed on.
@@ -70,104 +59,8 @@ impl IntermediateSum {
pub fn merge_fruits(&mut self, other: IntermediateSum) {
self.stats.merge_fruits(other.stats);
}
/// Computes the final sum value.
///
/// Returns `None` when no values were collected, matching the Rust-side
/// behavior of `IntermediateMin`, `IntermediateMax`, and
/// `IntermediateAvg`. The Elasticsearch-vs-SQL choice for the
/// user-visible result is made at the boundary in
/// [`IntermediateMetricResult::into_final_metric_result`]: by default
/// `None` is coerced to `Some(0.0)` to match Elasticsearch
/// (`"value": 0`), and the [`SumAggregation::none_if_no_match`] flag
/// opts out of that coercion for SQL-style consumers.
/// Computes the final minimum value.
pub fn finalize(&self) -> Option<f64> {
let stats = self.stats.finalize();
if stats.count == 0 {
None
} else {
Some(stats.sum)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum_finalize_returns_none_when_no_values() {
// Default IntermediateSum has count=0 — finalize should return None,
// matching MIN/MAX/AVG behavior for all-NULL groups.
let sum = IntermediateSum::default();
assert_eq!(sum.finalize(), None);
}
#[test]
fn test_sum_finalize_returns_value_when_has_values() {
let mut sum = IntermediateSum::default();
// Merge in a result that has actual values
let stats = IntermediateStats {
count: 3,
sum: 42.0,
min: 10.0,
max: 20.0,
..Default::default()
};
let other = IntermediateSum::from_stats(stats);
sum.merge_fruits(other);
assert_eq!(sum.finalize(), Some(42.0));
}
#[test]
fn test_sum_merge_two_empty_still_none() {
let mut a = IntermediateSum::default();
let b = IntermediateSum::default();
a.merge_fruits(b);
assert_eq!(a.finalize(), None);
}
#[test]
fn test_sum_aggregation_empty_index_default_matches_es() -> crate::Result<()> {
use serde_json::json;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
// Empty index — sum has no values to collect.
let values: Vec<Vec<&str>> = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"score_sum": { "sum": { "field": "score" } }
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// Default: match Elasticsearch — empty sum serializes as 0, not null.
assert_eq!(res["score_sum"]["value"], 0.0);
Ok(())
}
#[test]
fn test_sum_aggregation_empty_index_none_if_no_match_opt_in() -> crate::Result<()> {
use serde_json::json;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
let values: Vec<Vec<&str>> = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"score_sum": { "sum": { "field": "score", "none_if_no_match": true } }
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
// Opt-in non-ES extension — empty sum serializes as null.
assert!(
res["score_sum"]["value"].is_null(),
"expected null, got {:?}",
res["score_sum"]["value"]
);
Ok(())
Some(self.stats.finalize().sum)
}
}

View File

@@ -138,6 +138,31 @@ pub trait DocSet: Send {
buffer.len()
}
/// Fills a given mutable buffer with the next doc ids smaller than `horizon`.
///
/// Unlike [`DocSet::fill_buffer`], this method must not advance past a doc id greater than or
/// equal to `horizon`.
fn fill_buffer_up_to(
&mut self,
horizon: DocId,
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
if self.doc() == TERMINATED {
return 0;
}
for (pos, buffer_val) in buffer.iter_mut().enumerate() {
let doc = self.doc();
if doc >= horizon {
return pos;
}
*buffer_val = doc;
if self.advance() == TERMINATED {
return pos + 1;
}
}
buffer.len()
}
/// Returns the current document
/// Right after creating a new `DocSet`, the docset points to the first document.
///
@@ -251,6 +276,14 @@ impl DocSet for &mut dyn DocSet {
(**self).fill_buffer(buffer)
}
fn fill_buffer_up_to(
&mut self,
horizon: DocId,
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
(**self).fill_buffer_up_to(horizon, buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,

View File

@@ -240,6 +240,42 @@ impl BlockSegmentPostings {
self.freq_decoder.output_array()
}
pub(crate) fn copy_docs_and_term_freqs(
&self,
block_offset: usize,
horizon: DocId,
docs: &mut [DocId],
term_freqs: &mut [u32],
) -> usize {
debug_assert_eq!(docs.len(), term_freqs.len());
let block_docs = self.docs();
let remaining_docs_in_block = block_docs.len().saturating_sub(block_offset);
let max_len = remaining_docs_in_block.min(docs.len());
if max_len == 0 {
return 0;
}
let source_docs = &block_docs[block_offset..block_offset + max_len];
let len = if source_docs[max_len - 1] < horizon {
max_len
} else {
source_docs
.iter()
.position(|&doc| doc >= horizon)
.unwrap_or(max_len)
};
docs[..len].copy_from_slice(&source_docs[..len]);
let block_freqs = self.freq_output_array();
if block_freqs.len() >= block_offset + len {
term_freqs[..len].copy_from_slice(&block_freqs[block_offset..block_offset + len]);
} else {
term_freqs[..len].fill(1);
}
len
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
@@ -287,33 +323,6 @@ impl BlockSegmentPostings {
doc
}
/// Returns the number of documents with a doc id strictly smaller than `target`
/// (i.e. the *rank* of `target` in this posting list).
///
/// This jumps to the block that may contain `target` through the skip list, so no
/// skipped block is decoded; a single block is then decoded to locate `target`
/// within it. The cost is therefore `O(number_of_skip_list_entries)` plus one block
/// decode, rather than `O(doc_freq)`.
///
/// Like [`Self::seek`], the underlying cursor only ever moves forward. This method
/// must be called with **non-decreasing** `target` values (galloping); calling it
/// with a `target` smaller than a previous one yields an incorrect result. `target`
/// must be a valid doc id (i.e. `target <= TERMINATED`), exactly as for `seek`.
///
/// Edge cases: returns `0` when `target` is smaller than every doc id, and
/// `doc_freq()` when `target` is larger than every doc id.
pub fn rank(&mut self, target: DocId) -> u32 {
if self.doc_freq == 0 {
return 0;
}
// `within` = number of docs in the landed block with a doc id < target.
let within = self.seek(target);
// `remaining_docs` counts the landed block and everything after it, so the
// difference is the number of docs in all blocks strictly before it.
let docs_before_block = self.doc_freq - self.skip_reader.remaining_docs();
docs_before_block + within as u32
}
pub(crate) fn position_offset(&self) -> u64 {
self.skip_reader.position_offset()
}
@@ -595,38 +604,4 @@ mod tests {
assert_eq!(block_segments.docs(), &[1, 3, 5]);
Ok(())
}
#[test]
fn test_block_segment_postings_rank() -> crate::Result<()> {
// ~8 blocks worth of docs so the skip list is actually exercised.
let docs: Vec<DocId> = (0..1000u32).map(|i| i * 3).collect();
let mut block_postings = build_block_postings(&docs[..])?;
let doc_freq = block_postings.doc_freq();
// rank(target) must equal the number of docs strictly below target.
// Targets are queried in non-decreasing order, as the API requires.
// `target` values must be a valid doc id (<= TERMINATED) and non-decreasing.
let targets = [
0u32, 1, 2, 3, 4, 299, 300, 301, 1500, 2996, 2997, 3000, 10_000,
];
for &target in &targets {
let expected = docs.iter().filter(|&&d| d < target).count() as u32;
assert_eq!(
block_postings.rank(target),
expected,
"rank({target}) mismatch"
);
}
// Edge cases: below the first doc -> 0, above the last doc -> doc_freq.
let mut fresh = build_block_postings(&docs[..])?;
assert_eq!(fresh.rank(0), 0);
let mut fresh = build_block_postings(&docs[..])?;
assert_eq!(fresh.rank(1_000_000), doc_freq);
// Empty postings: rank is always 0.
let mut empty = BlockSegmentPostings::empty();
assert_eq!(empty.rank(42), 0);
Ok(())
}
}

View File

@@ -532,6 +532,16 @@ pub(crate) mod tests {
fn score(&mut self) -> Score {
self.0.score()
}
#[inline]
fn can_score_doc(&self) -> bool {
self.0.can_score_doc()
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
self.0.score_doc(doc, term_freq)
}
}
pub fn test_skip_against_unoptimized<F: Fn() -> Box<dyn DocSet>>(

View File

@@ -1,6 +1,6 @@
use common::HasLen;
use crate::docset::DocSet;
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
use crate::fastfield::AliveBitSet;
use crate::positions::PositionReader;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
@@ -151,6 +151,34 @@ impl SegmentPostings {
position_reader,
}
}
pub(crate) fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
let mut num_elems = 0;
while num_elems < COLLECT_BLOCK_BUFFER_LEN && self.doc() < horizon {
let copied = self.block_cursor.copy_docs_and_term_freqs(
self.cur,
horizon,
&mut docs[num_elems..],
&mut term_freqs[num_elems..],
);
if copied == 0 {
break;
}
num_elems += copied;
self.cur += copied;
if self.cur == COMPRESSION_BLOCK_SIZE {
self.cur = 0;
self.block_cursor.advance();
}
}
num_elems
}
}
impl DocSet for SegmentPostings {

View File

@@ -187,12 +187,6 @@ impl SkipReader {
self.last_doc_in_block
}
/// Number of docs from the start of the current block to the end of the postings
/// (i.e. the current block plus every block after it).
pub(crate) fn remaining_docs(&self) -> u32 {
self.remaining_docs
}
pub fn position_offset(&self) -> u64 {
self.position_offset
}

View File

@@ -109,6 +109,16 @@ impl Scorer for AllScorer {
fn score(&mut self) -> Score {
1.0
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
1.0
}
}
#[cfg(test)]

View File

@@ -1,5 +1,9 @@
use std::cell::RefCell;
use std::num::NonZeroUsize;
use std::sync::Arc;
use lru::LruCache;
use crate::fieldnorm::FieldNormReader;
use crate::query::Explanation;
use crate::schema::Field;
@@ -59,7 +63,9 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
}
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
const BM25_TF_CACHE_CAPACITY: usize = 64;
fn compute_tf_cache_uncached(average_fieldnorm: Score) -> Arc<[Score; 256]> {
let mut cache: [Score; 256] = [0.0; 256];
for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
@@ -68,6 +74,36 @@ fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
Arc::new(cache)
}
thread_local! {
static TF_CACHES: RefCell<LruCache<u32, Arc<[Score; 256]>>> = RefCell::new(LruCache::new(
NonZeroUsize::new(BM25_TF_CACHE_CAPACITY).unwrap(),
));
}
/// The cache is shared across all [Bm25Weight] with the same average fieldnorm on the same thread.
/// It is stored in a thread local LRU cache.
///
/// On one query all terms on the same field will share the same average fieldnorm, and thus the
/// same cache. This will lower cache pressure.
///
/// Even between queries (on the same thread), the cache will be reused, which allows the cache to
/// better learn the memory address of the cache and access patterns.
///
/// Thread local is used in order to be defensive about potential contention on the cache.
fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> {
let cache_key = average_fieldnorm.to_bits();
TF_CACHES.with(|cache_by_average_fieldnorm| {
let mut cache_by_average_fieldnorm = cache_by_average_fieldnorm.borrow_mut();
if let Some(cache) = cache_by_average_fieldnorm.get(&cache_key) {
return cache.clone();
}
let cache = compute_tf_cache_uncached(average_fieldnorm);
cache_by_average_fieldnorm.put(cache_key, cache.clone());
cache
})
}
/// A struct used for computing BM25 scores.
#[derive(Clone)]
pub struct Bm25Weight {
@@ -229,7 +265,7 @@ impl Bm25Weight {
#[cfg(test)]
mod tests {
use super::idf;
use super::{idf, Bm25Weight};
use crate::{assert_nearly_equals, Score};
#[test]
@@ -237,4 +273,12 @@ mod tests {
let score: Score = 2.0;
assert_nearly_equals!(idf(1, 2), score.ln());
}
#[test]
fn test_bm25_tf_cache_is_shared_for_same_average_fieldnorm() {
let weight1 = Bm25Weight::for_one_term(1, 10, 3.0);
let weight2 = Bm25Weight::for_one_term(2, 10, 3.0);
assert!(std::sync::Arc::ptr_eq(&weight1.cache, &weight2.cache));
}
}

View File

@@ -91,10 +91,14 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
num_docs: u32,
) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
SpecializedScorer::TermUnion(mut term_scorers) => {
if term_scorers.len() == 1 {
Box::new(term_scorers.pop().unwrap())
} else {
let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers

View File

@@ -112,6 +112,14 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.fill_buffer(buffer)
}
fn fill_buffer_up_to(
&mut self,
horizon: DocId,
buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.underlying.fill_buffer_up_to(horizon, buffer)
}
fn doc(&self) -> u32 {
self.underlying.doc()
}
@@ -138,6 +146,27 @@ impl<S: Scorer> Scorer for BoostScorer<S> {
fn score(&mut self) -> Score {
self.underlying.score() * self.boost
}
#[inline]
fn can_score_doc(&self) -> bool {
self.underlying.can_score_doc()
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
self.underlying.score_doc(doc, term_freq) * self.boost
}
#[inline]
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.underlying
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
}
}
#[cfg(test)]

View File

@@ -1,6 +1,6 @@
use std::fmt;
use crate::docset::{SeekDangerResult, COLLECT_BLOCK_BUFFER_LEN};
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
@@ -119,10 +119,6 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
self.docset.seek(target)
}
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
self.docset.seek_danger(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
self.docset.fill_buffer(buffer)
}
@@ -145,6 +141,16 @@ impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
fn score(&mut self) -> Score {
self.score
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
self.score
}
}
#[cfg(test)]

View File

@@ -315,6 +315,20 @@ mod tests {
fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, doc: DocId, _term_freq: u32) -> Score {
self.foo
.iter()
.find(|(candidate_doc, _)| *candidate_doc == doc)
.map(|(_, score)| *score)
.unwrap_or(0.0)
}
}
#[test]

View File

@@ -59,6 +59,16 @@ impl Scorer for EmptyScorer {
fn score(&mut self) -> Score {
0.0
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
0.0
}
}
#[cfg(test)]

View File

@@ -3,7 +3,6 @@ use std::ops::RangeInclusive;
use columnar::Column;
use crate::docset::SeekDangerResult;
use crate::{DocId, DocSet, TERMINATED};
/// Helper to have a cursor over a vec of docids
@@ -185,37 +184,6 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
doc
}
/// `seek_danger` only needs to answer whether `target` itself matches, so it does a cheap
/// point lookup on the column instead of scanning forward to materialize the next match (the
/// expensive part of a regular `seek`).
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
// Covers `target == TERMINATED` and any target past the last doc: no match is possible.
if target >= self.column.num_docs() {
return SeekDangerResult::SeekLowerBound(TERMINATED);
}
if self.is_last_seek_distance_large(target) {
self.reset_fetch_range();
}
self.last_seek_pos_opt = Some(target);
let is_match = self
.column
.values_for_doc(target)
.any(|value| self.value_range.contains(&value));
if is_match {
// Leave the docset in a valid state positioned on `target`, so `doc()` returns it and a
// following `advance()` resumes the scan right after it.
self.loaded_docs.get_cleared_data().push(target);
self.next_fetch_start = target + 1;
SeekDangerResult::Found
} else {
// `target` is not in the docset. The next match is strictly greater than `target`, so
// `target + 1` is a valid lower bound. We may leave the docset in an invalid state.
SeekDangerResult::SeekLowerBound(target + 1)
}
}
fn size_hint(&self) -> u32 {
// TODO: Implement a better size hint
self.column.num_docs() / 10
@@ -241,148 +209,12 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
#[cfg(test)]
mod tests {
use std::ops::{Bound, RangeInclusive};
use std::ops::Bound;
use columnar::Column;
use super::RangeDocSet;
use crate::collector::Count;
use crate::directory::RamDirectory;
use crate::docset::{SeekDangerResult, TERMINATED};
use crate::query::RangeQuery;
use crate::{schema, DocSet, Index, IndexBuilder, TantivyDocument, Term};
/// Builds a single-segment index where doc `i` carries `values_for_doc(i)` in a u64 fast
/// field, then returns its column so we can drive a `RangeDocSet` directly.
fn build_u64_column(
num_docs: usize,
values_for_doc: impl Fn(usize) -> Vec<u64>,
) -> Column<u64> {
let mut schema_builder = schema::SchemaBuilder::new();
let value_field = schema_builder.add_u64_field("value", schema::FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests().unwrap();
for i in 0..num_docs {
let mut doc = TantivyDocument::new();
for v in values_for_doc(i) {
doc.add_u64(value_field, v);
}
writer.add_document(doc).unwrap();
}
writer.commit().unwrap();
}
let searcher = index.reader().unwrap().searcher();
assert_eq!(searcher.segment_readers().len(), 1);
searcher
.segment_reader(0)
.fast_fields()
.u64("value")
.unwrap()
}
fn range_docset(
value_range: RangeInclusive<u64>,
num_docs: usize,
values_for_doc: impl Fn(usize) -> Vec<u64>,
) -> RangeDocSet<u64> {
RangeDocSet::new(value_range, build_u64_column(num_docs, values_for_doc))
}
#[test]
fn seek_danger_found_leaves_valid_state() {
// Even docs match the range, odd docs do not.
let mut docset = range_docset(0..=0, 100, |i| vec![(i % 2) as u64]);
// Matching target: `Found`, and the docset is positioned exactly on it.
assert_eq!(docset.seek_danger(10), SeekDangerResult::Found);
assert_eq!(docset.doc(), 10);
// A following advance resumes the scan right after the found doc.
assert_eq!(docset.advance(), 12);
assert_eq!(docset.doc(), 12);
}
#[test]
fn seek_danger_miss_returns_lower_bound() {
let mut docset = range_docset(0..=0, 100, |i| vec![(i % 2) as u64]);
// Odd target does not match: lower bound is strictly greater than the target and never
// skips past the next real match (here doc 12, the first even doc after 11).
match docset.seek_danger(11) {
SeekDangerResult::SeekLowerBound(lower_bound) => {
assert!(lower_bound > 11);
assert!(lower_bound <= 12);
}
SeekDangerResult::Found => panic!("11 should not match"),
}
// After a miss we may be in an invalid state; another seek_danger recovers it.
assert_eq!(docset.seek_danger(12), SeekDangerResult::Found);
assert_eq!(docset.doc(), 12);
}
#[test]
fn seek_danger_terminated_and_out_of_bounds() {
let mut docset = range_docset(0..=0, 10, |i| vec![(i % 2) as u64]);
assert_eq!(
docset.seek_danger(TERMINATED),
SeekDangerResult::SeekLowerBound(TERMINATED)
);
// A target past the last doc has no possible match either.
assert_eq!(
docset.seek_danger(10),
SeekDangerResult::SeekLowerBound(TERMINATED)
);
}
#[test]
fn seek_danger_multivalued() {
// Doc `i` holds values [i, i+1]; the range {5} matches docs 4 and 5.
let mut docset = range_docset(5..=5, 20, |i| vec![i as u64, i as u64 + 1]);
assert_eq!(docset.seek_danger(4), SeekDangerResult::Found);
assert_eq!(docset.doc(), 4);
assert_eq!(docset.advance(), 5);
// No further match after doc 5.
assert_eq!(docset.advance(), TERMINATED);
}
#[test]
fn seek_danger_matches_seek() {
// Cross-check seek_danger against the true next match for every target, on a column with a
// few sparse matches.
let matches = [3u32, 7, 50, 51, 99];
let num_docs = 100;
let values_for_doc = |i: usize| {
vec![if matches.contains(&(i as u32)) {
1u64
} else {
0u64
}]
};
for target in 0..num_docs as u32 {
// The first matching doc greater than or equal to `target`, i.e. what `seek` returns.
let expected = matches
.iter()
.copied()
.find(|&m| m >= target)
.unwrap_or(TERMINATED);
let mut danger = range_docset(1..=1, num_docs, values_for_doc);
match danger.seek_danger(target) {
SeekDangerResult::Found => {
assert_eq!(expected, target, "target {target} reported Found");
assert_eq!(danger.doc(), target);
}
SeekDangerResult::SeekLowerBound(lower_bound) => {
assert_ne!(expected, target, "target {target} should have been Found");
assert!(lower_bound > target);
// The lower bound must never skip past the true next match.
assert!(lower_bound <= expected);
}
}
}
}
use crate::{schema, IndexBuilder, TantivyDocument, Term};
#[test]
fn range_query_fast_optional_field_minimum() {

View File

@@ -1,5 +1,40 @@
use crate::docset::{DocSet, TERMINATED};
use crate::query::Scorer;
use crate::Score;
use crate::{DocId, Score};
struct ScoreOnlyScorer {
doc: DocId,
score: Score,
}
impl DocSet for ScoreOnlyScorer {
fn advance(&mut self) -> DocId {
self.doc = TERMINATED;
TERMINATED
}
fn doc(&self) -> DocId {
self.doc
}
fn size_hint(&self) -> u32 {
1
}
}
impl Scorer for ScoreOnlyScorer {
fn score(&mut self) -> Score {
self.score
}
fn can_score_doc(&self) -> bool {
true
}
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
self.score
}
}
/// The `ScoreCombiner` trait defines how to compute
/// an overall score given a list of scores.
@@ -10,6 +45,17 @@ pub trait ScoreCombiner: Default + Clone + Send + Copy + 'static {
/// or not.
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer);
/// Aggregates the score combiner with an already computed score.
fn update_score(&mut self, doc: DocId, score: Score) {
let mut scorer = ScoreOnlyScorer { doc, score };
self.update(&mut scorer);
}
/// Returns true if this combiner needs scorer scores to compute its state.
fn requires_scoring() -> bool {
true
}
/// Clears the score combiner state back to its initial state.
fn clear(&mut self);
@@ -27,6 +73,12 @@ pub struct DoNothingCombiner;
impl ScoreCombiner for DoNothingCombiner {
fn update<TScorer: Scorer>(&mut self, _scorer: &mut TScorer) {}
fn update_score(&mut self, _doc: DocId, _score: Score) {}
fn requires_scoring() -> bool {
false
}
fn clear(&mut self) {}
#[inline]
@@ -42,10 +94,16 @@ pub struct SumCombiner {
}
impl ScoreCombiner for SumCombiner {
#[inline]
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
self.score += scorer.score();
}
#[inline]
fn update_score(&mut self, _doc: DocId, score: Score) {
self.score += score;
}
fn clear(&mut self) {
self.score = 0.0;
}
@@ -77,12 +135,19 @@ impl DisjunctionMaxCombiner {
}
impl ScoreCombiner for DisjunctionMaxCombiner {
#[inline]
fn update<TScorer: Scorer>(&mut self, scorer: &mut TScorer) {
let score = scorer.score();
self.max = Score::max(score, self.max);
self.sum += score;
}
#[inline]
fn update_score(&mut self, _doc: DocId, score: Score) {
self.max = Score::max(score, self.max);
self.sum += score;
}
fn clear(&mut self) {
self.max = 0.0;
self.sum = 0.0;

View File

@@ -2,8 +2,8 @@ use std::ops::DerefMut;
use downcast_rs::impl_downcast;
use crate::docset::DocSet;
use crate::Score;
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
use crate::{DocId, Score};
/// Scored set of documents matching a query within a specific segment.
///
@@ -13,6 +13,36 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
///
/// This method will perform a bit of computation and is not cached.
fn score(&mut self) -> Score;
/// Returns true if [`Scorer::score_doc`] can score buffered docs without
/// repositioning the scorer.
///
/// Scorers whose [`Scorer::score_doc`] needs term frequencies must also override
/// [`Scorer::fill_buffer_up_to_with_term_freqs`].
fn can_score_doc(&self) -> bool {
false
}
/// Returns the score for `doc` with its term frequency.
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
panic!(
"score_doc is not supported by this scorer. You need check can_score_doc() before \
calling this method."
)
}
/// Fills docs up to `horizon`.
///
/// The default implementation does not fill `term_freqs`. Scorers whose
/// [`Scorer::score_doc`] reads term frequencies must override this method.
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
_term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
DocSet::fill_buffer_up_to(self, horizon, docs)
}
}
impl_downcast!(Scorer);
@@ -22,4 +52,25 @@ impl Scorer for Box<dyn Scorer> {
fn score(&mut self) -> Score {
self.deref_mut().score()
}
#[inline]
fn can_score_doc(&self) -> bool {
self.as_ref().can_score_doc()
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
self.deref_mut().score_doc(doc, term_freq)
}
#[inline]
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.deref_mut()
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
}
}

View File

@@ -1,4 +1,4 @@
use crate::docset::DocSet;
use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN};
use crate::fieldnorm::FieldNormReader;
use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings};
use crate::query::bm25::Bm25Weight;
@@ -147,6 +147,27 @@ impl Scorer for TermScorer {
let term_freq = self.term_freq();
self.similarity_weight.score(fieldnorm_id, term_freq)
}
#[inline]
fn can_score_doc(&self) -> bool {
true
}
#[inline]
fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score {
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);
self.similarity_weight.score(fieldnorm_id, term_freq)
}
fn fill_buffer_up_to_with_term_freqs(
&mut self,
horizon: DocId,
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
) -> usize {
self.postings
.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs)
}
}
#[cfg(test)]

View File

@@ -10,23 +10,7 @@ use crate::{DocId, Score};
// of upcoming document IDs (the "horizon").
const HORIZON_NUM_TINYBITSETS: usize = HORIZON as usize / 64;
const HORIZON: u32 = 64u32 * 64u32;
// `drain_filter` is not stable yet.
// This function is similar except that it does is not unstable, and
// it does not keep the original vector ordering.
//
// Elements are dropped and not yielded.
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
where P: FnMut(&mut T) -> bool {
let mut i = 0;
while i < v.len() {
if predicate(&mut v[i]) {
v.swap_remove(i);
} else {
i += 1;
}
}
}
const GROUPED_INSERT_MAX_BUCKET_SPAN: u32 = 2;
/// Creates a `DocSet` that iterate through the union of two or more `DocSet`s.
pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
@@ -53,31 +37,213 @@ pub struct BufferedUnionScorer<TScorer, TScoreCombiner = DoNothingCombiner> {
score: Score,
/// Number of documents in the segment.
num_docs: u32,
/// Scratch buffer for block-based refill.
refill_docs: [DocId; COLLECT_BLOCK_BUFFER_LEN],
/// Scratch buffer for term frequencies matching `refill_docs`.
refill_term_freqs: [u32; COLLECT_BLOCK_BUFFER_LEN],
/// Whether all children support scoring buffered docs after advancing.
use_score_doc_refill: bool,
}
#[inline]
fn union_bucket(
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
bucket_pos: u32,
tinyset: TinySet,
) {
debug_assert!((bucket_pos as usize) < HORIZON_NUM_TINYBITSETS);
// `bucket` comes from a doc delta below `HORIZON`; there are exactly
// `HORIZON / 64` buckets in the refill window.
bitsets[bucket_pos as usize] = bitsets[bucket_pos as usize].union(tinyset);
}
#[inline]
fn insert_delta(bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], delta: DocId) {
debug_assert!(delta < HORIZON);
// `delta < HORIZON`, so `delta / 64` is in the bitset array. The bit
// offset is reduced modulo 64 before being inserted in the TinySet.
bitsets[delta as usize / 64].insert_mut(delta % 64u32);
}
fn insert_and_score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
min_doc: DocId,
) {
debug_assert!(docs.windows(2).all(|pair| pair[0] < pair[1]));
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc < HORIZON);
let first_delta = docs[0] - min_doc;
let last_delta = docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc;
let first_bucket = first_delta / 64;
let last_bucket = last_delta / 64;
// Common for very dense scorers: 64 distinct doc ids in one 64-doc bucket
// means all bits in that bucket are present.
if first_bucket == last_bucket {
union_bucket(bitsets, first_bucket, TinySet::full());
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
return;
}
// 64 sorted distinct integers spanning exactly 64 values are consecutive.
// If they cross a TinySet boundary, this is just the suffix of the first
// bucket plus the prefix of the second bucket.
if last_delta - first_delta == COLLECT_BLOCK_BUFFER_LEN as u32 - 1 {
union_bucket(
bitsets,
first_bucket,
TinySet::range_greater_or_equal(first_delta % 64u32),
);
union_bucket(
bitsets,
last_bucket,
TinySet::range_lower((last_delta + 1) % 64u32),
);
score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc);
return;
}
// Grouping wins only for very dense buffers that hit the same TinySet many
// times. Once the 64 docs are spread farther, a straight pass is cheaper.
if last_bucket - first_bucket <= GROUPED_INSERT_MAX_BUCKET_SPAN {
let mut bucket = first_bucket;
let mut tinyset = TinySet::empty();
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
let delta = doc - min_doc;
let delta_bucket = delta / 64;
if delta_bucket != bucket {
union_bucket(bitsets, bucket, tinyset);
bucket = delta_bucket;
tinyset = TinySet::empty();
}
tinyset.insert_mut(delta % 64u32);
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, delta, doc, score);
}
union_bucket(bitsets, bucket, tinyset);
} else {
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
let delta = doc - min_doc;
insert_delta(bitsets, delta);
// TODO: score_doc access the field_norm reader for each _term_, instead of once per
// doc. We could optimize this by caching the field norm for the doc, and
// reusing it for all terms in the doc.
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, delta, doc, score);
}
}
}
#[inline]
fn update_score_combiner<TScoreCombiner: ScoreCombiner>(
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
delta: DocId,
doc: DocId,
score: Score,
) {
debug_assert!(delta < HORIZON);
// Full and partial refill only buffer docs below `horizon`, so their
// deltas are always in the score-combiner window.
score_combiner[delta as usize].update_score(doc, score);
}
fn score_full_buffer<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
min_doc: DocId,
) {
for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) {
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, doc - min_doc, doc, score);
}
}
fn refill_scorer_with_score_docs<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
min_doc: DocId,
horizon: DocId,
) {
loop {
let len = scorer.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs);
if len == COLLECT_BLOCK_BUFFER_LEN {
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] != TERMINATED);
debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] < horizon);
insert_and_score_full_buffer(
scorer,
docs,
term_freqs,
bitsets,
score_combiner,
min_doc,
);
} else {
for (&doc, &term_freq) in docs[..len].iter().zip(term_freqs[..len].iter()) {
let delta = doc - min_doc;
insert_delta(bitsets, delta);
let score = scorer.score_doc(doc, term_freq);
update_score_combiner(score_combiner, delta, doc, score);
}
break;
}
}
}
fn refill_scorer_from_current_doc<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorer: &mut TScorer,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
min_doc: DocId,
horizon: DocId,
) {
loop {
let doc = scorer.doc();
if doc >= horizon {
break;
}
let delta = doc - min_doc;
insert_delta(bitsets, delta);
debug_assert!(delta < HORIZON);
score_combiner[delta as usize].update(scorer);
scorer.advance();
}
}
fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
scorers: &mut Vec<TScorer>,
bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS],
score_combiner: &mut [TScoreCombiner; HORIZON as usize],
docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN],
min_doc: DocId,
use_score_doc_refill: bool,
) {
unordered_drain_filter(scorers, |scorer| {
let horizon = min_doc + HORIZON;
loop {
let doc = scorer.doc();
if doc >= horizon {
return false;
}
// add this document
let delta = doc - min_doc;
bitsets[(delta / 64) as usize].insert_mut(delta % 64u32);
score_combiner[delta as usize].update(scorer);
if scorer.advance() == TERMINATED {
// remove the docset, it has been entirely consumed.
return true;
}
let horizon = min_doc + HORIZON;
for scorer in scorers.iter_mut() {
if use_score_doc_refill {
refill_scorer_with_score_docs(
scorer,
bitsets,
score_combiner,
docs,
term_freqs,
min_doc,
horizon,
);
} else {
refill_scorer_from_current_doc(scorer, bitsets, score_combiner, min_doc, horizon);
}
});
}
scorers.retain(|scorer| scorer.doc() != TERMINATED);
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer, TScoreCombiner> {
@@ -87,6 +253,8 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
score_combiner_fn: impl FnOnce() -> TScoreCombiner,
num_docs: u32,
) -> BufferedUnionScorer<TScorer, TScoreCombiner> {
let use_score_doc_refill =
TScoreCombiner::requires_scoring() && docsets.iter().all(Scorer::can_score_doc);
let non_empty_docsets: Vec<TScorer> = docsets
.into_iter()
.filter(|docset| docset.doc() != TERMINATED)
@@ -100,6 +268,9 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
doc: 0,
score: 0.0,
num_docs,
refill_docs: [TERMINATED; COLLECT_BLOCK_BUFFER_LEN],
refill_term_freqs: [1u32; COLLECT_BLOCK_BUFFER_LEN],
use_score_doc_refill,
};
if union.refill() {
union.advance();
@@ -120,7 +291,10 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
&mut self.docsets,
&mut self.bitsets,
&mut self.scores,
&mut self.refill_docs,
&mut self.refill_term_freqs,
min_doc,
self.use_score_doc_refill,
);
true
} else {
@@ -248,12 +422,12 @@ where
// The target is outside of the buffered horizon.
// advance all docsets to a doc >= to the target.
unordered_drain_filter(&mut self.docsets, |docset| {
for docset in &mut self.docsets {
if docset.doc() < target {
docset.seek(target);
}
docset.doc() == TERMINATED
});
}
self.docsets.retain(|docset| docset.doc() != TERMINATED);
// at this point all of the docsets
// are positioned on a doc >= to the target.

View File

@@ -10,6 +10,8 @@ pub use simple_union::SimpleUnion;
mod tests {
use std::collections::BTreeSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use common::BitSet;
@@ -18,8 +20,8 @@ mod tests {
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::union::bitset_union::BitSetPostingUnion;
use crate::query::{BitSetDocSet, ConstScorer, VecDocSet};
use crate::{tests, DocId};
use crate::query::{BitSetDocSet, ConstScorer, Scorer, VecDocSet};
use crate::{tests, DocId, Score};
fn vec_doc_set_from_docs_list(
docs_list: &[Vec<DocId>],
@@ -66,6 +68,61 @@ mod tests {
}
BitSetDocSet::from(doc_bitset)
}
struct CountingScorer {
docset: VecDocSet,
score_calls: Arc<AtomicUsize>,
score_doc_calls: Arc<AtomicUsize>,
}
impl CountingScorer {
fn new(
doc_ids: Vec<DocId>,
score_calls: Arc<AtomicUsize>,
score_doc_calls: Arc<AtomicUsize>,
) -> Self {
CountingScorer {
docset: VecDocSet::from(doc_ids),
score_calls,
score_doc_calls,
}
}
}
impl DocSet for CountingScorer {
fn advance(&mut self) -> DocId {
self.docset.advance()
}
fn seek(&mut self, target: DocId) -> DocId {
self.docset.seek(target)
}
fn doc(&self) -> DocId {
self.docset.doc()
}
fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
}
impl Scorer for CountingScorer {
fn score(&mut self) -> Score {
self.score_calls.fetch_add(1, Ordering::SeqCst);
1.0
}
fn can_score_doc(&self) -> bool {
true
}
fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score {
self.score_doc_calls.fetch_add(1, Ordering::SeqCst);
1.0
}
}
fn aux_test_union(docs_list: &[Vec<DocId>]) {
for constructor in [
posting_list_union_from_docs_list,
@@ -168,6 +225,22 @@ mod tests {
]);
}
#[test]
fn test_do_nothing_combiner_does_not_score_buffered_docs() {
let score_calls = Arc::new(AtomicUsize::new(0));
let score_doc_calls = Arc::new(AtomicUsize::new(0));
let scorers = vec![
CountingScorer::new(vec![1, 3, 5], score_calls.clone(), score_doc_calls.clone()),
CountingScorer::new(vec![2, 3, 6], score_calls.clone(), score_doc_calls.clone()),
];
let mut union = BufferedUnionScorer::build(scorers, DoNothingCombiner::default, 10);
assert_eq!(union.count_including_deleted(), 5);
assert_eq!(score_calls.load(Ordering::SeqCst), 0);
assert_eq!(score_doc_calls.load(Ordering::SeqCst), 0);
}
fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) {
for constructor in [
posting_list_union_from_docs_list,